batch_size = 16 #设置批次大小 根据你电脑的显卡、显存 4/8 8/16/32/64
learning_rate = 1e-4 #设置学习率
epoches = 30 #设置训练的次数 //如果训练结束,精度不高 多训练几次 可以设置成8/9/10
num_of_classes=7 #要分的类别个数
from tqdm import tqdm
import torch
import os
from torch.utils import data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
trainpath = './dataset/train/'
valpath = './dataset/train/'
#数据增强的方式
traintransform = transforms .Compose([
transforms .RandomRotation (20), #随机旋转角度
transforms .ColorJitter(brightness=0.1), #颜色亮度
transforms .Resize([224, 224]), #设置成224×224大小的张量
transforms .ToTensor(), # 将图⽚数据变为tensor格式
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]),
])
valtransform = transforms .Compose([
transforms .Resize([224, 224]),
transforms .ToTensor(), # 将图⽚数据变为tensor格式
])
trainData = dsets.ImageFolder (trainpath, transform =traintransform ) # 读取训练集,标签就是train⽬录下的⽂件夹的名字,图像保存在格⼦标签下的⽂件夹⾥
valData = dsets.ImageFolder (valpath, transform =valtransform ) #读取演正剧
trainLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True) #将数据集分批次 并打乱顺序
valLoader = torch.utils.data.DataLoader(dataset=valData, batch_size=batch_size, shuffle=False) #将测试集分批次并打乱顺序
test_sum = sum([len(x) for _, _, x in os.walk(os.path.dirname(trainpath))]) #计算 训练集和测试集的图片总数
train_sum = sum([len(x) for _, _, x in os.walk(os.path.dirname(valpath))])
import numpy as np
import torchvision.models as models
model = models.resnet34(pretrained=True) #pretrained表⽰是否加载已经与训练好的参数
model.fc = torch.nn.Linear(512, num_of_classes) #将最后的fc层的输出改为标签数量(如3),512取决于原始⽹络fc层的输⼊通道
#model = model.cuda() # 如果有GPU,⽽且确认使⽤则保留;如果没有GPU,请删除
criterion = torch.nn.CrossEntropyLoss() # 定义损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定义优化器
from torch.autograd import Variable
#定义训练的函数
def train(model, optimizer, criterion):
model.train()
total_loss = 0
train_corrects = 0
for i, (image, label) in enumerate (tqdm(trainLoader)):
#image = Variable(image.cuda()) # 同理
#label = Variable(label.cuda()) # 同理
#print(i,image,label)
optimizer.zero_grad ()
target = model(image)
loss = criterion(target, label)
loss.backward()
optimizer.step()
total_loss += loss.item()
max_value , max_index = torch.max(target, 1)
pred_label = max_index.cpu().numpy()
true_label = label.cpu().numpy()
train_corrects += np.sum(pred_label == true_label)
return total_loss / float(len(trainLoader)), train_corrects / train_sum
testLoader=valLoader
#定义测试的函数
def evaluate(model, criterion):
model.eval()
corrects = eval_loss = 0
with torch.no_grad():
for image, label in tqdm(testLoader):
#image = Variable(image.cuda()) # 如果不使⽤GPU,删除.cuda()
#label = Variable(label.cuda()) # 同理
pred = model(image)
loss = criterion(pred, label)
eval_loss += loss.item()
max_value, max_index = torch.max(pred, 1)
pred_label = max_index.cpu().numpy()
true_label = label.cpu().numpy()
corrects += np.sum(pred_label == true_label)
return eval_loss / float(len(testLoader)), corrects, corrects / test_sum
#torch.save(model,"./resnet1.pt")
for i in range(epoches):
print("第{}个epoch".format(i+1))
train_loss,train_acc=train(model,optimizer,criterion)
print("train_loss: {} train_acc: {}\n".format(train_loss,train_acc))
test_loss,test_correct,test_acc=evaluate(model,criterion)
print("test_loss: {} test_correct:{} test_acc:{}".format(test_loss,test_correct,test_acc))
torch.save(model,"./resnet34_final.pt")#保存模型,第二个参数是保存的路径
TechMasterPlus
- 粉丝: 3211
- 资源: 24
最新资源
- ECharts象形柱图-象形柱图变形为柱状图-1.zip
- ECharts象形柱图-虚线柱状图效果-3.zip
- ECharts象形柱图-精灵-5.zip
- java jdk8 windows macos linux
- 协作臂控制软件包C++
- ImageMagick-7.1.0-57-Q16-HDRI-x64
- 三极管全自动套管装配机工程图机械结构设计图纸和其它技术资料和技术方案非常好100%好用.zip
- 基于java+springboot+mysql+微信小程序的超市售货管理平台小程序 源码+数据库+论文(高分毕业设计).zip
- 基于java+springboot+mysql+微信小程序的仓储管理系统 源码+数据库+论文(高分毕业设计).zip
- macos java jdk17
- 对接顺丰开放平台获取顺丰速运快递路由信息的PHP程序
- 基于java+springboot+mysql+微信小程序的大学生校园兼职小程序 源码+数据库+论文(高分毕业设计).zip
- 基于java+springboot+mysql+微信小程序的大学生心理健康测评管理系统 源码+数据库+论文(高分毕业设计).zip
- 基于java+springboot+mysql+微信小程序的大学生党务学习平台小程序 源码+数据库+论文(高分毕业设计).zip
- 基于java+springboot+mysql+微信小程序的电影交流平台小程序 源码+数据库+论文(高分毕业设计).zip
- 基于java+springboot+mysql+微信小程序的电影院票务系统 源码+数据库+论文(高分毕业设计).zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈