# -*- coding: utf-8 -*-
import numpy
import matplotlib.pyplot as plt
from pandas import read_csv
import math
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.metrics import mean_squared_error
from keras.callbacks import ReduceLROnPlateau
from sklearn.preprocessing import MinMaxScaler
#载入数据
dataframe = read_csv('data.csv',encoding='utf-8' ,usecols=[1])
dataset = dataframe.values
plt.plot(dataset)#查看趋势
plt.show()
dataset_or = dataset
#为后续lstm的输入创建一个数据处理函数
#look_back为滑窗
def create_dataset(dataset, look_back=1):
dataX, dataY = [], []
for i in range(len(dataset)-look_back):
a = dataset[i:(i+look_back), 0]
dataX.append(a)
dataY.append(dataset[i + look_back, 0])
return numpy.array(dataX), numpy.array(dataY)
#设置随机种子,标准化数据
numpy.random.seed(7)
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
train = dataset
#设置时间滑窗,创建训练集
look_back = 7
trainX, trainY = create_dataset(train, look_back)
#对训练集x做reshape
trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
#搭建lstm网络
model = Sequential()
#输出节点为25,输入的每个样本长度为look_back
model.add(LSTM(25, input_shape=(1, look_back)))
#添加一个全连接层,输出维度为1
model.add(Dense(1))
#使用均方差做损失函数。优化器用adam
model.compile(loss='mean_squared_error', optimizer='adam')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='max')
#训练模型,100epoch,批次为1,每一个epoch显示一次日志,学习率动态减小
model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2, callbacks=[reduce_lr])
#预测
trainPredict = model.predict(trainX)#预测训练集
#反标准化
trainPredict = scaler.inverse_transform(trainPredict)
trainY = scaler.inverse_transform([trainY])
#testPredict = scaler.inverse_transform(testPredict)
#输出训练RMSE
trainScore = math.sqrt(mean_squared_error(trainY[0,:], trainPredict[:,0]))
print('Train Score: %.2f RMSE' % (trainScore))
#画图查看模型预测结果
trainPredictPlot = numpy.reshape(numpy.array([None]*(len(dataset)+7)),((len(dataset)+7),1))
trainPredictPlot[look_back:len(trainPredict)+look_back, :] = trainPredict
plt.plot(dataset_or,label='true')
plt.plot(trainPredictPlot,label='trainpredict')
plt.legend()
plt.show()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
COVID.zip (2个子文件)
data.csv 2KB
covid.py 3KB
共 2 条
- 1
资源评论
Substituteman
- 粉丝: 2
- 资源: 2
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功