import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
class Model:
def __init__(self, net, cost, optimist):
self.net = net
self.cost = self.create_cost(cost)
self.optimizer = self.create_optimizer(optimist)
pass
def create_cost(self, cost):
support_cost = {
'CROSS_ENTROPY': nn.CrossEntropyLoss(),
'MSE': nn.MSELoss()
}
return support_cost[cost]
def create_optimizer(self, optimist, **rests):
support_optim = {
'SGD': optim.SGD(self.net.parameters(), lr=0.1, **rests),
'ADAM': optim.Adam(self.net.parameters(), lr=0.01, **rests),
'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rests)
}
return support_optim[optimist]
def train(self, train_loader, epoches=3):
for epoch in range(epoches):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.net(inputs)
loss = self.cost(outputs, labels)
loss.backward()
self.optimizer.step()
running_loss += loss.item()
if i % 100 == 0:
print('[epoch %d, %.2f%%] loss: %.3f' %
(epoch + 1, (i + 1)*1./len(train_loader), running_loss / 100))
running_loss = 0.0
print('Finished Training')
def evaluate(self, test_loader):
print('Evaluating ...')
correct = 0
total = 0
with torch.no_grad(): # no grad when test and predict
for data in test_loader:
images, labels = data
outputs = self.net(images)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def mnist_load_data():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0,], [1,])])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)
return trainloader, testloader
class MnistNet(torch.nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
self.fc1 = torch.nn.Linear(28*28, 512)
self.fc2 = torch.nn.Linear(512, 512)
self.fc3 = torch.nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
return x
if __name__ == '__main__':
# train for mnist
net = MnistNet()
model = Model(net, 'CROSS_ENTROPY', 'RMSP')
train_loader, test_loader = mnist_load_data()
model.train(train_loader)
model.evaluate(test_loader)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
pytorch.rar (10个子文件)
pytorch
linear.py 466B
容器.py 403B
data
MNIST
raw
t10k-labels-idx1-ubyte 10KB
train-images-idx3-ubyte 44.86MB
t10k-images-idx3-ubyte 7.48MB
train-labels-idx1-ubyte 59KB
processed
test.pt 7.55MB
training.pt 45.32MB
自动求导.py 286B
mnistnet.py 4KB
共 10 条
- 1
资源评论
- GithL2023-03-07资源不错,对我启发很大,获得了新的灵感,受益匪浅。
- weixin_579173822022-10-12资源不错,很实用,内容全面,介绍详细,很好用,谢谢分享。
- pepper_salt2022-02-14用户下载后在一定时间内未进行评价,系统默认好评。
- 瑾色安年9852023-12-04资源不错,很实用,内容全面,介绍详细,很好用,谢谢分享。
摇滚死兔子
- 粉丝: 54
- 资源: 4227
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功