%% 基于长短期记忆 (LSTM) 网络对序列数据进行分类
%使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。
%% 加载序列数据
%使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。
%从 WaveformData 加载示例数据。
%序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。
load WaveformData
%绘制部分序列
numChannels = size(data{1},2);
idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
nexttile
stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
xlabel("Time Step")
title("Class: " + string(labels(idx(i))))
end
%查看类名称
classNames = categories(labels)
%划分数据
%使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据),
numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);
XTest = data(idxTest);
TTest = labels(idxTest);
%% 准备要填充的数据
%默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度
%获取观测值序列长度
numObservations = numel(XTrain);
for i=1:numObservations
sequence = XTrain{i};
sequenceLengths(i) = size(sequence,1);
end
%序列长度排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);
%查看序列长度
figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")
%% 定义 LSTM 神经网络架构
%将输入大小指定为输入数据的通道数。
%指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。
%最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。
numHiddenUnits = 120;
numClasses = 4;
layers = [
sequenceInputLayer(numChannels)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer]
%% 指定训练选项
%使用 Adam 求解器进行训练。
%进行 200 轮训练。
%指定学习率为 0.002。
%使用阈值 1 裁剪梯度。
%为了保持序列按长度排序,禁用乱序。
%在图中显示训练进度并监控准确度。
options = trainingOptions("adam", ...
MaxEpochs=200, ...
InitialLearnRate=0.002,...
GradientThreshold=1, ...
Shuffle="never", ...
Plots="training-progress", ...
Metrics="accuracy", ...
Verbose=false);
%% 训练 LSTM 神经网络
%使用 trainnet 函数训练神经网络
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);
%% 测试 LSTM 神经网络
%对测试数据进行分类,并计算预测的分类准确度。
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
sequence = XTest{i};
sequenceLengthsTest(i) = size(sequence,1);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);
%对测试数据进行分类,并计算预测的分类准确度。
scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);
%计算分类准确度
acc = mean(YTest == TTest)
%混淆图中显示分类结果
figure
confusionchart(TTest,YTest)
逼子歌
- 粉丝: 3569
- 资源: 41
最新资源
- 基于C语言的嵌入式软件定时器详细文档+全部资料+高分项目+源码.zip
- 基于ffmpeg的直播推流器,超级稳定,经过长时间稳定性测试,超低延时,可用于手机,电视,嵌入式等直播App及设备。详细文档+全部资料+高分项目+源码.zip
- 基于DCT算法的水印嵌入和提取的移动智能终端数字图像证据系统详细文档+全部资料+高分项目+源码.zip
- 基于FPGA的DDR1控制器,为低端FPGA嵌入式系统提供廉价、大容量的存储详细文档+全部资料+高分项目+源码.zip
- 基于FreeRTOS开发的嵌入式开发框架详细文档+全部资料+高分项目+源码.zip
- 基于FMCW雷达的多天线定位系统详细文档+全部资料+高分项目+源码.zip
- 基于FriendlyARM6410平台的嵌入式Qt程序:实时天气信息,远程vnc控制,远程监视摄像头,语音控制,语音输出TTS详细文档+全部资料+高分项目+源码.zip
- 基于FSMPSTem32的嵌入式音乐播放器、实训作业详细文档+全部资料+高分项目+源码.zip
- 基于GEC6818嵌入式大作业详细文档+全部资料+高分项目+源码.zip
- 基于jetty嵌入式容器的java性能分析工具,内嵌H2 database,以图表形式直观展现应用当前性能数据详细文档+全部资料+高分项目+源码.zip
- 基于jq开发的数学公式插件,可随意嵌入web中详细文档+全部资料+高分项目+源码.zip
- 基于Linux系统的应用程序,旨在搭建一套完整的多进程多线程通讯的消息框架. 支持多SOC的嵌入式APP详细文档+全部资料+高分项目+源码.zip
- 基于mplayer的嵌入式音视频播放器详细文档+全部资料+高分项目+源码.zip
- 基于LSM-Tree的嵌入式数据库详细文档+全部资料+高分项目+源码.zip
- 基于liunx下的一个QT程序,KTV点歌系统嵌入式设备详细文档+全部资料+高分项目+源码.zip
- 基于MySQL的嵌入式Linux智慧农业采集控制系统详细文档+全部资料+高分项目+源码.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈