"""
LSTM-长短期记忆网络
Datetime:2020-9-8
purpose:预测20世纪70年代中期波士顿郊区房屋价格的中位数。使用boston_housing数据集
Data:404个训练样本和102个测试样本,包含地区的犯罪率、当地房产税率等13个特征量
Author:kris.wang
"""
from keras.models import Sequential
from keras.layers import Dense,Dropout,Activation
from keras.layers import Embedding
from keras.layers import SimpleRNN,LSTM,GRU
from keras.optimizers import SGD,Nadam,Adam,RMSprop
from keras.callbacks import TensorBoard
from keras.utils import np_utils
import scipy.io as si
import os
import numpy as np
import matplotlib.pyplot as plt
#1,导入数据
# path=os.getcwd().replace('smallprojectsummary\LSTMproject','')
data=si.loadmat('sp1s_aa_1000Hz.mat')
# y_test=np.loadtxt(path+r'BCI_II_data\')
#预处理数据
#316个样本,28个电极通道,500毫秒数据。
data_train=data['x_train'].reshape((316,500,28))
data_train/=200
data_train=data_train.astype('float32')
data_test=data['x_test'].reshape((100,500,28))
data_test/=200
data_test=data_test.astype('float32')
data_label=data['y_train'].reshape((316,1))
tmp_train = []
for i in data_label:
if i == 1:
tmp_train.append(1)
elif i == 0:
tmp_train.append(-1)
data_label = np.array(tmp_train)
data_label = np_utils.to_categorical(data_label, 2)
data_label = data_label.astype('float32')
# print(data_label.shape)
#训练模型
def build_model():
model=Sequential()
model.add(LSTM(10,return_sequences=True,input_shape=(500,28)))
model.add(LSTM(10,return_sequences=True))
model.add(LSTM(5))
model.add(Dense(2,activation='softmax'))
model.summary()
model.compile(
optimizer=Nadam(lr=0.001),
loss='categorical_crossentropy',
metrics=['acc']
)
# model.summary()#显示模型结构。
return model
# score,acc=model.evaluate()
# k折交叉验证
k=4
num_samples=len(data_train)//k
all_acc_histories=[]
for i in range(k):
print('processing fold #',i)
#准备验证数据
val_data=data_train[i*num_samples:(i+1)*num_samples]
val_targets=data_label[i*num_samples:(i+1)*num_samples]
#准备训练数据
partial_train_data=np.concatenate(
[data_train[:i*num_samples],
data_train[(i+1)*num_samples:]],
axis=0
)
partial_train_targets=np.concatenate(
[data_label[:i*num_samples],
data_label[(i+1)*num_samples:]],
axis=0
)
#构建模型
model=build_model()
history=model.fit(partial_train_data,partial_train_targets,validation_data=(val_data,val_targets),epochs=80,batch_size=20)#verbose静默模式
val_los,val_acc=model.evaluate(val_data,val_targets,verbose=0)#给出损失值和目标值
#记录每个轮次中所有折mae
acc_history=history.history['val_acc']
all_acc_histories.append(acc_history)#验证分数,即mae
#计算所有轮次中的K折验证分数平均值
average_acc_history=[np.mean([x[i] for x in all_acc_histories]) for i in range(80)]
#绘制验证分数
def smooth_curve(points,factor=0.9):
smoothed_points=[]
for point in points:
if smoothed_points:
previous=smoothed_points[-1]
smoothed_points.append(previous*factor+point*(1-factor))
else:
smoothed_points.append(point)
return smoothed_points
if __name__=="__main__":
smooth_mae_history = smooth_curve(average_acc_history[10:]) # 因为前10个数据点的取值范围和其它点不同,所以删除它们。
plt.plot(range(1, len(smooth_mae_history) + 1), smooth_mae_history)
plt.xlabel('Epochs')
plt.ylabel('Validation MAE')
plt.show() # -从图中能明显发现模型在80轮次后出现了过拟合。
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
使用keras库,数据来自BCI Competiton数据集下的Data from Berlin组的mat文件,请仅限用于研究,数据包组成,使用后三个量x_train(训练集),y_train(标签),x_test(测试集),训练集有316组样本,样本由500毫秒下28通道的数据构成,数据详细描述:http://www.bbci.de/competition/ii/berlin_desc.html。使用k折验证法验证,验证结果极佳,但没有测试集的标签,所以不知道对于新数据的分类情况如何。
资源推荐
资源详情
资源评论
收起资源包目录
LSTMproject.rar (18个子文件)
LSTMproject
LSTM.md 0B
image
LSTM核心思想.jpg 37KB
添加信息.png 26KB
门结构.jpg 55KB
输出门.png 34KB
忘记门.png 21KB
细胞状态.jpg 28KB
LSTM链式结构.jpg 36KB
汇总.jpg 191KB
RNN.jpg 39KB
符号.jpg 17KB
利用LSTM来处理脑电数据.png 67KB
更新信息.png 23KB
LSTM结构.jpg 49KB
LSTM定义.png 15KB
机制.png 449KB
LSTM.py 4KB
sp1s_aa_1000Hz.mat 44.44MB
共 18 条
- 1
资源评论
- 芊暖2023-07-24文件中的代码块和注释非常详细,很容易理解和跟踪。
- 蒋寻2023-07-24作者从多个角度探讨了LSTM的优缺点,使读者能够全面评估这项技术的实用性。
- 断脚的鸟2023-07-24真实案例的引入使这个文件更具实际应用性,读者可以更好地理解LSTM的工作原理。
- 东郊椰林放猪散仙2023-07-24提供了一个简单而实用的LSTM项目示例,能够帮助读者快速上手。
- 方2郭2023-07-24这个文件对于初学者来说非常友好,解释清晰易懂。
朔漠君
- 粉丝: 184
- 资源: 14
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- (源码)基于Spring Boot和Vue的高性能售票系统.zip
- (源码)基于Windows API的USB设备通信系统.zip
- (源码)基于Spring Boot框架的进销存管理系统.zip
- (源码)基于Java和JavaFX的学生管理系统.zip
- (源码)基于C语言和Easyx库的内存分配模拟系统.zip
- (源码)基于WPF和EdgeTTS的桌宠插件系统.zip
- (源码)基于PonyText的文本排版与预处理系统.zip
- joi_240913_8.8.0_73327_share-2EM46K.apk
- Library-rl78g15-fpb-1.2.1.zip
- llvm-17.0.1.202406-rl78-elf.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功