import argparse
#from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataset.data import Data
import socket
from datetime import datetime
import os
from model.BaseNet import CPFNet
import torch
from tensorboardX import SummaryWriter
import tqdm
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from PIL import Image
import utils.utils as u
import utils.loss as LS
from utils.config import DefaultConfig
import torch.backends.cudnn as cudnn
def train(args, model, optimizer,criterion, dataloader_train):
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(args.log_dirs, current_time + '_' + socket.gethostname())
writer = SummaryWriter(log_dir=log_dir)
step = 0
for epoch in range(args.num_epochs):
lr = u.adjust_learning_rate(args,optimizer,epoch)
model.train()
tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
tq.set_description('epoch %d, lr %f' % (epoch, lr))
loss_record = []
train_loss=0.0
for i,(data, label) in enumerate(dataloader_train):
if torch.cuda.is_available() and args.use_gpu:
data = data.cuda()
label = label.cuda().long()
optimizer.zero_grad()
aux_out,main_out = model(data)
loss_aux=F.nll_loss(aux_out,label,weight=None)
loss_main= criterion[1](main_out, label)
loss =loss_aux+loss_main
loss.backward()
optimizer.step()
tq.update(args.batch_size)
train_loss += loss.item()
tq.set_postfix(loss='%.6f' % (train_loss/(i+1)))
step += 1
if step%10==0:
writer.add_scalar('Train/loss_step', loss, step)
loss_record.append(loss.item())
tq.close()
loss_train_mean = np.mean(loss_record)
writer.add_scalar('Train/loss_epoch', float(loss_train_mean), epoch)
print('loss for train : %f' % (loss_train_mean))
checkpoint_dir = args.save_model_path
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint_latest =os.path.join(checkpoint_dir, 'checkpoint_latest.pth')
u.save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict()
}, epoch, checkpoint_dir, filename=checkpoint_latest)
def test(model,dataloader, args, save_path):
print('start test!')
with torch.no_grad():
model.eval()
# precision_record = []
tq = tqdm.tqdm(dataloader,desc='\r')
tq.set_description('test')
comments=os.getcwd().split('\\')[-1]
for i, (data, label_path) in enumerate(tq):
if torch.cuda.is_available() and args.use_gpu:
data = data.cuda()
# label = label.cuda()
aux_pred,predict = model(data)
predict=torch.argmax(torch.exp(predict),dim=1)
pred=predict.data.cpu().numpy()
pred_RGB=Data.COLOR_DICT[pred.astype(np.uint8)]
for index,item in enumerate(label_path):
img=Image.fromarray(pred_RGB[index].squeeze().astype(np.uint8))
_, name = os.path.split(item)
img.save(os.path.join(save_path, name))
# tq.set_postfix(str=str(save_img_path))
tq.close()
def main(mode='train',args=None):
# create dataset and dataloader
dataset_path = args.data
dataset_train = Data(os.path.join(dataset_path,'train'), scale=(args.crop_width, args.crop_height),mode='train')
dataloader_train = DataLoader(
dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True
)
# build model
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
#load model
model_all={'BaseNet': CPFNet(out_planes=args.num_classes)}
model=model_all[args.net_work]
print(args.net_work)
cudnn.benchmark = True
if torch.cuda.is_available() and args.use_gpu:
model = torch.nn.DataParallel(model).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
criterion_aux=nn.NLLLoss(weight=None)
criterion_main=LS.Multi_DiceLoss(class_num=args.num_classes)
criterion=[criterion_aux,criterion_main]
if mode=='train':
train(args, model, optimizer,criterion, dataloader_train)
if __name__ == '__main__':
seed=1234
torch.manual_seed(seed) # 固定初始化
torch.cuda.manual_seed_all(seed)
args=DefaultConfig()
modes = 'train'
if modes=='train':
main(mode='train', args=args)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
收起资源包目录
![package](https://csdnimg.cn/release/downloadcmsfe/public/img/package.f3fc750b.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/ZIP.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/PNG.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/ZIP.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/PNG.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/PNG.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
共 8 条
- 1
资源评论
![avatar-default](https://csdnimg.cn/release/downloadcmsfe/public/img/lazyLogo2.1882d7f4.png)
![avatar](https://profile-avatar.csdnimg.cn/default.jpg!1)
唐先生的博客
- 粉丝: 3262
- 资源: 632
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助
![voice](https://csdnimg.cn/release/downloadcmsfe/public/img/voice.245cc511.png)
![center-task](https://csdnimg.cn/release/downloadcmsfe/public/img/center-task.c2eda91a.png)
安全验证
文档复制为VIP权益,开通VIP直接复制
![dialog-icon](https://csdnimg.cn/release/downloadcmsfe/public/img/green-success.6a4acb44.png)