import argparse
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from models import *
from data_loader import data_loader
from helper import AverageMeter, save_checkpoint, accuracy, adjust_learning_rate
model_names = [
'alexnet', 'squeezenet1_0', 'squeezenet1_1', 'densenet121',
'densenet169', 'densenet201', 'densenet201', 'densenet161',
'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19', 'vgg19_bn', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152'
]
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='alexnet', choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: alexnet)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='numer of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful to restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N',
help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='Weight decay (default: 1e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-m', '--pin-memory', dest='pin_memory', action='store_true',
help='use pin memory')
parser.add_argument('-p', '--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--print-freq', '-f', default=10, type=int, metavar='N',
help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoitn, (default: None)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
best_prec1 = 0.0
def main():
global args, best_prec1
args = parser.parse_args()
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
if args.arch == 'alexnet':
model = alexnet(pretrained=args.pretrained)
elif args.arch == 'squeezenet1_0':
model = squeezenet1_0(pretrained=args.pretrained)
elif args.arch == 'squeezenet1_1':
model = squeezenet1_1(pretrained=args.pretrained)
elif args.arch == 'densenet121':
model = densenet121(pretrained=args.pretrained)
elif args.arch == 'densenet169':
model = densenet169(pretrained=args.pretrained)
elif args.arch == 'densenet201':
model = densenet201(pretrained=args.pretrained)
elif args.arch == 'densenet161':
model = densenet161(pretrained=args.pretrained)
elif args.arch == 'vgg11':
model = vgg11(pretrained=args.pretrained)
elif args.arch == 'vgg13':
model = vgg13(pretrained=args.pretrained)
elif args.arch == 'vgg16':
model = vgg16(pretrained=args.pretrained)
elif args.arch == 'vgg19':
model = vgg19(pretrained=args.pretrained)
elif args.arch == 'vgg11_bn':
model = vgg11_bn(pretrained=args.pretrained)
elif args.arch == 'vgg13_bn':
model = vgg13_bn(pretrained=args.pretrained)
elif args.arch == 'vgg16_bn':
model = vgg16_bn(pretrained=args.pretrained)
elif args.arch == 'vgg19_bn':
model = vgg19_bn(pretrained=args.pretrained)
elif args.arch == 'resnet18':
model = resnet18(pretrained=args.pretrained)
elif args.arch == 'resnet34':
model = resnet34(pretrained=args.pretrained)
elif args.arch == 'resnet50':
model = resnet50(pretrained=args.pretrained)
elif args.arch == 'resnet101':
model = resnet101(pretrained=args.pretrained)
elif args.arch == 'resnet152':
model = resnet152(pretrained=args.pretrained)
else:
raise NotImplementedError
# use cuda
model.cuda()
# model = torch.nn.parallel.DistributedDataParallel(model)
# define loss and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionlly resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# cudnn.benchmark = True
# Data loading
train_loader, val_loader = data_loader(args.data, args.batch_size, args.workers, args.pin_memory)
if args.evaluate:
validate(val_loader, model, criterion, args.print_freq)
return
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch, args.lr)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
# evaluate on validation set
prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)
# remember the best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict()
}, is_best, args.arch + '.pth')
def train(train_loader, model, criterion, optimizer, epoch, print_freq):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(async=True)
input = input.cuda(async=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec1[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
这是一个基于Pytorch实现的Densenet资源,Densenet是一种密集连接的深度神经网络模型,它在图像分类和目标检测等计算机视觉任务中表现出色。该资源提供了一个经过训练的Densenet模型,可以用于图像分类任务。该模型具有较高的准确率和泛化能力,并且在处理大规模数据集时具有较好的性能。此外,该资源还提供了预训练的权重文件,用户可以直接加载这些权重文件并在自己的项目中使用已经训练好的Densenet模型,从而节省了训练时间和计算资源。无论是对于初学者还是有经验的研究者,这个基于Pytorch实现的Densenet资源都是一个非常有价值的工具,可以帮助他们快速构建和训练高性能的图像分类模型。
资源推荐
资源详情
资源评论
收起资源包目录
densenet-pytorch.zip (5个子文件)
helper.py 1KB
main.py 9KB
data_loader.py 1KB
models
__init__.py 310B
densenet.py 8KB
共 5 条
- 1
资源评论
Mr.Winter`
- 粉丝: 15w+
- 资源: 28
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功