import torch
import math
import os
import time
import copy
import numpy as np
from lib.logger import get_logger
from lib.metrics import All_Metrics
class Trainer(object):
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
super(Trainer, self).__init__()
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
if val_loader != None:
self.val_per_epoch = len(val_loader)
self.best_path = os.path.join(self.args.log_dir, 'best_model.pth')
self.loss_figure_path = os.path.join(self.args.log_dir, 'loss.png')
#log
if os.path.isdir(args.log_dir) == False and not args.debug:
os.makedirs(args.log_dir, exist_ok=True)
self.logger = get_logger(args.log_dir, name=args.model, debug=args.debug)
self.logger.info('Experiment log path in: {}'.format(args.log_dir))
#if not args.debug:
#self.logger.info("Argument: %r", args)
# for arg, value in sorted(vars(args).items()):
# self.logger.info("Argument %s: %r", arg, value)
def val_epoch(self, epoch, val_dataloader):
self.model.eval()
total_val_loss = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(val_dataloader):
data = data[..., :self.args.input_dim]
label = target[..., :self.args.output_dim]
output = self.model(data, target, teacher_forcing_ratio=0.)
if self.args.real_value:
label = self.scaler.inverse_transform(label)
loss = self.loss(output.cuda(), label)
#a whole batch of Metr_LA is filtered
if not torch.isnan(loss):
total_val_loss += loss.item()
val_loss = total_val_loss / len(val_dataloader)
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
return val_loss
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
data = data[..., :self.args.input_dim]
label = target[..., :self.args.output_dim] # (..., 1)
self.optimizer.zero_grad()
#teacher_forcing for RNN encoder-decoder model
#if teacher_forcing_ratio = 1: use label as input in the decoder for all steps
if self.args.teacher_forcing:
global_step = (epoch - 1) * self.train_per_epoch + batch_idx
teacher_forcing_ratio = self._compute_sampling_threshold(global_step, self.args.tf_decay_steps)
else:
teacher_forcing_ratio = 1.
#data and target shape: B, T, N, F; output shape: B, T, N, F
output = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio)
if self.args.real_value:
label = self.scaler.inverse_transform(label)
loss = self.loss(output.cuda(), label)
loss.backward()
# add max grad clipping
if self.args.grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
self.optimizer.step()
total_loss += loss.item()
#log information
if batch_idx % self.args.log_step == 0:
self.logger.info('Train Epoch {}: {}/{} Loss: {:.6f}'.format(
epoch, batch_idx, self.train_per_epoch, loss.item()))
train_epoch_loss = total_loss/self.train_per_epoch
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}, tf_ratio: {:.6f}'.format(epoch, train_epoch_loss, teacher_forcing_ratio))
#learning rate decay
if self.args.lr_decay:
self.lr_scheduler.step()
return train_epoch_loss
def train(self):
best_model = None
best_loss = float('inf')
not_improved_count = 0
train_loss_list = []
val_loss_list = []
start_time = time.time()
for epoch in range(1, self.args.epochs + 1):
epoch_time = time.time()
train_epoch_loss = self.train_epoch(epoch)
#print(time.time()-epoch_time)
#exit()
if self.val_loader == None:
val_dataloader = self.test_loader
else:
val_dataloader = self.val_loader
val_epoch_loss = self.val_epoch(epoch, val_dataloader)
#print('LR:', self.optimizer.param_groups[0]['lr'])
train_loss_list.append(train_epoch_loss)
val_loss_list.append(val_epoch_loss)
if train_epoch_loss > 1e6:
self.logger.warning('Gradient explosion detected. Ending...')
break
#if self.val_loader == None:
#val_epoch_loss = train_epoch_loss
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_state = True
else:
not_improved_count += 1
best_state = False
# early stop
if self.args.early_stop:
if not_improved_count == self.args.early_stop_patience:
self.logger.info("Validation performance didn\'t improve for {} epochs. "
"Training stops.".format(self.args.early_stop_patience))
break
# save the best state
if best_state == True:
self.logger.info('*********************************Current best model saved!')
best_model = copy.deepcopy(self.model.state_dict())
training_time = time.time() - start_time
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
#save the best model to file
if not self.args.debug:
torch.save(best_model, self.best_path)
self.logger.info("Saving current best model to " + self.best_path)
#test
self.model.load_state_dict(best_model)
#self.val_epoch(self.args.epochs, self.test_loader)
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
def save_checkpoint(self):
state = {
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'config': self.args
}
torch.save(state, self.best_path)
self.logger.info("Saving current best model to " + self.best_path)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
if path != None:
check_point = torch.load(path)
state_dict = check_point['state_dict']
args = check_point['config']
model.load_state_dict(state_dict)
model.to(args.device)
model.eval()
y_pred = []
y_true = []
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
data = data[..., :args.input_dim]
label = target[..., :args.output_dim]
output = model(data, target, teacher_forcing_ratio=0)
y_true.append(label)
y_pred.append(output)
y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
if args.real_value:
y_pred = torch.cat(y_pred, dim=0)
else:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
np.save('./{}_true.npy'.format(args.dataset), y_true.cpu().numpy())
np.
用于交通流预测的联合时空图卷积网络JSTGCN(python代码)(不带数据集)
版权申诉
124 浏览量
2022-06-23
14:37:58
上传
评论 1
收藏 34KB ZIP 举报
资源存储库
- 粉丝: 4502
- 资源: 392
最新资源
- 基于matlab实现夜间车牌识别程序(1).rar
- 基于matlab实现无线传感器网络无需测距定位算法matlab源代码 包括apit,dv-hop,amorphous在内的共7个
- 基于python的yolov5实现的旋转目标检测
- 基于matlab实现无线传感器网络 CAB定位仿真程序 这是无线传感器节点定位CAB算法的仿真程序,由matlab完成.rar
- 基于matlab实现图像处理,本程序使用背景差分法对来往车辆进行检测和跟踪.rar
- 基于matlab实现视频监控中车型识别代码,自己写的,希望和大家多多交流.rar
- springcodespringcodespringcodespringcode
- 基于matlab实现权值的MAXDEV无线传感器网络定位算法研究 MAXDEV 无线传感器 定位 算法.rar
- sdk.config
- 基于matlab实现配电网三相潮流计算方法,对几种常用的配电网潮流计算方法进行了对比分析.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈