'''导包'''
import pandas as pd
import numpy as np
import tqdm
import datetime
import os
import argparse
import random
import creat_data as Data
from def_function import *
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data
import logging
import sys
from config import Config
''' 主函数 args: 超参定义器,param logger: 日志句柄'''
def main(args, logger):
device = torch.device(args.device) # 指定运行设备
data_x, data_y, scalar = get_dataset(args.data_path, args.dataset_name, args.city_name, args.look_back)
'''
划分训练集和测试集的方法,但这么做会影响时序,最好还是利用新的测试集
'''
train_slice = int(data_x.shape[0] * 0.8) # 8比2划分训练集 切片 前80%的data_x定义为train_slice
train_dataset = Data.generate_dataset(data_x[:train_slice, :], data_y[:train_slice, :]) # 建立训练集数据库
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers,shuffle=False) # 批次训练
test_dataset = Data.generate_dataset(data_x[train_slice:, :], data_y[train_slice:, :]) # 建立测试集数据库
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) # 批次测试
#train_dataset = Data.generate_dataset(data_x, data_y) # 以整个导入的数据全部作为训练集
#train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers,shuffle=False) #不打乱数据 批次训练
feature_nums = args.look_back
model = get_model(args.model_name, feature_nums, args.hidden_dims, args.num_layers).to(device)
logger.info(model) # 日志
if args.loss_type == 'mse': #损失函数类型 mse和smoothl1loss
loss = nn.MSELoss()
elif args.loss_type == 'smoothl1loss':
loss = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # 优化器采用Adam
start_time = datetime.datetime.now() # 起始时间
train_epoch_loss = [] # 用数组存放每epoch的损失
for epoch_i in range(args.epochs):
torch.cuda.empty_cache() # 清理无用的cuda中间变量缓存
train_average_loss = train(model, optimizer, train_data_loader, loss, device) # 平均损失
train_epoch_loss.append(train_average_loss)
train_end_time = datetime.datetime.now() # 训练时间
if epoch_i % args.print_interval == 0: # 每print_interval,打印一次结果到日志中
logger.info('City {}, model {}, epoch {}, train_{}_loss {}, '
'[{}s]'.format(args.city_name, args.model_name, epoch_i,
args.loss_type, train_average_loss, (train_end_time - start_time).seconds))
torch.save(model.state_dict(), os.path.join(args.save_param_dir,
args.model_name + '_best_' + args.loss_type + '.pth')) # 存储参数
'''valid_dataset = Data.generate_dataset(data_x,data_y) # 验证集数据
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers)'''
# submission
#valid_loss, preds = sub(model, valid_data_loader, loss, device)
'''
如果从训练集中划分了测试集,这里可以改为如下形式,test_data_loader为上面第146行定义的变量
'''
valid_loss, preds = sub(model, test_data_loader, loss, device)
output_dict = {}
output_dict.setdefault(args.loss_type + '_loss', train_epoch_loss)
output_dict.setdefault(args.loss_type + '_preds', (preds * scalar).tolist())
return output_dict
if __name__ == '__main__': #模拟程序入口 使用的是多线程导入数据,用 __name__这样写不会报错
parser = argparse.ArgumentParser() #创建解析器 解析config.py
parser.add_argument('--data_path', default=Config.data_path) #也可使用绝对路径,默认是config中的相对路径
parser.add_argument('--dataset_name', default='yjz.csv', help='dataset')
parser.add_argument('--model_name', default=Config.model_name, help='LR, RNN, MLP')
parser.add_argument('--num_workers', default=Config.num_workers, help='4, 8, 16, 32')
parser.add_argument('--hidden_dims', default=Config.hidden_dims)
parser.add_argument('--num_layers', default=Config.num_layers, help='1, 2')
parser.add_argument('--look_back', default=Config.look_back, help='以几行数据为特征维度数量')
parser.add_argument('--city_name', default='yjz5', help='Zone 1,Zone 2,Zone 3')
parser.add_argument('--epochs', type=int, default=Config.epochs)
parser.add_argument('--lr', type=float, default=Config.lr)
parser.add_argument('--weight_decay', type=float, default=Config.weight_decay)
parser.add_argument('--batch_size', type=int, default=Config.batch_size)
parser.add_argument('--print_interval', type=int, default=Config.print_interval)
parser.add_argument('--device', default=Config.device)
parser.add_argument('--loss_type', type=str, default='mse', help='smoothl1loss')
parser.add_argument('--save_log_dir', default=Config.save_log_dir)
parser.add_argument('--save_res_dir', default=Config.save_res_dir)
parser.add_argument('--save_param_dir', default=Config.save_param_dir)
args = parser.parse_args()
# 设置随机数种子
setup_seed(Config.seed)
#创建保存结果文件夹
if not os.path.exists(args.save_log_dir):
os.mkdir(args.save_log_dir)
if not os.path.exists(args.save_res_dir):
os.mkdir(args.save_res_dir)
if not os.path.exists(args.save_param_dir):
os.mkdir(args.save_param_dir)
logging.basicConfig(level=logging.DEBUG,
filename=args.save_log_dir + args.model_name + '_output.log',
datefmt='%Y/%m/%d %H:%M:%S',
format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)d - %(module)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
logger.info('===> start training! ')
# 将各个州的数据保存到字典中
current_city_output_dicts = dict.fromkeys(tuple('Zone 2,Zone 3'.split(','))) #Zone2 3 也可加上 'Zone 1,Zone 2,Zone 3'
for city_name in current_city_output_dicts.keys():
logger.info('===> now excuate the city {} '.format(city_name))
args.city_name = city_name
current_city_output_dict = {}
for loss_type in ['mse', 'smoothl1loss']:
args.loss_type = loss_type
output_dict = main(args, logger) # 主运行函数
for key in output_dict.keys():
current_city_output_dict.setdefault(key, output_dict[key])
current_city_output_dicts[city_name] = current_city_output_dict
# 存储loss和preds
for loss_type in ['mse', 'smoothl1loss']:
city_preds = {}
city_losses = {}
for city_name in current_city_output_dicts.keys():
current_city_output_loss = current_city_output_dicts[city_name][loss_type + '_loss']
current_city_output_pred = current_city_output_dicts[city_name][loss_type + '_preds']
city_losses.setdefault(city_name, current_city_output_loss)
city_preds.setdefault(city_name, current_city_output_pred)
city_loss_df = pd.DataFrame(data=city_losses)
city_loss_df.to_csv(args.save_res_dir + '/ele_' + loss_type + '_losses_' + args.model_name + '.csv', index=None)
没有合适的资源?快使用搜索试试~ 我知道了~
LSTM时间序列预测 LSTM时间序列预测
共120个文件
csv:60个
jpg:24个
xml:10个
需积分: 4 12 下载量 59 浏览量
2023-07-01
21:02:55
上传
评论
收藏 5.32MB RAR 举报
温馨提示
LSTM时间序列预测 LSTM时间序列预测
资源推荐
资源详情
资源评论
收起资源包目录
LSTM时间序列预测 LSTM时间序列预测 (120个子文件)
Tetuan City power consumption.csv 4.03MB
算例2 步长2 ele_smoothl1loss_preds_RNN.csv 390KB
ele_mse_preds_RNN.csv 389KB
算例2 步长2 ele_mse_preds_RNN.csv 389KB
算例2 步长4 ele_mse_preds_RNN.csv 389KB
ele_smoothl1loss_preds_RNN.csv 389KB
算例2 步长4 ele_smoothl1loss_preds_RNN.csv 389KB
步长6 ele_mse_preds_RNN.csv 200KB
步长2 ele_smoothl1loss_preds_RNN.csv 200KB
LSTM步长2 ele_mse_preds_RNN.csv 199KB
步长2 ele_mse_preds_RNN.csv 199KB
LSTM步长2 ele_smoothl1loss_preds_RNN.csv 199KB
LSTM步长4 ele_mse_preds_RNN.csv 199KB
步长6 ele_smoothl1loss_preds_RNN.csv 199KB
步长4 ele_mse_preds_RNN.csv 199KB
步长4 ele_smoothl1loss_preds_RNN.csv 199KB
LSTM步长6 ele_mse_preds_RNN.csv 199KB
el_smoothl1loss_preds_RNN.csv 199KB
el_mse_preds_RNN.csv 199KB
LSTM步长6 ele_smoothl1loss_preds_RNN.csv 199KB
LSTM步长4 ele_smoothl1loss_preds_RNN.csv 199KB
yjz.csv 184KB
算例2 步长4 ele_smoothl1loss_losses_RNN.csv 2KB
ele_mse_losses_RNN.csv 2KB
ele_smoothl1loss_losses_RNN.csv 2KB
算例2 步长2 ele_smoothl1loss_losses_RNN.csv 2KB
算例2 步长2 ele_mse_losses_RNN.csv 2KB
算例2 步长4 ele_mse_losses_RNN.csv 2KB
步长6 ele_mse_losses_RNN.csv 1KB
步长4 ele_mse_losses_RNN.csv 1KB
步长4 ele_smoothl1loss_losses_RNN.csv 1KB
步长2 ele_mse_losses_RNN.csv 1KB
LSTM步长4 ele_smoothl1loss_losses_RNN.csv 1KB
步长2 ele_smoothl1loss_losses_RNN.csv 1KB
步长6 ele_smoothl1loss_losses_RNN.csv 1KB
LSTM步长6 ele_smoothl1loss_losses_RNN.csv 1KB
LSTM步长2 ele_smoothl1loss_losses_RNN.csv 1KB
LSTM步长2 ele_mse_losses_RNN.csv 1KB
LSTM步长4 ele_mse_losses_RNN.csv 1KB
LSTM步长6 ele_mse_losses_RNN.csv 1KB
算例2 步长2 smoothl1loss_RNN.csv 65B
算例2 步长4 mse_RNN.csv 65B
smoothl1loss_RNN.csv 65B
算例2 步长2 mse_RNN.csv 64B
算例2 步长4 smoothl1loss_RNN.csv 64B
mse_RNN.csv 63B
步长4 smoothl1loss_RNN.csv 36B
步长6 mse_RNN.csv 35B
步长6 smoothl1loss_RNN.csv 35B
步长2 mse_RNN.csv 35B
步长2 smoothl1loss_RNN.csv 35B
步长4 mse_RNN.csv 32B
LSTM步长2 smoothl1loss_RNN.csv 32B
LSTM步长2 mse_RNN.csv 32B
LSTM步长6 mse_RNN.csv 32B
LSTM步长6 smoothl1loss_RNN.csv 32B
LSTM步长4 smoothl1loss_RNN.csv 32B
LSTM步长4 mse_RNN.csv 32B
el_smoothl1loss_losses_RNN.csv 32B
el_mse_losses_RNN.csv 31B
.gitignore 50B
.gitignore 50B
lstm用工业用电量预测.iml 506B
LSTM-大号-2023-7-1.iml 491B
Zone 2_mse.jpg 42KB
Zone 2_smoothl1loss.jpg 42KB
算例2 步长2 Zone 2_mse.jpg 42KB
算例2 步长2Zone 2_smoothl1loss.jpg 42KB
算例2 步长4 Zone 2_mse.jpg 41KB
算例2 步长4 Zone 2_smoothl1loss.jpg 41KB
算例2 步长2Zone 3_smoothl1loss.jpg 40KB
Zone 3_mse.jpg 40KB
Zone 3_smoothl1loss.jpg 40KB
算例2 步长4 Zone 3_smoothl1loss.jpg 40KB
算例2 步长4 Zone 3_mse.jpg 40KB
算例2 步长2 Zone 3_mse.jpg 40KB
步长4 Zone 1_mse.jpg 39KB
步长4 Zone 1_smoothl1loss.jpg 39KB
步长2 Zone 1_smoothl1loss.jpg 39KB
LSTM步长6 Zone 1_smoothl1loss.jpg 39KB
步长2 Zone 1_mse.jpg 39KB
LSTM步长4 Zone 1_mse.jpg 39KB
LSTM步长2 Zone 1_smoothl1loss.jpg 39KB
步长6 Zone 1_smoothl1loss.jpg 39KB
LSTM步长4 Zone 1_smoothl1loss.jpg 39KB
步长6 Zone 1_mse.jpg 39KB
LSTM步长6 Zone 1_mse.jpg 39KB
LSTM步长2 Zone 1_mse.jpg 39KB
RNN_plot.log 5.15MB
RNN_output.log 36KB
lstmmoxing 528KB
RNN_best_smoothl1loss.pth 13KB
RNN_best_mse.pth 13KB
main.py 8KB
valid.py 6KB
def_function.py 5KB
plot_main.py 4KB
lstm.py 4KB
Model.py 2KB
config.py 2KB
共 120 条
- 1
- 2
资源评论
程序员奇奇
- 粉丝: 3w+
- 资源: 294
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功