from MINISTModel import MINISTModel
import torch
from torch import nn
# 定义超参数
learning_rate = 1e-3
batch_size = 64
epochs = 200
flag = "A" # A means origial MNIST;B for mine
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 选择模型,loss,optimizer
model = MINISTModel().to(device) # BP神经网络
loss_fn = nn.CrossEntropyLoss()
# 优化器,优化算法我们采用SGD随机梯度下降,模型内部的参数(w,b)已经被初始化好了
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
train_loss_list = []
# 训练train_dataset
def train_loop(dataloader, model, loss_fn, optimizer):
# training_data是MNIST对象,train_dataloader.dataset从train_dataloader取出该对象,len方法返回数据集的大小
size = len(dataloader.dataset)
# batch 代表从dataloader中抽取出的第几个batch_size,是通过枚举enumerate得到的序号。X是64个image的Tensor,y是对应的标签
for batch, (X, y) in enumerate(dataloader): # enmumerate:枚举,元素一个个列举出来
# Compute prediction and loss
pred = model(X.to(device)) # pred包含了64个样本的输出,是一个64*10的Tensor
loss = loss_fn(pred, y.to(device))
# Back propagation
optimizer.zero_grad() # 重置模型参数的梯度,默认情况下梯度会迭代相加
loss.backward() # 反向传播预测损失,计算梯度
optimizer.step() # 梯度下降,w = w - lr * 梯度。随机梯度下降是迭代的,通过随机噪声能避免鞍点的出现
loss = loss.item()
if batch % 100 == 0: # 取余的数值可以自己设置
loss, current = loss, batch * batch_size # 我将len(X)替换为了batch_size
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss_list.append(round(loss, 7)) # 损失加入到列表中
没有合适的资源?快使用搜索试试~ 我知道了~
收起资源包目录
MINIST CNN.rar (27个子文件)
MINIST CNN
main.py 70B
data
dataLoader.py 1KB
MNIST_data
t10k-images-idx3-ubyte.gz 1.57MB
train-images.idx3-ubyte 44.86MB
train-labels-idx1-ubyte.gz 28KB
train-labels.idx1-ubyte 59KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
__pycache__
dataLoader.cpython-310.pyc 871B
mnist
MNIST
raw
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
train-labels-idx1-ubyte.gz 28KB
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
train-labels-idx1-ubyte 59KB
MNIST_data.rar 11.06MB
.idea
MINIST CNN.iml 335B
workspace.xml 3KB
misc.xml 199B
inspectionProfiles
Project_Default.xml 410B
profiles_settings.xml 174B
modules.xml 279B
.gitignore 50B
MINISTModel
MINISTModel.py 587B
train.py 2KB
共 27 条
- 1
资源推荐
资源预览
资源评论
5星 · 资源好评率100%
176 浏览量
5星 · 资源好评率100%
119 浏览量
2025-01-06 上传
2025-01-02 上传
188 浏览量
5星 · 资源好评率100%
2022-05-27 上传
2022-02-10 上传
186 浏览量
5星 · 资源好评率100%
5星 · 资源好评率100%
5星 · 资源好评率100%
114 浏览量
128 浏览量
5星 · 资源好评率100%
166 浏览量
2021-03-21 上传
155 浏览量
2022-06-21 上传
2018-11-28 上传
167 浏览量
5星 · 资源好评率100%
158 浏览量
2017-08-18 上传
2018-01-13 上传
2020-12-13 上传
2023-10-06 上传
5星 · 资源好评率100%
资源评论
新兴AI民工
- 粉丝: 6202
- 资源: 11
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 漂亮动态效果PPT柱形图-3.pptx
- 山形柱状图数据分析PPT模板-1.pptx
- 长阴影扁平化PPT柱形图模板-1.pptx
- 山形锥形柱状图PPT模板素材-1.pptx
- 条形图-数据图表-简约扁平-3.pptx
- 条形图-数据图表-时尚红蓝-PPT模板-3.pptx
- 小人人数比例分析说明PPT模板-1.pptx
- 柱状图-数据图表-高端商务-3.pptx
- 柱状图-数据图表-扁平简洁-3.pptx
- 柱状图-数据图表-简约扁平 -3.pptx
- 柱状图-数据图表-清新活泼-3.pptx
- 柱状图-数据图表-折纸简洁-3.pptx
- 柱状图-数据图表-简约扁平--1.pptx
- windows tcp连通性测试工具tcping64
- CDN(内容分发网络)核心技术解析及其在网络优化中的应用
- 饼图-数据图表-简约清新 -3.pptx
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功