import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from torchvision.models import inception_v3
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理
transform = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = inception_v3(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
modellrnew = modellr * (0.1 ** (epoch // 50))
print("lr:", modellrnew)
for param_group in optimizer.param_groups:
param_group['lr'] = modellrnew
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
output,hid = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
def val(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
_, pred = torch.max(output.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
# 训练
for epoch in range(1, EPOCHS + 1):
adjust_learning_rate(optimizer, epoch)
train(model_ft, DEVICE, train_loader, optimizer, epoch)
val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')
没有合适的资源?快使用搜索试试~ 我知道了~
GoogLeNet图像分类.rar
共2000个文件
jpg:5755个
py:6个
pyc:2个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
5星 · 超过95%的资源 9 下载量 14 浏览量
2021-06-12
19:32:28
上传
评论 1
收藏 972.01MB RAR 举报
温馨提示
【图像分类】实战——使用GoogLeNet识别动漫https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/117852220?spm=1001.2014.3001.5501,里面有你模型和数据集。
资源推荐
资源详情
资源评论
收起资源包目录
GoogLeNet图像分类.rar (2000个子文件)
img_0800068.jpg 686KB
img_1700163.jpg 640KB
img_1700105.jpg 596KB
img_0700130.jpg 575KB
img_2400097.jpg 558KB
img_1600094.jpg 557KB
img_0700230.jpg 557KB
img_1700155.jpg 543KB
img_0200184.jpg 541KB
img_0300190.jpg 521KB
img_0700213.jpg 519KB
img_0700227.jpg 516KB
img_0300179.jpg 511KB
img_0800072.jpg 502KB
img_0800042.jpg 501KB
img_0700121.jpg 501KB
img_1500124.jpg 500KB
img_1900194.jpg 492KB
img_1500162.jpg 491KB
img_0200063.jpg 487KB
img_1900100.jpg 486KB
img_0200061.jpg 484KB
img_1500224.jpg 484KB
img_1600093.jpg 481KB
img_0800092.jpg 472KB
img_1700127.jpg 472KB
img_0500085.jpg 471KB
img_1500230.jpg 471KB
img_1800039.jpg 469KB
img_0900085.jpg 466KB
img_1500197.jpg 464KB
img_1900178.jpg 461KB
img_1500237.jpg 460KB
img_1800038.jpg 458KB
img_0700114.jpg 450KB
img_1700149.jpg 450KB
img_0600054.jpg 446KB
img_0800156.jpg 445KB
img_1500145.jpg 443KB
img_2300133.jpg 441KB
img_0700024.jpg 440KB
img_1700080.jpg 439KB
img_0200033.jpg 438KB
img_0700140.jpg 438KB
img_0100124.jpg 436KB
img_0400031.jpg 433KB
img_1500241.jpg 433KB
img_0300047.jpg 432KB
img_1700212.jpg 430KB
img_0700162.jpg 430KB
img_1500199.jpg 427KB
img_0600097.jpg 426KB
img_1500065.jpg 426KB
img_2300138.jpg 424KB
img_0200064.jpg 424KB
img_1500240.jpg 423KB
img_1500199.jpg 422KB
img_2300139.jpg 422KB
img_0700094.jpg 421KB
img_1700156.jpg 421KB
img_0200069.jpg 421KB
img_0200187.jpg 420KB
img_2400153.jpg 418KB
img_0700219.jpg 417KB
img_0300159.jpg 416KB
img_0600032.jpg 416KB
img_1500233.jpg 416KB
img_0100110.jpg 412KB
img_0300186.jpg 412KB
img_0200128.jpg 411KB
img_1800032.jpg 410KB
img_0500101.jpg 410KB
img_0800057.jpg 410KB
img_0900170.jpg 410KB
img_1600065.jpg 410KB
img_1500228.jpg 409KB
img_0700212.jpg 409KB
img_2500061.jpg 409KB
img_0800084.jpg 408KB
img_0500180.jpg 408KB
img_1700094.jpg 407KB
img_0800043.jpg 406KB
img_0300081.jpg 406KB
img_0200032.jpg 406KB
img_0300117.jpg 405KB
img_0700025.jpg 405KB
img_0700223.jpg 404KB
img_0200070.jpg 404KB
img_2400160.jpg 404KB
img_1500227.jpg 402KB
img_0200142.jpg 402KB
img_0700110.jpg 401KB
img_0600060.jpg 401KB
img_1600179.jpg 401KB
img_0800059.jpg 401KB
img_1800032.jpg 400KB
img_1500120.jpg 399KB
img_0600118.jpg 398KB
img_0300088.jpg 397KB
img_0900047.jpg 397KB
共 2000 条
- 1
- 2
- 3
- 4
- 5
- 6
- 20
AI浩
- 粉丝: 14w+
- 资源: 216
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页