# -*- coding: utf-8 -*-
# /usr/bin/env/python3
'''
Pytorch implementation for LPRNet.
Author: aiboy.wei@outlook.com .
'''
from load_data import CHARS, LPRDataLoader
from model.LPRNet import build_lprnet
from torch.autograd import Variable
from torch.utils.data import *
from torch import optim
import torch.nn as nn
import numpy as np
import argparse
import torch
import time
import os
import matplotlib.pyplot as plt
def sparse_tuple_for_ctc(T_length, lengths):
'''
get sparse tuple for ctc loss
'''
input_lengths = []
target_lengths = []
for ch in lengths:
input_lengths.append(T_length)
target_lengths.append(ch)
return tuple(input_lengths), tuple(target_lengths)
def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
"""
Sets the learning rate
"""
lr = 0
for i, e in enumerate(lr_schedule):
if cur_epoch < e:
lr = base_lr * (0.1 ** i)
break
if lr == 0:
lr = base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def get_parser():
parser = argparse.ArgumentParser(description='parameters to train net')
parser.add_argument('--max_epoch', default=15, help='epoch to train the network')
parser.add_argument('--img_size', default=(94, 24), help='the image size')
parser.add_argument('--train_img_dirs', default="train", help='the train images path')
parser.add_argument('--test_img_dirs', default="test", help='the test images path')
parser.add_argument('--dropout_rate', default=0.5, help='dropout rate.')
parser.add_argument('--learning_rate', default=0.1, help='base value of learning rate.')
parser.add_argument('--lpr_max_len', default=8, help='license plate number max length.')
parser.add_argument('--train_batch_size', default=128, type=int, help='training batch size.')
parser.add_argument('--test_batch_size', default=120, type=int, help='testing batch size.')
parser.add_argument('--phase_train', default=True, type=bool, help='train or test phase flag.')
parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading')
parser.add_argument('--cuda', default=False, type=bool, help='Use cuda to train model')
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
parser.add_argument('--interval', default=2000, type=int, help='interval for evaluate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=2e-5, type=float, help='Weight decay for SGD')
parser.add_argument('--lr_schedule', default=[4, 8, 12, 14, 16], help='schedule for learning rate.')
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
parser.add_argument('--pretrained_model', default='./weights/Final_LPRNet_model.pth', help='pretrained base model')
# parser.add_argument('--pretrained_model', default='', help='pretrained base model')
args = parser.parse_args()
return args
def collate_fn(batch):
'''
User-defined Batch output
'''
imgs = []
labels = []
lengths = []
for _, sample in enumerate(batch):
img, label, length = sample
imgs.append(torch.from_numpy(img))
labels.extend(label)
lengths.append(length)
labels = np.asarray(labels).flatten().astype(np.int)
return (torch.stack(imgs, 0), torch.from_numpy(labels), lengths)
def train():
args = get_parser()
T_length = 18 # args.lpr_max_len
epoch = 0 + args.resume_epoch
loss_val = 0
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
# build network
lprnet = build_lprnet(lpr_max_len=args.lpr_max_len, phase=args.phase_train, class_num=len(CHARS), dropout_rate=args.dropout_rate)
device = torch.device("cuda:0" if args.cuda else "cpu")
lprnet.to(device)
print("Successful to build network!")
# load pretrained model
if args.pretrained_model:
lprnet.load_state_dict(torch.load(args.pretrained_model,map_location=torch.device('cpu')))
print("load pretrained model successful!")
else:
def xavier(param):
nn.init.xavier_uniform(param)
def weights_init(m):
for key in m.state_dict():
if key.split('.')[-1] == 'weight':
if 'conv' in key:
nn.init.kaiming_normal_(m.state_dict()[key], mode='fan_out')
if 'bn' in key:
m.state_dict()[key][...] = xavier(1)
elif key.split('.')[-1] == 'bias':
m.state_dict()[key][...] = 0.01
lprnet.backbone.apply(weights_init)
lprnet.container.apply(weights_init)
print("initial net weights successful!")
# define optimizer
optimizer = optim.RMSprop(lprnet.parameters(), lr=args.learning_rate, alpha = 0.9, eps=1e-08,
momentum=args.momentum, weight_decay=args.weight_decay)
# prepare dataset
train_img_dirs = os.path.expanduser(args.train_img_dirs)
test_img_dirs = os.path.expanduser(args.test_img_dirs)
train_dataset = LPRDataLoader(train_img_dirs.split(','), args.img_size, args.lpr_max_len)
test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size, args.lpr_max_len)
epoch_size = len(train_dataset) // args.train_batch_size
max_iter = args.max_epoch * epoch_size
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean') # reduction: 'none' | 'mean' | 'sum'
if args.resume_epoch > 0:
start_iter = args.resume_epoch * epoch_size
else:
start_iter = 0
Tp = 0
Tn_1 = 0
Tn_2 = 0
Train_Accuracy=[]
Test_Accuracy=[]
Train_loss=[]
for iteration in range(start_iter, max_iter):
if iteration % epoch_size == 0:
# create batch iterator
batch_iterator = iter(DataLoader(train_dataset, args.train_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
loss_val = 0
epoch += 1
start_time = time.time()
# load train data
images, labels, lengths = next(batch_iterator)
# get ctc parameters
input_lengths, target_lengths = sparse_tuple_for_ctc(T_length, lengths)
# update lr
lr = adjust_learning_rate(optimizer, epoch, args.learning_rate, args.lr_schedule)
targets = []
start=0
for length in lengths:
label = labels[start:start+length]
targets.append(label)
start += length
targets = np.array([el.numpy() for el in targets])
if args.cuda:
images = Variable(images, requires_grad=False).cuda()
labels = Variable(labels, requires_grad=False).cuda()
else:
images = Variable(images, requires_grad=False)
labels = Variable(labels, requires_grad=False)
# forward
logits = lprnet(images)
log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C
log_probs = log_probs.log_softmax(2).requires_grad_()
# backprop
optimizer.zero_grad()
loss = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
if loss.item() == np.inf:
continue
loss.backward()
optimizer.step()
loss_val += loss.item()
end_time = time.time()
# calculate accuracy
prebs = logits.cpu().detach().numpy()
preb_labels = list()
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = list()
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = list()
pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
no_repeat_bla
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
构建LPRNet,基于CCPD数据集进行车牌检测和识别_License-plate-recognition.zip (30个子文件)
License-plate-recognition-main
load_data.py 3KB
test_LPRNet.py 6KB
weights
NotoSansCJK-Regular.ttc 17.88MB
Final_LPRNet_model.pth 1.72MB
.idea
codeStyles
codeStyleConfig.xml 153B
cv_big_project.iml 595B
misc.xml 207B
inspectionProfiles
Project_Default.xml 22KB
profiles_settings.xml 174B
modules.xml 287B
.gitignore 50B
locate_result
plate3-1_0.png 531KB
plate3-3_0.png 343KB
plate3-3_1.png 23KB
plate2-2_0.png 586KB
plate2-1_0.png 876KB
plate1-3_0.png 2.07MB
plate1-1_0.png 2.58MB
plate2-3_0.png 766KB
plate_location.py 4KB
model
LPRNet.py 4KB
__pycache__
LPRNet.cpython-38.pyc 3KB
train_LPRNet.py 12KB
images
2-2.jpg 2.77MB
3-2.jpg 7.48MB
2-3.jpg 5.88MB
3-1.jpg 4.56MB
3-3.jpg 7.07MB
2-1.jpg 3.11MB
__pycache__
load_data.cpython-38.pyc 3KB
共 30 条
- 1
资源评论
2401_87496566
- 粉丝: 853
- 资源: 3373
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 三维项目-通过迁移学习进行高质量的单目深度估计实现-优质项目实战.zip
- DllToShellCode DLLSHELL注入
- 《机器人SLAM导航》课件(完整版)-第1季:第5章-机器人主机
- 三维文物重建-基于NeRF实现的文物三维重建算法-附项目源码-优质项目实战.zip
- 光伏电站无功补偿优化及其混合权重法的多目标评价体系构建
- element element-UI离线包下载 版本2.15.6
- 三维视觉测距-基于SIFT特征匹配的双目立体视觉测距项目-优质项目分享.zip
- 超声波距离发送到蓝牙APP
- MATLAB实现BO-LSTM贝叶斯优化长短期记忆神经网络股票价格预测(含完整的程序和代码详解)
- MATLAB实现TCN-GRU时间卷积门控循环单元多输入单输出回归预测(含完整的程序和代码详解)
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功