import os
import time
import copy
import numpy as np
import matplotlib as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms
# define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load data and do data augmention
path = 'data/'
mode = ('train', 'val')
transform = {
'train':transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]),
'val':transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
}
kwargs = {'num_workers':4, 'pin_memory':True}
image_datasets = {x: datasets.ImageFolder(root=os.path.join(path, x), transform = transform[x])
for x in mode}
data_loaders = {x: DataLoader(image_datasets[x], batch_size=4, shuffle=True, **kwargs)
for x in mode}
class_names = image_datasets['train'].classes
dataset_size = {x: len(image_datasets[x]) for x in mode}
# define my net and criterion optimizer
my_resnet18 = torchvision.models.resnet18(pretrained=True)
num_features = my_resnet18.fc.in_features
my_resnet18.fc = nn.Linear(512, 2)
#my_resnet18 = nn.DataParallel(my_resnet18)
my_resnet18 = my_resnet18.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(my_resnet18.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
print(my_resnet18)
#import ipdb; ipdb.set_trace()
# train
def train_model(epochs=25):
best_model_wts = copy.deepcopy(my_resnet18.state_dict())
best_acc = 0.
for epoch in range(epochs):
# in each epoch
#epoch_start = time.time()
print('Epoch {}/{}'.format(epoch, epochs-1))
print('-'*10)
# iterate on the whole data training set
for phase in mode:
running_loss = 0.
running_corrects = 0
if phase == 'train':
exp_lr_scheduler.step()
my_resnet18.train()
else :
my_resnet18.eval()
# in each epoch iterate over all dataset
for inputs, labels in data_loaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(phase == 'train'):
# in each iter step
# 1. zero the parameter gradients
optimizer.zero_grad()
# 2. forward
outputs = my_resnet18(inputs)
# 3. compute loss and backward and update parameters
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
preds = outputs.max(1)[1]
running_loss += loss.item()*inputs.size(0)
running_corrects += torch.sum(preds == labels)
epoch_loss = running_loss/dataset_size[phase]
epoch_acc = running_corrects.double()/dataset_size[phase]
print('%s Loss: %.4f ACC: %.4f'%(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(my_resnet18.state_dict())
print()
# load best model weights
my_resnet18.load_state_dict(best_model_wts)
train_model()
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
实现功能 读取torchvision中保存的resnet18网络模型,设置预训练pretrained=True,修改最后一个全连接层为my_resnet18.fc = nn.Linear(512, 2) 定义交叉熵损失函数criterion = nn.CrossEntropyLoss() 定义随机梯度下降优化器optimizer = optim.SGD(my_resnet18.parameters(), lr=0.001, momentum=0.9) 定义学习率每7步自动衰减exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) 进行25轮次训练,每一轮都在训练集上训练,在验证集测试 把25轮次中的最优模型参数保存下来best_model_wts = copy.deepcopy(my_resnet18.state_dict()) 最终模型读取最优参数my_resnet18.load_state_dict(best_model_wts) 注意 多卡运行可以平均使用每张卡,但是卡之
资源推荐
资源详情
资源评论
收起资源包目录
人工智能-项目实践-迁移学习-基于resnet18的迁移学习分类网络,用于给bee和ant二分类.zip (399个子文件)
imageNotFound.gif 5KB
formica.jpeg 8KB
ants-devouring-remains-of-large-dead-insect-on-red-tile-in-Stellenbosch-South-Africa-closeup-1-DHD.jpg 2MB
Nepenthes_rafflesiana_ant.jpg 694KB
477437164_bc3e6e594a.jpg 249KB
1181173278_23c36fac71.jpg 241KB
774440991_63a4aa0cbe.jpg 235KB
59798110_2b6a3c8031.jpg 220KB
339670531_94b75ae47a.jpg 217KB
215512424_687e1e0821.jpg 209KB
226951206_d6bf946504.jpg 208KB
1519368889_4270261ee3.jpg 208KB
2477349551_e75c97cf4d.jpg 206KB
2191997003_379df31291.jpg 205KB
2841437312_789699c740.jpg 204KB
562589509_7e55469b97.jpg 200KB
2683605182_9d2a0c66cf.jpg 199KB
2039585088_c6f47c592e.jpg 199KB
161076144_124db762d6.jpg 199KB
540543309_ddbb193ee5.jpg 196KB
488272201_c5aa281348.jpg 193KB
318052216_84dff3f98a.jpg 192KB
188552436_605cc9b36b.jpg 191KB
1119630822_cd325ea21a.jpg 190KB
196658222_3fffd79c67.jpg 187KB
649407494_9b6bc4949f.jpg 187KB
460874319_0a45ab4d05.jpg 186KB
207947948_3ab29d7207.jpg 186KB
198508668_97d818b6c4.jpg 185KB
3044402684_3853071a87.jpg 185KB
1807583459_4fe92b3133.jpg 185KB
374435068_7eee412ec4.jpg 183KB
129236073_0985e91c7d.jpg 182KB
2792000093_e8ae0718cf.jpg 181KB
2709775832_85b4b50a57.jpg 176KB
2457841282_7867f16639.jpg 174KB
751649788_78dd7d16ce.jpg 174KB
196757565_326437f5fe.jpg 174KB
212100470_b485e7b7b9.jpg 174KB
2509402554_31821cb0b6.jpg 173KB
522104315_5d3cb2758e.jpg 171KB
445356866_6cb3289067.jpg 171KB
238161922_55fa9a76ae.jpg 171KB
3006264892_30e9cced70.jpg 171KB
2722592222_258d473e17.jpg 171KB
1030023514_aad5c608f9.jpg 171KB
2493379287_4100e1dacc.jpg 170KB
484293231_e53cfc0c89.jpg 170KB
272986700_d4d4bf8c4b.jpg 168KB
365759866_b15700c59b.jpg 168KB
1097045929_1753d1c765.jpg 168KB
2638074627_6b3ae746a0.jpg 167KB
1355974687_1341c1face.jpg 167KB
969455125_58c797ef17.jpg 165KB
1440002809_b268d9a66a.jpg 165KB
150801003_3390b73135.jpg 165KB
208072188_f293096296.jpg 164KB
319494379_648fb5a1c6.jpg 164KB
1660097129_384bf54490.jpg 164KB
459694881_ac657d3187.jpg 163KB
167890289_dd5ba923f3.jpg 162KB
457457145_5f86eb7e9c.jpg 162KB
2962405283_22718d9617.jpg 159KB
509247772_2db2d01374.jpg 158KB
466430434_4000737de9.jpg 158KB
2741763055_9a7bb00802.jpg 157KB
485743562_d8cc6b8f73.jpg 156KB
350436573_41f4ecb6c8.jpg 155KB
150801171_cd86f17ed8.jpg 155KB
590318879_68cf112861.jpg 154KB
2668391343_45e272cd07.jpg 153KB
386190770_672743c9a7.jpg 152KB
45472593_bfd624f8dc.jpg 151KB
354167719_22dca13752.jpg 151KB
2707440199_cd170bd512.jpg 150KB
1473187633_63ccaacea6.jpg 149KB
144098310_a4176fd54d.jpg 149KB
450057712_771b3bfc91.jpg 149KB
153783656_85f9c3ac70.jpg 149KB
2486746709_c43cec0e42.jpg 148KB
44105569_16720a960c.jpg 147KB
2751836205_6f7b5eff30.jpg 146KB
586474709_ae436da045.jpg 146KB
359928878_b3b418c728.jpg 145KB
512863248_43c8ce579b.jpg 145KB
936182217_c4caa5222d.jpg 144KB
2652877533_a564830cbf.jpg 143KB
2364597044_3c3e3fc391.jpg 143KB
384191229_5779cf591b.jpg 142KB
203868383_0fcbb48278.jpg 142KB
2292213964_ca51ce4bef.jpg 142KB
72100438_73de9f17af.jpg 142KB
957233405_25c1d1187b.jpg 140KB
892108839_f1aad4ca46.jpg 140KB
116570827_e9c126745d.jpg 140KB
desert_ant.jpg 140KB
2501530886_e20952b97d.jpg 140KB
2345177635_caf07159b3.jpg 139KB
1360291657_dc248c5eea.jpg 139KB
476347960_52edd72b06.jpg 139KB
共 399 条
- 1
- 2
- 3
- 4
资源评论
博士僧小星
- 粉丝: 2381
- 资源: 5995
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- Delphi 12 控件之FlashAV FFMPEG VCL Player For Delphi v7.0 for D10-D11 Full Source.7z
- Delphi 12 控件之DevExpressVCLProducts-24.2.3.exe.zip
- Mysql配置文件优化内容 my.cnf
- 中国地级市CO2排放数据(2000-2023年).zip
- smart200光栅报警程序
- 企业信息部门2024年终工作总结与2025规划方案
- 串口AT命令发送工具,集成5G模组常用At命令
- 通过python实现归并排序示例代码.zip
- 复旦大学张奇:2023年大规模语言模型中的多语言对齐与知识分区研究
- 通过python实现一个堆排序示例代码.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功