# -- coding: utf-8 --
# @Time : 2021/9/18 18:33
# @Author : LiDingZhao
# @Email : 1023822090@qq.com
# @File : main.py
# @Software: PyCharm
import os
import sys
import argparse
import datetime
import time
import os.path as osp
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import datasets
import models
from utils1 import AverageMeter, Logger
from center_loss import CenterLoss
parser = argparse.ArgumentParser("Center Loss Example")
# dataset
parser.add_argument('-d', '--dataset', type=str, default='birddataset', choices=['birddataset'])
parser.add_argument('-j', '--workers', default=0, type=int,
help="number of data loading workers (default: 0)")
# optimization
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--lr-model', type=float, default=0.001, help="learning rate for model")
parser.add_argument('--lr-cent', type=float, default=0.5, help="learning rate for center loss")
parser.add_argument('--weight-cent', type=float, default=1, help="weight for center loss")
parser.add_argument('--max-epoch', type=int, default=100)
parser.add_argument('--stepsize', type=int, default=20)
parser.add_argument('--gamma', type=float, default=0.5, help="learning rate decay")
# model
parser.add_argument('--model', type=str, default='densenet121',choices='LSTM')
# misc
parser.add_argument('--eval-freq', type=int, default=10)
parser.add_argument('--print-freq', type=int, default=50)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--use-cpu', action='store_true')
parser.add_argument('--save-dir', type=str, default='log')
parser.add_argument('--plot', action='store_true', help="whether to plot features for every epoch")
args = parser.parse_args()
def main():
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
sys.stdout = Logger(osp.join(args.save_dir, 'log_' + args.dataset + '.txt'))
if use_gpu:
print("Currently using GPU: {}".format(args.gpu))
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU")
print("Creating dataset: {}".format(args.dataset))
dataset = datasets.create(
name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
num_workers=args.workers,
)
trainloader, testloader = dataset.trainloader, dataset.testloader
print("Creating model: {}".format(args.model))
model = models.create(name=args.model)()
if use_gpu:
model = nn.DataParallel(model).cuda()
criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=dataset.num_classes, feat_dim=2, use_gpu=use_gpu)
optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)
if args.stepsize > 0:
scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma)
start_time = time.time()
for epoch in range(args.max_epoch):
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
train(model, criterion_xent, criterion_cent,
optimizer_model, optimizer_centloss,
trainloader, use_gpu, dataset.num_classes, epoch)
if args.stepsize > 0: scheduler.step()
if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch:
print("==> Test")
acc, err = test(model, testloader, use_gpu, dataset.num_classes, epoch)
print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err))
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
def train(model, criterion_xent, criterion_cent,
optimizer_model, optimizer_centloss,
trainloader, use_gpu, num_classes, epoch):
model.train()
xent_losses = AverageMeter()
cent_losses = AverageMeter()
losses = AverageMeter()
if args.plot:
all_features, all_labels = [], []
for batch_idx, (data, labels) in enumerate(trainloader):
if use_gpu:
data, labels = data.cuda(), labels.cuda()
features, outputs = model(data)
loss_xent = criterion_xent(outputs, labels)
loss_cent = criterion_cent(features, labels)
loss_cent *= args.weight_cent
loss = loss_xent + loss_cent
optimizer_model.zero_grad()
optimizer_centloss.zero_grad()
loss.backward()
optimizer_model.step()
for param in criterion_cent.parameters():
param.grad.data *= (1. / args.weight_cent)
optimizer_centloss.step()
losses.update(loss.item(), labels.size(0))
xent_losses.update(loss_xent.item(), labels.size(0))
cent_losses.update(loss_cent.item(), labels.size(0))
if args.plot:
if use_gpu:
all_features.append(features.data.cpu().numpy())
all_labels.append(labels.data.cpu().numpy())
else:
all_features.append(features.data.numpy())
all_labels.append(labels.data.numpy())
if (batch_idx+1) % args.print_freq == 0:
print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) CenterLoss {:.6f} ({:.6f})" \
.format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, cent_losses.val, cent_losses.avg))
if args.plot:
all_features = np.concatenate(all_features, 0)
all_labels = np.concatenate(all_labels, 0)
plot_features(all_features, all_labels, num_classes, epoch, prefix='train')
def test(model, testloader, use_gpu, num_classes, epoch):
model.eval()
correct, total = 0, 0
if args.plot:
all_features, all_labels = [], []
with torch.no_grad():
for data, labels in testloader:
if use_gpu:
data, labels = data.cuda(), labels.cuda()
features, outputs = model(data)
predictions = outputs.data.max(1)[1]
total += labels.size(0)
correct += (predictions == labels.data).sum()
if args.plot:
if use_gpu:
all_features.append(features.data.cpu().numpy())
all_labels.append(labels.data.cpu().numpy())
else:
all_features.append(features.data.numpy())
all_labels.append(labels.data.numpy())
if args.plot:
all_features = np.concatenate(all_features, 0)
all_labels = np.concatenate(all_labels, 0)
plot_features(all_features, all_labels, num_classes, epoch, prefix='test')
acc = correct * 100. / total
err = 100. - acc
return acc, err
def plot_features(features, labels, num_classes, epoch, prefix):
"""
Args:
features: (num_instances, num_features).
labels: (num_instances).
"""
colors = ['C0', 'C1', 'C2', 'C3', 'C4']
for label_idx in range(num_classes):
plt.scatter(
features[labels==label_idx, 0],
features[labels==label_idx, 1],
c=colors[label_idx],
s=1,
)
plt.legend(['0', '1', '2', '3', '4'], loc='upper right')
dirname = osp.join(args.save_dir, prefix)
if not osp.exists(dirname):
os.mkdir(dirname)
save_name = osp.join(dirname, 'epoch_' + str(epoch+1) + '.png')
plt.savefig(save_name, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
融合了CBAM与DenseNet121,使用中心损失函数及鸟声融合特征进行数据预处理、鸟声识别等任务.zip (8个子文件)
-Xeno-Canto--main
__init__.py 0B
center_loss.py 2KB
utils1.py 2KB
main.py 8KB
models.py 7KB
datasets.py 2KB
数据说明.xlsx 10KB
cbam.py 4KB
共 8 条
- 1
资源评论
博士僧小星
- 粉丝: 1931
- 资源: 5896
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功