没有合适的资源?快使用搜索试试~ 我知道了~
利用pytorch搭建lenet网络,并利用mnist数据集进行训练测试
19 下载量 175 浏览量
2021-01-06
16:11:10
上传
评论 1
收藏 31KB PDF 举报
温馨提示
最近在学习pytorch,手工复现了LeNet网络,并附源码如下,欢迎大家留言交流 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms class LeNet(nn.Module): def __init__(self): super(LeNet,self).__init__() self.conv1 = nn.Conv2d(1
资源推荐
资源详情
资源评论
利用利用pytorch搭建搭建lenet网络,并利用网络,并利用mnist数据集进行训练测数据集进行训练测
试试
最近在学习pytorch,手工复现了LeNet网络,并附源码如下,欢迎大家留言交流
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, 1,)
self.conv2 = nn.Conv2d(6, 16, 5, 1)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
#x:1*28*28
x = F.max_pool2d(self.conv1(x), 2, 2)
x = F.max_pool2d(self.conv2(x), 2, 2)
x = x.view(-1, 16*4*4)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return F.log_softmax(x,dim=1)
def train(model, device, train_loader, optimizer, epoch):
model.train()
for idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
pred = model(data)
loss = F.nll_loss(pred,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if idx % 100 == 0:
print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
def test(model, device, test_loader):
model.eval()
total_loss = 0.
correct = 0.
with torch.no_grad():
for idx, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
total_loss += F.nll_loss(output, target, reduction="sum").item()
pred = output.argmax(dim=1)
correct +=pred.eq(target.view_as(pred)).sum().item()
total_loss /= len(test_loader.dataset)
acc = correct/len(test_loader.dataset)*100
print("Test loss: {}, Accuracy: {}".format(total_loss, acc))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
datasets.MNIST("./MNIST_data", train=True, download=False,
transform = transforms.Compose([
transforms.ToTensor(),
资源评论
weixin_38617001
- 粉丝: 5
- 资源: 902
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- (源码)基于 JavaFX 和 MySQL 的影院管理系统.zip
- (源码)基于EAV模型的动态广告位系统.zip
- (源码)基于Qt的长沙地铁换乘系统.zip
- (源码)基于ESP32和DM02A模块的智能照明系统.zip
- (源码)基于.NET Core和Entity Framework Core的学校管理系统.zip
- (源码)基于C#的WiFi签到管理系统.zip
- (源码)基于WPF和MVVM框架的LikeYou.WAWA管理系统.zip
- (源码)基于C#的邮件管理系统.zip
- 【yan照门】chen冠希(1323张) [2月25日凌晨新增容祖儿全94张].rar.torrent
- (源码)基于C++的员工管理系统.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功