import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
# 设置设备,如果有GPU则使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪32x32的图像,外边补充4个像素
transforms.RandomHorizontalFlip(), # 随机水平翻转图像
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # 标准化
])
transform_test = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # 标准化
])
# 加载 CIFAR-100 数据集
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2) # 加载训练集
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) # 加载测试集
# 定义卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 使用 nn.Sequential 定义网络结构
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 第一个卷积层,输入3通道,输出64通道
nn.ReLU(), # ReLU 激活函数
nn.MaxPool2d(kernel_size=2, stride=2), # 最大池化层
nn.Conv2d(64, 128, kernel_size=3, padding=1), # 第二个卷积层,输入64通道,输出128通道
nn.ReLU(), # ReLU 激活函数
nn.MaxPool2d(kernel_size=2, stride=2), # 最大池化层
nn.Conv2d(128, 256, kernel_size=3, padding=1),# 第三个卷积层,输入128通道,输出256通道
nn.ReLU(), # ReLU 激活函数
nn.MaxPool2d(kernel_size=2, stride=2), # 最大池化层
nn.Flatten(), # 展平层,将多维度的张量展平为一维
nn.Linear(256*4*4, 512), # 全连接层,输入大小为256*4*4,输出大小为512
nn.ReLU(), # ReLU 激活函数
nn.Linear(512, 100) # 输出层,100 个类别
)
def forward(self, x):
return self.net(x)
# 训练和验证模型
def main():
# 实例化网络,并将其移动到GPU(如果可用)
net = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001) # Adam优化器,学习率0.001
num_epochs = 25 # 训练的轮数
best_acc = 0.0 # 用于保存最佳模型的准确率
train_losses = [] # 用于保存训练损失
test_losses = [] # 用于保存测试损失
train_accuracies = [] # 用于保存训练准确率
test_accuracies = [] # 用于保存测试准确率
# 开始训练
for epoch in range(num_epochs):
net.train() # 设置为训练模式
running_loss = 0.0
correct = 0
total = 0
# 进度条显示训练进度
train_bar = tqdm(trainloader, desc=f'Training Epoch {epoch+1}/{num_epochs}')
for inputs, labels in train_bar:
inputs, labels = inputs.to(device), labels.to(device) # 将输入和标签移动到GPU
optimizer.zero_grad() # 梯度清零
outputs = net(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item()
_, predicted = outputs.max(1) # 获取预测的类别
total += labels.size(0)
correct += predicted.eq(labels).sum().item() # 计算准确率
# 更新进度条
train_bar.set_postfix(loss=running_loss/(len(train_bar)*trainloader.batch_size), acc=100.*correct/total)
# 记录训练损失和准确率
train_loss = running_loss / len(trainloader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accuracies.append(train_acc)
# 验证模型
net.eval() # 设置为评估模式
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
# 进度条显示验证进度
test_bar = tqdm(testloader, desc=f'Validating Epoch {epoch+1}/{num_epochs}')
for inputs, labels in test_bar:
inputs, labels = inputs.to(device), labels.to(device) # 将输入和标签移动到GPU
outputs = net(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
test_loss += loss.item()
_, predicted = outputs.max(1) # 获取预测的类别
total += labels.size(0)
correct += predicted.eq(labels).sum().item() # 计算准确率
# 更新进度条
test_bar.set_postfix(loss=test_loss/(len(test_bar)*testloader.batch_size), acc=100.*correct/total)
# 记录测试损失和准确率
test_loss = test_loss / len(testloader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
test_accuracies.append(test_acc)
# 保存验证集上表现最好的模型
if test_acc > best_acc:
best_acc = test_acc
torch.save(net.state_dict(), 'best_model_cifar100.pth')
# 绘制损失和准确率曲线
plt.figure(figsize=(12, 5))
# 绘制训练和验证损失曲线
plt.subplot(1, 2, 1)
plt.plot(range(num_epochs), train_losses, label='Training Loss')
plt.plot(range(num_epochs), test_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
# 绘制训练和验证准确率曲线
plt.subplot(1, 2, 2)
plt.plot(range(num_epochs), train_accuracies, label='Training Accuracy')
plt.plot(range(num_epochs), test_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
# 保存图像
plt.savefig('loss_accuracy_curves_cifar100.png')
plt.show()
if __name__ == "__main__":
main() # 调用主函数,开始训练和验证模型
Seraphina_Lily
- 粉丝: 1379
- 资源: 11
最新资源
- (178548844)zotero文献阅读以及主题和翻译插件
- (179839044)64402-MySQL数据库基础实例教程(第3版)(微课版)-源代码(含例题、案例、实训、实战四个项目).zip.zip
- 基于微信小程序的户外旅游小程序.zip
- 双摇臂履带底盘sw16可编辑全套技术开发资料100%好用.zip
- 国外某地气温数据(extend:2011-2016年).zip
- (18695238)libsvm文档
- 数据分析-51-小红书达人画像
- 基于微信小程序的华云智慧园区(包括数据库,源码).zip
- 步进电机驱动 C#上位机和STM32下位机源程序 步数方向控制
- 船上用品检测12-YOLO(v5至v11)、COCO、CreateML、Paligemma、TFRecord、VOC数据集合集.rar
- 非常好用 的,局域网,文件共享,文档管理 工作,方便检索文件 ,支持HTTP服务
- 手机组装自动镭焊机step全套技术开发资料100%好用.zip
- java项目,毕业设计-基于协同过滤算法商品推荐系统
- 大三-一个简单的安卓移动开发课程设计Android Studio
- 数据分析-53-「猛男的童年回忆」三大类型玩具在京东平台的销售分析
- C# TouchSocket的基础使用,连接,发送,接收WPF
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈