import torch
import sys
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 1. 设置数据集的路径
data_dir = "D:\py_work\mnist_ai_bp\AI_test_mnist_demo_BP" # 替换为您的实际路径
# 2. 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 使用MNIST的均值和标准差
])
# 3. 加载数据集
trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)
# 可视化部分训练数据
examples = enumerate(trainloader)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title(f"标签: {example_targets[i].item()}")
plt.xticks([])
plt.yticks([])
plt.show()
# 4. 定义神经网络模型
class BPnetwork(nn.Module):
def __init__(self):
super(BPnetwork, self).__init__()
self.linear1 = nn.Linear(28 * 28, 128)
self.ReLU1 = nn.ReLU()
self.linear2 = nn.Linear(128, 64)
self.ReLU2 = nn.ReLU()
self.linear3 = nn.Linear(64, 10)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = x.view(x.size(0), -1) # 将输入数据展平成一维
x = self.linear1(x)
x = self.ReLU1(x)
x = self.linear2(x)
x = self.ReLU2(x)
x = self.linear3(x)
x = self.log_softmax(x)
return x
# 5. 创建模型、损失函数和优化器
model = BPnetwork().to(device)
criterion = nn.NLLLoss() # 负对数似然损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器
# 6. 模型训练
epochs = 10
for epoch in range(epochs):
model.train() # 设置模型为训练模式
running_loss = 0.0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # 清空梯度
output = model(images) # 前向传播
loss = criterion(output, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
running_loss += loss.item()
average_loss = running_loss / len(trainloader)
print(f"Epoch {epoch+1}, Loss: {average_loss:.4f}")
# 在每个epoch结束后,评估模型在测试集上的性能
model.eval() # 设置模型为评估模式
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
_, predicted = torch.max(output, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"测试准确率: {accuracy:.2f}%")
# 7. 保存模型
torch.save(model.state_dict(), 'mnist_bpnetwork.pth')
print("模型已保存为 mnist_bpnetwork.pth")
# 8. 可视化部分测试结果
model.eval()
examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)
example_data, example_targets = example_data.to(device), example_targets.to(device)
with torch.no_grad():
output = model(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i].cpu()[0], cmap='gray', interpolation='none')
pred = output.data.max(1, keepdim=True)[1][i].item()
plt.title(f"预测: {pred}")
plt.xticks([])
plt.yticks([])
plt.show()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
AI_test_mnist_demo_BP_win.zip (54个子文件)
AI_test_mnist_demo_BP
new_python_demo_minst_model1.py 4KB
MNIST
python_display_image.py 755B
train-images.idx3-ubyte 44.86MB
display_mnist_data_iamge.py 1KB
t10k-images.idx3-ubyte 7.48MB
display_mnist_data_iamge2.py 1KB
train-labels.idx1-ubyte 59KB
t10k-labels.idx1-ubyte 10KB
display_mnist_data_iamge1.py 2KB
raw
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
train-labels-idx1-ubyte.gz 28KB
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
train-labels-idx1-ubyte 59KB
mnist_bpnetwork.pth 430KB
test_image2
test_iamge_ai
test_num1.jpg 89KB
testNum6.jpg 1KB
test_num2.jpg 97KB
testNum10.jpg 4KB
testNum2.jpg 1KB
testNum1.jpg 1KB
testNum5.jpg 1KB
test_num111.jpg 77KB
test_num7.jpg 86KB
test_num12.jpg 98KB
mnist_data_num1.jpg 9KB
test_num4.jpg 102KB
testNum7.jpg 987B
test_num114.jpg 60KB
ai_test_image1.jpg 1KB
testNum0.jpg 782B
test_num6.jpg 93KB
test_num113.jpg 69KB
testNum3.jpg 1KB
test_num5.jpg 98KB
mnist_data_num2.jpg 8KB
test_num11.jpg 102KB
ai_test_image4.jpg 1KB
testNum4.jpg 1KB
ai_test_image3.jpg 1KB
test_num3.jpg 87KB
ai_test_image2.jpg 1KB
testNum8.jpg 1KB
test_num112.jpg 70KB
testNum9.jpg 1KB
mnist_bpnetwork.pth 430KB
testNum7.jpg 987B
python_ai_test_image_dif_size.py 2KB
python_ai_test_image.py 2KB
.ai_test_tool_num_check.py.swp 12KB
test_iamge_ai.zip 12KB
共 54 条
- 1
资源评论
好奇龙猫
- 粉丝: 3w+
- 资源: 131
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功