import torch
import sys
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 1. 设置数据集的路径
data_dir = "D:\py_work\mnist_ai_bp\AI_test_mnist_demo_BP" # 替换为您的实际路径
# 2. 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 使用MNIST的均值和标准差
])
# 3. 加载数据集
trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)
# 可视化部分训练数据
examples = enumerate(trainloader)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title(f"标签: {example_targets[i].item()}")
plt.xticks([])
plt.yticks([])
plt.show()
# 4. 定义神经网络模型
class BPnetwork(nn.Module):
def __init__(self):
super(BPnetwork, self).__init__()
self.linear1 = nn.Linear(28 * 28, 128)
self.ReLU1 = nn.ReLU()
self.linear2 = nn.Linear(128, 64)
self.ReLU2 = nn.ReLU()
self.linear3 = nn.Linear(64, 10)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = x.view(x.size(0), -1) # 将输入数据展平成一维
x = self.linear1(x)
x = self.ReLU1(x)
x = self.linear2(x)
x = self.ReLU2(x)
x = self.linear3(x)
x = self.log_softmax(x)
return x
# 5. 创建模型、损失函数和优化器
model = BPnetwork().to(device)
criterion = nn.NLLLoss() # 负对数似然损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器
# 6. 模型训练
epochs = 10
for epoch in range(epochs):
model.train() # 设置模型为训练模式
running_loss = 0.0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # 清空梯度
output = model(images) # 前向传播
loss = criterion(output, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
running_loss += loss.item()
average_loss = running_loss / len(trainloader)
print(f"Epoch {epoch+1}, Loss: {average_loss:.4f}")
# 在每个epoch结束后,评估模型在测试集上的性能
model.eval() # 设置模型为评估模式
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
_, predicted = torch.max(output, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"测试准确率: {accuracy:.2f}%")
# 7. 保存模型
torch.save(model.state_dict(), 'mnist_bpnetwork.pth')
print("模型已保存为 mnist_bpnetwork.pth")
# 8. 可视化部分测试结果
model.eval()
examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)
example_data, example_targets = example_data.to(device), example_targets.to(device)
with torch.no_grad():
output = model(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i].cpu()[0], cmap='gray', interpolation='none')
pred = output.data.max(1, keepdim=True)[1][i].item()
plt.title(f"预测: {pred}")
plt.xticks([])
plt.yticks([])
plt.show()
好奇龙猫
- 粉丝: 3w+
- 资源: 131
最新资源
- 三菱FX3U PLC与Factory IO通讯仿真PID液位调节程序 说到学习PLC 的PID ,要有硬件 测温度的PID设备有: 输入输出模拟量模块300左右X2(有些PLC自带)
- carsim+simulink联合仿真实现变道 包含路径规划 carsim+simulink联合仿真实现变道 包含路径规划算法+mpc轨迹跟踪算法 可选simulink版本和c++版本算法(二选一)
- carsim+simulink联合仿真实现变道 包含路径规划 carsim+simulink联合仿真实现变道 包含路径规划算法+mpc轨迹跟踪算法 可选simulink版本和c++版本算法 可以适用于
- mpc模型预测控制从原理到代码实现 mpc模型预测控制从原理到代码实现 mpc模型预测控制详细原理推导 matlab和c++两种编程实现 四个实际控制工程案例: 双积分控制系统 倒立摆控制系统 车辆运
- 1.中性点不接地系统的小电流接地故障及故障选线的MATLAB仿真,也可以改接地的 2.两个打包(中性点不接地与中性点经消弧线圈接地),一个(中性点不接地或中性点经消弧线圈接地) 4.选线方法的仿真
- 多智能体,神经网络,自适应动态滑模,有文献可以参考 符合要求请放心联系,simulink,复现,保证能够运行
- 该模型是内置式的MTPA控制,速度环的输出为给定转矩,然后方式1通过求解MTPA方程得到dq给定电流,方式2进行工程近似得到dq给定电流,并外和id=0控制进行比较
- 永磁同步电机的无传感器控制算法 基于永磁同步电机(PMSM)的改进的卡尔曼滤波速度观测器simulink模型;可与普通卡尔曼滤波进行比对,精度大大提高
- 基于ESO的永磁同步电机无感FOC 1.采用线性扩张状态观测器(LESO)估计电机反电势,利用锁相环从反电势中提取位置和转速信息,从而实现无位置传感器控制; 2.提供算法对应的参考文献和仿真模型 拿
- 电机过调制算法模型从线性调制区到过调制区,算法已在量产车中验证过 电子文件产品
- 交错并联buck 两重化交错并联buck电路,采用电压电流双闭环控制,电流采用平均电流采样,载波移相180°,减少了电流纹波,可以减少电感体积 仿真波形如图所示,当采用软启动时,0.3秒的时间输出
- 永磁同步电机风力发电系统仿真模型,包含变桨系统与传动系统,运行各项指标正确,可稳定发出有功功率,无功功率为0
- fpga MIL-STD1553B源码,支持BC ,BM,RT 可任意移植到xilinx,altera,actel全系列型号 功能和接口可参考actel芯片1553b核,纯源码
- carsim交通场景搭建,carsim与matlab,prescan联合仿真,巡航、路径规划及道控制算法,cpar文件输出及场景图生成
- MMC,模块化多电平变流器的MATLAB,Simulink仿真 11电平三相MMC逆变器并网仿真,调制方式选用载波移相调制 采用双闭环矢量控制,施加环流抑制控制和子模块电容电压均衡控制 直流侧采用
- 三段式电流保护Matlab编程 Simulink仿真 1. Matlab编程计算三段式电流保护的整定值,并进行灵敏度校验; 2.Simulink搭建仿真模型,对三段式电流保护模型进行仿真分析
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈