from __future__ import print_function
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from model import train_model
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='CNN_BiLSTM', choices=['CNN_BiLSTM_Attn', 'CNN_BiLSTM'])
parser.add_argument('--epochs', type=int, default=600, help="Epochs")
parser.add_argument('--batch_size', type=int, default=256, help="Batch Size")
parser.add_argument('--learning_rate', type=float, default=0.001, help="Learning rate")
parser.add_argument('--sequence_length', type=int, default=10, help="sequence length")
parser.add_argument('--dataset', type=str, default='all', choices=['cluster1', 'cluster2', 'cluster3', 'all'])
parser.add_argument('--path', type=str, default='./results/')
args = parser.parse_args()
if not os.path.exists(args.path):
os.mkdir(args.path)
def convertSeriesToMatrix(vectorSeries, sequence_length):
"""
滑动时间窗口处理
"""
matrix = []
for i in range(len(vectorSeries) - sequence_length + 1):
matrix.append(vectorSeries[i: i + sequence_length])
return matrix
if args.dataset in {'cluster1', 'cluster2', 'cluster3'}:
df = pd.read_csv(args.dataset + '.csv')
wind = np.array(df)
elif args.dataset == 'all':
X_files = os.listdir(os.getcwd() + '/dataset')
if '.gitignore' in X_files: X_files.remove('.gitignore')
X_files = np.array(X_files)
wind = []
for filename in X_files:
df = pd.read_csv(os.getcwd() + '/dataset/' + filename)
data = df.values
wind.append(data[:, 0])
wind = np.sum(wind, axis=0).reshape(-1, 1)
list_hourly_data = [wind[i] for i in range(0, wind.shape[0]) if i % 12 == 0]
wind_data = np.array(list_hourly_data)
# min-max归一化处理
MAX_MIN = []
for i in range(wind_data.shape[1]): # 取每一列
temp_max = np.max(wind_data[:, i])
temp_min = np.min(wind_data[:, i]) # 第i列最小值
wind_data[:, i] = (wind_data[:, i] - temp_min) / (temp_max - temp_min) # 归一化
# wind_data[:, i] = wind_data[:, i] / temp_max # 归一化
MAX_MIN.append([temp_max, temp_min])
MAX_NUM = MAX_MIN[0][0]
MIN_NUM = MAX_MIN[0][1]
matrix_load = convertSeriesToMatrix(wind_data, args.sequence_length) # 滑动时间窗口,窗口长度可以改
matrix_load = np.array(matrix_load).astype(np.float32) # matrix_load转为ndarray
# 分为训练集与测试集,训练集占90%
train_row = int(round(0.9 * matrix_load.shape[0]))
train_set = matrix_load[:train_row, :]
np.random.shuffle(train_set) # 打乱训练集,测试集不用打乱
X_train = train_set[:, :-1] # 训练集
y_train = train_set[:, -1, 0].reshape(-1, 1) # 训练集label
X_test = matrix_load[train_row:, :-1] # 测试集
y_test = matrix_load[train_row:, -1, 0].reshape(-1, 1) # 测试集label
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1] * X_train.shape[2], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1] * X_test.shape[2], 1))
# =====================================
# ========= 载入模型,开始训练 =========
# =====================================
input_shape = X_train.shape[1:]
model = train_model(args.model_name, input_shape=input_shape)
model.summary() # print model 结构
history = model.fit(X_train, y_train, batch_size=args.batch_size, epochs=args.epochs, validation_split=0.05, verbose=1) # 开始训练
# model.save(args.path + args.model_name + '_' + args.dataset + 'model.h5') # 保存模型
# =====================================
# ============ 训练结果评估 ============
# =====================================
# 预测测试集
predicted_values = model.predict(X_test)
num_test_samples = len(predicted_values)
predicted_values = np.reshape(predicted_values, (num_test_samples, 1))
# 评估训练结果
test_mse = model.evaluate(X_test, y_test, verbose=1)
print('\nThe mean squared error (MSE) on the test data set is %.3f over %d test samples.' % (test_mse[0], len(y_test)))
# 反归一化
predicted_values = predicted_values * (MAX_NUM - MIN_NUM) + MIN_NUM
y_test = y_test * (MAX_NUM - MIN_NUM) + MIN_NUM
y_train = y_train * (MAX_NUM - MIN_NUM) + MIN_NUM
# 画图
fig = plt.figure(figsize=(15, 5), dpi=600)
plt.plot(y_test)
plt.plot(predicted_values)
plt.xlabel('Hour')
plt.ylabel('Wind Power')
# plt.show()
fig.savefig(args.path + args.model_name + '_' + args.dataset + '_output.png', bbox_inches='tight')
# MSE 损失函数
fig = plt.figure(figsize=(15, 5), dpi=600)
plt.plot(history.history['loss'], label='training loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.title('model loss')
plt.ylabel('MSE Loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
fig.savefig(args.path + args.model_name + '_' + args.dataset + '_loss.png')
# MAE
fig = plt.figure(figsize=(15, 5), dpi=600)
plt.plot(history.history['mae'], label='training mae')
plt.plot(history.history['val_mae'], label='val mae')
plt.title('model MAE')
plt.ylabel('MAE')
plt.xlabel('epoch')
plt.legend(loc='upper right')
fig.savefig(args.path + args.model_name + '_' + args.dataset + '_mae.png')
# RMSE
fig = plt.figure(figsize=(15, 5), dpi=600)
plt.plot(history.history['root_mean_squared_error'], label='training rmse')
plt.plot(history.history['val_root_mean_squared_error'], label='val rmse')
plt.title('model RMSE')
plt.ylabel('RMSE')
plt.xlabel('epoch')
plt.legend(loc='upper right')
fig.savefig(args.path + args.model_name + '_' + args.dataset + '_rmse.png')
# 将预测结果保存成csv文件
data = np.hstack((predicted_values, y_test))
df = pd.DataFrame(data, columns=['predicted', 'real'])
df.to_csv(args.path + args.model_name + '_' + args.dataset + '_result.csv')
# MSE loss 和 MAE 保存成csv文件
df = pd.DataFrame(history.history)
df.to_csv(args.path + args.model_name + '_' + args.dataset + '_loss.csv')
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于高斯混合模型聚类CNN-BiLSTM-Attention风电场短期功率预测方法(Python和Matlab代码实现) 基于高斯混合模型聚类CNN-BiLSTM-Attention风电场短期功率预测方法(Python和Matlab代码实现) 基于高斯混合模型聚类CNN-BiLSTM-Attention风电场短期功率预测方法(Python和Matlab代码实现)
资源推荐
资源详情
资源评论
收起资源包目录
高斯混合模型聚类的风电场短期功率预测.zip (74个子文件)
Attention_utils.py 2KB
main.py 6KB
cluster2.csv 1.4MB
dataset
site_122081_attributes.csv 2.24MB
site_122449_attributes.csv 2.07MB
site_124893_attributes.csv 2.24MB
site_124681_attributes.csv 2.11MB
site_123145_attributes.csv 2.13MB
site_122907_attributes.csv 2.09MB
site_122237_attributes.csv 2.19MB
site_125565_attributes.csv 2.29MB
site_122455_attributes.csv 2.08MB
site_122768_attributes.csv 2.1MB
site_120561_attributes.csv 2.12MB
site_122083_attributes.csv 2.22MB
site_125079_attributes.csv 2.19MB
site_123556_attributes.csv 2.14MB
site_125971_attributes.csv 2.36MB
site_123188_attributes.csv 2.24MB
site_121856_attributes.csv 2.25MB
site_120270_attributes.csv 2.13MB
site_121076_attributes.csv 2.14MB
site_123862_attributes.csv 2.12MB
site_124643_attributes.csv 2.07MB
site_125743_attributes.csv 2.17MB
site_121161_attributes.csv 2.26MB
site_121636_attributes.csv 2.26MB
site_121704_attributes.csv 2.14MB
site_122369_attributes.csv 2.1MB
site_121432_attributes.csv 2.31MB
site_124268_attributes.csv 2.12MB
site_125346_attributes.csv 2.26MB
site_123000_attributes.csv 2.41MB
site_124646_attributes.csv 2.37MB
site_120750_attributes.csv 2.09MB
model.py 2KB
result_analysis.m 2KB
gmm_bic.png 11KB
gmm_main.m 2KB
AICC.m 569B
cluster3.csv 1.39MB
data_process.m 3KB
results
CNN_BiLSTM_Attn_cluster2_output.png 1007KB
CNN_BiLSTM_Attn_all_loss.png 360KB
CNN_BiLSTM_Attn_all_rmse.png 442KB
CNN_BiLSTM_Attn_cluster2_loss.csv 74KB
CNN_BiLSTM_Attn_cluster1_mae.png 535KB
CNN_BiLSTM_all_result.csv 21KB
CNN_BiLSTM_Attn_cluster1_loss.csv 75KB
CNN_BiLSTM_Attn_all_mae.png 478KB
CNN_BiLSTM_Attn_cluster3_output.png 1MB
CNN_BiLSTM_Attn_cluster1_loss.png 362KB
CNN_BiLSTM_Attn_all_loss.csv 75KB
CNN_BiLSTM_all_output.png 942KB
CNN_BiLSTM_Attn_cluster3_result.csv 19KB
CNN_BiLSTM_Attn_cluster1_rmse.png 491KB
CNN_BiLSTM_Attn_cluster3_mae.png 441KB
CNN_BiLSTM_all_loss.csv 74KB
CNN_BiLSTM_Attn_cluster2_result.csv 19KB
CNN_BiLSTM_Attn_cluster2_rmse.png 469KB
CNN_BiLSTM_Attn_cluster3_loss.csv 74KB
CNN_BiLSTM_Attn_cluster1_output.png 892KB
CNN_BiLSTM_all_rmse.png 482KB
CNN_BiLSTM_Attn_all_output.png 896KB
CNN_BiLSTM_Attn_all_result.csv 21KB
CNN_BiLSTM_Attn_cluster1_result.csv 19KB
CNN_BiLSTM_all_mae.png 533KB
CNN_BiLSTM_Attn_cluster3_loss.png 361KB
CNN_BiLSTM_Attn_cluster2_loss.png 371KB
CNN_BiLSTM_Attn_cluster2_mae.png 518KB
CNN_BiLSTM_Attn_cluster3_rmse.png 451KB
CNN_BiLSTM_all_loss.png 382KB
k_means_sc.png 15KB
cluster1.csv 1.39MB
共 74 条
- 1
资源评论
- m0_748355872023-11-13感谢大佬分享的资源给了我灵感,果断支持!感谢分享~
前程算法屋
- 粉丝: 4156
- 资源: 710
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功