import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from nin import NiN
data_train = torchvision.datasets.FashionMNIST(
root="FashionMNIST", train=True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]), download=True)
data_test = torchvision.datasets.FashionMNIST(
root="FashionMNIST", train=False, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]), download=True)
data_train_loader = DataLoader(data_train, batch_size=32, shuffle=True, num_workers=4)#数据加载器加载训练数据
data_test_loader = DataLoader(data_test, batch_size=16, num_workers=4)#数据加载器加载测试数据
model = NiN()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
# config
epochs = 12#迭代次数
lr = 0.0001#学习率
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
def train():
print('start training')
# 训练模型
for epoch in range(epochs):
model.train()#训练模式
epoch_loss = 0
epoch_accuracy = 0
for _, (data, label) in enumerate(data_train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)#输出
loss = criterion(output, label)#计算loss
optimizer.zero_grad()#清空过往梯度(因为每次循环都是一次完整的训练)
loss.backward()#反向传播
optimizer.step()#更新参数
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(data_train_loader)#当前训练平均准确率
epoch_loss += loss / len(data_train_loader)#累计loss
print(f'EPOCH:{epoch:2}, train loss:{epoch_loss:.4f}, train acc:{epoch_accuracy:.4f}')
def test():
best_accuracy = 0
model.eval() #加与不加都行
total_correct = 0 #记录正确数目
avg_loss = 0.0 #记录平均错误
for _, (images, labels) in enumerate(data_test_loader):
images = images.to(device)
labels = labels.to(device)
output = model(images)
avg_loss += criterion(output, labels).sum() #将损失累加起来
pred = output.detach().max(1)[1] #max(1)得到每行最大值的第一个(得到概率最大的那个),.detach()指这个tensor永远不需要计算其梯度
total_correct += pred.eq(labels.view_as(pred)).sum() #累加与pred同类型的labels(即为正确)的数值,即记录正确分数(如果预测对了对应的位置就是1)
avg_loss /= len(data_test) #平均误差
if(float(total_correct) / len(data_test) > best_accuracy):
torch.save(model.cpu().state_dict(), 'model.pth')
best_accuracy = max(best_accuracy,float(total_correct) / len(data_test))
print('bestaccuracy is %f' % best_accuracy)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) #输出信息
def main(): #开始训练和测试
train()
test()
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
NiN.zip
共3个文件
py:2个
pyc:1个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 66 浏览量
2023-08-20
15:18:30
上传
评论
收藏 3KB ZIP 举报
温馨提示
NiN.zip
资源推荐
资源详情
资源评论
收起资源包目录
NiN.zip (3个子文件)
NiN
nin_torch.py 3KB
nin.py 1KB
__pycache__
nin.cpython-38.pyc 1KB
共 3 条
- 1
资源评论
sjx_alo
- 粉丝: 1w+
- 资源: 1206
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功