#手写数字识别
#--------------------准备阶段----------------
#1.加载必要的库
#2.定义超参数
#3.构建pipeline,对图像做处理
#4.下载、加载数据
#5.构建网络模型
#6.定义优化器
#7.定义训练方法
#8.定义测试方法
#9.调用方法(7、8)
#1.加载必要的库
import numpy as np
import torch
import torch.nn as nn #加载网络模型
import torch.nn.functional as F
import torch.optim as optim #导入优化器
from torchvision import datasets,transforms
#参数:模型f(x,teta)中的teta称为模型的优化参数,可以通过优化算法进行学习
#超参数:用来定义模型结构或优化策略
#2.定义超参数(超参数是人为给定的)
BATCH_SIZE=128 #每批处理的数据
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu") #决定cpu训练还是gpu训练
EPOCHS=10 #训练数据集的轮次
#3.构建pipeline,对图像做处理 transforms主要对图像做变换
pipeline=transforms.Compose([
transforms.ToTensor(), #将图片转换成tensor
transforms.Normalize((0.1307,),(0.3081,)) #正则化 模型出现过拟合现象时,降低模型复杂度。
])
#4.下载、加载数据
from torch.utils.data import DataLoader
#下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
#加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True) #shuffle:打乱图片
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
#插入代码显示MINST中的图片
with open("./data/MNIST/raw/train-images-idx3-ubyte","rb")as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16:16+784]]
#print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("1.jpg",image1_np)
#5.构建网络模型
class Digit(nn.Module):
def __init__(self): #构造方法
super().__init__() #调用父类的构造方法
self.conv1=nn.Conv2d(1,10,5) # 二维卷积 1:灰度图片的通道 ,10:输出通道 ,5:卷积核
self.conv2=nn.Conv2d(10,20,3) #10:输入通道 ,20:输出通道 ,3:卷积核Kernel
self.fc1 =nn.Linear(20*10*10,500) #20*10*10:输入通道; 500:输出通道 全连接层
self.fc2=nn.Linear(500,10) #500:输入通道,10:输出通道
def forward(self,x):
input_size=x.size(0) #batch_size x 1 x 28 x 28
x=self.conv1(x) #输入:batch*1*28*28,输出:batch*10*24*24 24=28-5+1
x=F.relu(x) #激活函数保持shape不变 输出:batch*10*24*24
x=F.max_pool2d(x,2,2) #池化(对图片进行压缩的方法)降采样 输入:batch*10*24*24 输出:batch*10*12*12
x=self.conv2(x) #输入:batch*10*12*12 输出:batch*20*10*10 10=12-3+1
x=F.relu(x)
x=x.view(input_size,-1) #拉平,-1自动计算维度,20*10*10=2000
x=self.fc1(x) #输入:batch*2000 输出:batch*500
x=F.relu(x) #激活保持shape不变
x=self.fc2(x) #输入:batch*500 输出:batch*10
output=F.log_softmax(x,dim=1) #计算分类后,每个数字的概率值
return output
#6.定义优化器
model=Digit().to(DEVICE)
optimizer=optim.Adam(model.parameters()) #优化器的作用是用来更新模型的参数
#7.定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):
#模型训练
model.train()
for batch_index,(data,target) in enumerate(train_loader):
#部署到DEVICE上去
data,target=data.to(device),target.to(device)
#梯度初始化为0
optimizer.zero_grad()
#训练后的结果
output=model(data)
#计算损失
loss=F.cross_entropy(output,target)
#找到概率值最大的下标
pred=output.max(1,keepdim=True) #pred=output.argmax(dim=1)
#反向传播
loss.backward()
#参数优化
optimizer.step()
if batch_index %3000 ==0:
print("Train Epoch : {}\t loss : {:.6f}".format(epoch,loss.item()))
#8.定义测试方法
def test_model(model,device,test_loader):
#模型验证
model.eval()
#统计正确率
correct=0.0
#测试损失
test_loss =0.0
with torch.no_grad(): #不会计算梯度,也不会进行反向传播
for data,target in test_loader:
#部署到device上
data,target=data.to(device),target.to(device)
#测试数据
output=model(data)
#计算测试损失
test_loss+=F.cross_entropy(output,target).item()
#找到概率值最大的下标
pred=output.max(1,keepdim=True)[1] #值,索引
#累计正确率
correct+=pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("Test——Average loss:{:.4f},Accuracy: {:.3f}\n".format(test_loss,100.0*correct/len(test_loader.dataset)))
#9.调用方法(7、8)
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
HandWriteReg.zip (18个子文件)
HandWriteReg
.idea
workspace.xml 6KB
misc.xml 188B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 287B
deployment.xml 871B
pythonProject2.iml 291B
.gitignore 188B
Exercises
data
MNIST
raw
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
train-labels-idx1-ubyte.gz 28KB
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
train-labels-idx1-ubyte 59KB
1.jpg 884B
NumbRw.py 5KB
digit.jpg 774B
共 18 条
- 1
资源评论
XINYUW
- 粉丝: 121
- 资源: 9
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功