from datetime import datetime
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pylab
from matplotlib import pyplot as plt #plt.imshow需要
import numpy
from matplotlib.pyplot import MultipleLocator
def get_acc(output, label): #获取准确率
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion): #训练
if torch.cuda.is_available(): #gpu运行
net = net.cuda()
prev_time = datetime.now()
train_acc_plot = numpy.zeros(num_epochs) #plot
valid_acc_plot = numpy.zeros(num_epochs) #plot
iter = 0
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data: #训练集训练每个数据
if torch.cuda.is_available():
with torch.no_grad(): #屏蔽梯度
im = Variable(im).cuda() # (bs, 3, h, w)
label = Variable(label).cuda() # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label) #获取误差
# backward
optimizer.zero_grad() #梯度清零
loss.backward()
optimizer.step()
train_loss += loss.item() #训练集误差求和
train_acc += get_acc(output, label) #训练集准确率求和
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600) #计算时间
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None: #测试集测试每个数据
valid_loss = 0
valid_acc = 0
net = net.eval() #返回传入字符串的表达式的结果
for im, label in valid_data:
if torch.cuda.is_available():
with torch.no_grad(): #屏蔽梯度
im = Variable(im).cuda()
label = Variable(label).cuda()
else:
with torch.no_grad(): #屏蔽梯度
im = Variable(im)
label = Variable(label)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item() #测试集误差求和
valid_acc += get_acc(output, label) #测试集准确率求和
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
train_acc_plot[epoch] = train_acc/len(train_data) # plot
valid_acc_plot[epoch] = valid_acc/len(valid_data) # plot
prev_time = cur_time
print(epoch_str + time_str)
# epoch_list.append(iter)
# iter += 1
# print(epoch_list)
epoch_list = numpy.arange(len(train_acc_plot))
plt.plot(epoch_list,train_acc_plot,label="train_acc")
plt.plot(epoch_list,valid_acc_plot,label="valid_acc") #画图设置
plt.xlabel('Number of EPOCH')
plt.ylabel('train_acc&valid_acc')
ax = plt.gca()
x_major_locator = MultipleLocator(1) #设置刻度间隔为1
ax.xaxis.set_major_locator(x_major_locator) #x轴使用刻度
plt.xlim(0, 19)
plt.ylim(0.8,1)
# plt.title('RESNET')
plt.legend()
plt.show()
# plt.plot(epoch_list,train_err)
# plt.xlabel('Number of EPOCH')
# plt.ylabel('valid_err')
# plt.title('RESNET')
# plt.show()
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2d(
in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
class residual_block(nn.Module):
def __init__(self, in_channel, out_channel, same_shape=True):
super(residual_block, self).__init__()
self.same_shape = same_shape
stride = 1 if self.same_shape else 2
self.conv1 = conv3x3(in_channel, out_channel, stride=stride)
self.bn1 = nn.BatchNorm2d(out_channel)
self.conv2 = conv3x3(out_channel, out_channel)
self.bn2 = nn.BatchNorm2d(out_channel)
if not self.same_shape:
self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)
def forward(self, x):
out = self.conv1(x)
out = F.relu(self.bn1(out), True)
out = self.conv2(out)
out = F.relu(self.bn2(out), True)
if not self.same_shape:
x = self.conv3(x)
return F.relu(x + out, True)
class resnet(nn.Module):
def __init__(self, in_channel, num_classes, verbose=False):
super(resnet, self).__init__()
self.verbose = verbose
self.block1 = nn.Conv2d(in_channel, 64, 7, 2)
self.block2 = nn.Sequential(
nn.MaxPool2d(3, 2), residual_block(64, 64), residual_block(64, 64))
self.block3 = nn.Sequential(
residual_block(64, 128, False), residual_block(128, 128))
self.block4 = nn.Sequential(
residual_block(128, 256, False), residual_block(256, 256))
self.block5 = nn.Sequential(
residual_block(256, 512, False),
residual_block(512, 512), nn.AvgPool2d(3))
self.classifier = nn.Linear(512, num_classes)
def forward(self, x):
x = self.block1(x)
if self.verbose:
print('block 1 output: {}'.format(x.shape))
x = self.block2(x)
if self.verbose:
print('block 2 output: {}'.format(x.shape))
x = self.block3(x)
if self.verbose:
print('block 3 output: {}'.format(x.shape))
x = self.block4(x)
if self.verbose:
print('block 4 output: {}'.format(x.shape))
x = self.block5(x)
if self.verbose:
print('block 5 output: {}'.format(x.shape))
x = x.view(x.shape[0], -1)
x = self.classifier(x)
return x
没有合适的资源?快使用搜索试试~ 我知道了~
lstm图像分类_lstm图像处理_
共9个文件
xml:4个
py:2个
pyc:1个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
5星 · 超过95%的资源 2 下载量 108 浏览量
2021-09-30
02:01:17
上传
评论 2
收藏 9KB ZIP 举报
温馨提示
pytorch库进行lstm图像处理,数据集为mnisit手写字数据。
资源推荐
资源详情
资源评论
收起资源包目录
lstm图像分类.zip (9个子文件)
lstm图像分类
utils.py 6KB
main.py 2KB
.idea
misc.xml 188B
workspace.xml 4KB
rnn图像分类.iml 291B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 289B
.gitignore 50B
__pycache__
utils.cpython-38.pyc 5KB
共 9 条
- 1
Dyingalive
- 粉丝: 88
- 资源: 4808
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
前往页