import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchsummary import summary
from hwdb import HWDB
from model import ConvNet
def valid(epoch, net, test_loarder, writer):
print("epoch %d 开始验证..." % epoch)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loarder:
images, labels = images.cuda(), labels.cuda()
outputs = net(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('correct number: ', correct)
print('totol number:', total)
acc = 100 * correct / total
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
writer.add_scalar('valid_acc', acc, global_step=epoch)
def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100):
print("epoch %d 开始训练..." % epoch)
net.train()
sum_loss = 0.0
total = 0
correct = 0
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss.backward()
optimizer.step()
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
if (i + 1) % save_iter == 0:
batch_loss = sum_loss / save_iter
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
% (epoch, i + 1, batch_loss, acc))
writer.add_scalar('train_loss', batch_loss, global_step=i + len(train_loader) * epoch)
writer.add_scalar('train_acc', acc, global_step=i + len(train_loader) * epoch)
for name, layer in net.named_parameters():
writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
writer.add_histogram(name + '_data', layer.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
total = 0
correct = 0
sum_loss = 0.0
if __name__ == "__main__":
# 超参数
epochs = 20
batch_size = 100
lr = 0.01
data_path = r'data'
log_path = r'logs/batch_{}_lr_{}'.format(batch_size, lr)
save_path = r'checkpoints/'
if not os.path.exists(save_path):
os.mkdir(save_path)
# 读取分类类别
with open('char_dict', 'rb') as f:
class_dict = pickle.load(f)
num_classes = len(class_dict)
# 读取数据
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(path=data_path, transform=transform)
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
if torch.cuda.is_available():
net = net.cuda()
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
print('网络结构:\n')
summary(net, input_size=(3, 64, 64), device='cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter(log_path)
for epoch in range(epochs):
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
valid(epoch, net, testloader, writer=writer)
print("epoch%d 结束, 正在保存模型..." % epoch)
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zip 已获导师指导的高分设计项目,代码完整下载可用,纯手打高分设计,可作为期末大作业和课程设计,小白也可实战。 基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zip 已获导师指导的高分设计项目,代码完整下载可用,纯手打高分设计,可作为期末大作业和课程设计,小白也可实战。基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zip 已获导师指导的高分设计项目,代码完整下载可用,纯手打高分设计,可作为期末大作业和课程设计,小白也可实战。基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zip 已获导师指导的高分设计项目,代码完整下载可用,纯手打高分设计,可作为期末大作业和课程设计,小白也可实战。基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zip 已获导师指导的高分设计项目,代码完整下载可用,纯手打高分设计,可作为期末大作业和课程设计,小白也可实战。基于pytorch卷积神经网络的中文手写汉字识别项目源码(高分大作业).zi
资源推荐
资源详情
资源评论
收起资源包目录
手写汉字识别.zip (5个子文件)
hand-writing-recognition-master
process_gnt.py 3KB
hwdb.py 2KB
model.py 3KB
hwdb.jpg 179KB
train.py 4KB
手写
共 5 条
- 1
资源评论
- m0_655814972023-07-08内容与描述一致,超赞的资源,值得借鉴的内容很多,支持!
- m0_503499162023-11-12这个资源内容超赞,对我来说很有价值,很实用,感谢大佬分享~
- 2301_763135182024-04-25资源内容详细,总结地很全面,与描述的内容一致,对我启发很大,学习了。
盈梓的博客
- 粉丝: 6864
- 资源: 1248
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功