# 训练代码,训练集使用CIFAR100,初次训练会自动下载
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from net import AlexNet
# 训练参数
batch_size = 8
epochs = 10
lr = 0.01
gamma = 0.7
no_cuda = False
seed = 1
log_interval = 10
save_model = True
def train(model, device, train_loader, optimizer, epoch, log_interval):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if save_model:
torch.save(model.state_dict(), "model.pth")
def main():
use_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR100('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=batch_size, shuffle=True, **kwargs)
model = AlexNet(num_labels=100).to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch, log_interval)
scheduler.step()
if __name__ == '__main__':
main()
AlexNet卷积神经网络图像分类Pytorch训练代码 使用Cifar100数据集
版权申诉
5星 · 超过95%的资源 19 浏览量
2023-01-28
16:24:22
上传
评论 10
收藏 2KB ZIP 举报
两只程序猿
- 粉丝: 338
- 资源: 158
最新资源
- C语言基础-C语言编程基础之Leetcode编程题解之第39题组合总和.zip
- C语言基础-C语言编程基础之Leetcode编程题解之第38题外观数列.zip
- C语言基础-C语言编程基础之Leetcode编程题解之第37题解数独.zip
- C语言基础-C语言编程基础之Leetcode编程题解之第36题有效的数独.zip
- C语言基础-C语言编程基础之Leetcode编程题解之第35题搜索插入位置.zip
- index.wxml
- C语言基础-C语言编程基础之Leetcode编程题解之第33题搜索旋转排序数组.zip
- 基于Python实现的手写数字识别系统源码.zip
- 从网页提取禁止转载的文字
- C语言基础-C语言编程基础之Leetcode编程题解之第32题最长有效括号.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
- 1
- 2
- 3
- 4
前往页