import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import DogCat
from torch.autograd import Variable
from efficientnet_pytorch import EfficientNet
#pip install efficientnet_pytorch
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = DogCat('data/train', transforms=transform, train=True)
dataset_test = DogCat("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = EfficientNet.from_pretrained('efficientnet-b3')
num_ftrs = model_ft._fc.in_features
model_ft._fc = nn.Linear(num_ftrs, 2)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
modellrnew = modellr * (0.1 ** (epoch // 50))
print("lr:", modellrnew)
for param_group in optimizer.param_groups:
param_group['lr'] = modellrnew
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
def val(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
_, pred = torch.max(output.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
# 训练
for epoch in range(1, EPOCHS + 1):
adjust_learning_rate(optimizer, epoch)
train(model_ft, DEVICE, train_loader, optimizer, epoch)
val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')
AI浩
- 粉丝: 15w+
- 资源: 232
最新资源
- 3b015大学生创业项目管理系统_springboot+vue0.zip
- x86-64架构下gmssl工具
- 3b016个性化课程推荐系统_springboot+vue.zip
- 电影订票及评论网站的设计与实现-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.zip
- 3b014宠物猫店管理系统_springboot+vue.zip
- 仓库管理系统pf-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.zip
- 洞见研报Geek+(智能物流机器人研发商,北京极智嘉科技股份有限公司)创投信息
- 付费问答系统的设计与实现-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 3b017旅游景区预约服务系统_springboot+vue0.zip
- 新版在线生成一合三网站缩微图工具PHP源码
- 基于BS的社区物业管理系统-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 基于Java的美妆购物网站的设计与实现-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 基于HTML语言的环保网站的设计与实现-springboot毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 德普微一级代理 DP040N04DTL TO-252 DPMOS N-MOSFET 40V 100A 3.2mΩ
- 3b019企业人事管理系统_springboot+vue.zip
- 3b018企业人力资源管理系统_springboot+vue.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
- 1
- 2
- 3
- 4
- 5
前往页