# -*- coding: utf-8 -*-
"""
# @file name : train_lenet.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 设置随机种子
rmb_label = {"1": 0, "100": 1}
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
valid_curve.append(loss_val/valid_loader.__len__())
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct_val / total_val))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
# ============================ inference ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
rmb = 1 if predicted.numpy()[0] == 0 else 100
print("模型获得{}元".format(rmb))
没有合适的资源?快使用搜索试试~ 我知道了~
transforms.zip
共410个文件
jpg:401个
py:4个
ds_store:3个
需积分: 32 1 下载量 125 浏览量
2020-08-09
22:27:17
上传
评论
收藏 124.32MB ZIP 举报
温馨提示
简单粗暴PyTorch之transforms详解中的代码与数据集
资源推荐
资源详情
资源评论
收起资源包目录
transforms.zip (410个子文件)
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
0GPYRDQM.jpg 687KB
0GPYRDQM.jpg 687KB
0MOHTNXQ.jpg 664KB
0MOHTNXQ.jpg 664KB
0KYWGVO5.jpg 633KB
0KYWGVO5.jpg 633KB
09F2SGOT.jpg 624KB
09F2SGOT.jpg 624KB
073LW92O.jpg 621KB
073LW92O.jpg 621KB
0E5Q62TM.jpg 613KB
0E5Q62TM.jpg 613KB
0RPZ5WDL.jpg 610KB
0RPZ5WDL.jpg 610KB
01LNYXO4.jpg 610KB
01LNYXO4.jpg 610KB
0LWI5TZA.jpg 607KB
0LWI5TZA.jpg 607KB
08C3EHPG.jpg 606KB
08C3EHPG.jpg 606KB
0E6AGCOW.jpg 601KB
0E6AGCOW.jpg 601KB
0KOIAHWT.jpg 600KB
0KOIAHWT.jpg 600KB
0GRZFSDG.jpg 594KB
0GRZFSDG.jpg 594KB
0MLDWG4I.jpg 591KB
0MLDWG4I.jpg 591KB
0OFE6MSI.jpg 586KB
0OFE6MSI.jpg 586KB
027AXFQE.jpg 583KB
027AXFQE.jpg 583KB
02OE5LH4.jpg 582KB
02OE5LH4.jpg 582KB
013MNV9B.jpg 579KB
013MNV9B.jpg 579KB
0ICF2DMA.jpg 578KB
0ICF2DMA.jpg 578KB
0HBEG1TG.jpg 577KB
0HBEG1TG.jpg 577KB
0NAQUMVX.jpg 573KB
0NAQUMVX.jpg 573KB
03WGM2XG.jpg 570KB
03WGM2XG.jpg 570KB
0MEG4GXO.jpg 568KB
0MEG4GXO.jpg 568KB
0FY3IOKC.jpg 566KB
0FY3IOKC.jpg 566KB
0NVLGX81.jpg 562KB
0NVLGX81.jpg 562KB
07IUEGQX.jpg 562KB
07IUEGQX.jpg 562KB
0IPXU5A9.jpg 561KB
0IPXU5A9.jpg 561KB
07UHGSGR.jpg 553KB
07UHGSGR.jpg 553KB
0PQXSWVG.jpg 552KB
0PQXSWVG.jpg 552KB
0F9X81GD.jpg 551KB
0F9X81GD.jpg 551KB
01GUGTQ4.jpg 551KB
01GUGTQ4.jpg 551KB
06DCY1X7.jpg 550KB
06DCY1X7.jpg 550KB
0GE1UZT5.jpg 547KB
0GE1UZT5.jpg 547KB
08KCVAP1.jpg 545KB
08KCVAP1.jpg 545KB
0R6X4SO8.jpg 533KB
0R6X4SO8.jpg 533KB
0I376P29.jpg 527KB
0I376P29.jpg 527KB
0OWAK5B7.jpg 510KB
0OWAK5B7.jpg 510KB
0GHKAWQX.jpg 505KB
0GHKAWQX.jpg 505KB
04VA2NX7.jpg 504KB
04VA2NX7.jpg 504KB
05IDEW2M.jpg 503KB
05IDEW2M.jpg 503KB
0WV65B8Z.jpg 501KB
0WV65B8Z.jpg 501KB
0KS8UVFH.jpg 499KB
0KS8UVFH.jpg 499KB
0P1HGRT2.jpg 493KB
0P1HGRT2.jpg 493KB
0R2P4H1I.jpg 491KB
0R2P4H1I.jpg 491KB
0ON7E9RU.jpg 488KB
0ON7E9RU.jpg 488KB
0BOVSMYN.jpg 485KB
0BOVSMYN.jpg 485KB
0RBDE8G9.jpg 485KB
0RBDE8G9.jpg 485KB
06NEIRC4.jpg 480KB
06NEIRC4.jpg 480KB
0MN9158I.jpg 477KB
共 410 条
- 1
- 2
- 3
- 4
- 5
资源评论
刘大鸭
- 粉丝: 30
- 资源: 2
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功