from MINISTModel import MINISTModel
import torch
from torch import nn
# 定义超参数
learning_rate = 1e-3
batch_size = 64
epochs = 200
flag = "A" # A means origial MNIST;B for mine
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 选择模型,loss,optimizer
model = MINISTModel().to(device) # BP神经网络
loss_fn = nn.CrossEntropyLoss()
# 优化器,优化算法我们采用SGD随机梯度下降,模型内部的参数(w,b)已经被初始化好了
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
train_loss_list = []
# 训练train_dataset
def train_loop(dataloader, model, loss_fn, optimizer):
# training_data是MNIST对象,train_dataloader.dataset从train_dataloader取出该对象,len方法返回数据集的大小
size = len(dataloader.dataset)
# batch 代表从dataloader中抽取出的第几个batch_size,是通过枚举enumerate得到的序号。X是64个image的Tensor,y是对应的标签
for batch, (X, y) in enumerate(dataloader): # enmumerate:枚举,元素一个个列举出来
# Compute prediction and loss
pred = model(X.to(device)) # pred包含了64个样本的输出,是一个64*10的Tensor
loss = loss_fn(pred, y.to(device))
# Back propagation
optimizer.zero_grad() # 重置模型参数的梯度,默认情况下梯度会迭代相加
loss.backward() # 反向传播预测损失,计算梯度
optimizer.step() # 梯度下降,w = w - lr * 梯度。随机梯度下降是迭代的,通过随机噪声能避免鞍点的出现
loss = loss.item()
if batch % 100 == 0: # 取余的数值可以自己设置
loss, current = loss, batch * batch_size # 我将len(X)替换为了batch_size
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss_list.append(round(loss, 7)) # 损失加入到列表中
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
MINIST CNN.rar (27个子文件)
MINIST CNN
main.py 70B
data
dataLoader.py 1KB
MNIST_data
t10k-images-idx3-ubyte.gz 1.57MB
train-images.idx3-ubyte 44.86MB
train-labels-idx1-ubyte.gz 28KB
train-labels.idx1-ubyte 59KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
__pycache__
dataLoader.cpython-310.pyc 871B
mnist
MNIST
raw
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
train-labels-idx1-ubyte.gz 28KB
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
train-labels-idx1-ubyte 59KB
MNIST_data.rar 11.06MB
.idea
MINIST CNN.iml 335B
workspace.xml 3KB
misc.xml 199B
inspectionProfiles
Project_Default.xml 410B
profiles_settings.xml 174B
modules.xml 279B
.gitignore 50B
MINISTModel
MINISTModel.py 587B
train.py 2KB
共 27 条
- 1
资源评论
新兴AI民工
- 粉丝: 6067
- 资源: 11
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功