%回归问题
%读取数据集
clear;clc;close all;
[trainImages,~,trainAngles] = digitTrain4DArrayData;
%显示任意二十个结果
numTrainImages = size(trainImages,4);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
subplot(4,5,i)
imshow(trainImages(:,:,:,idx(i)))
drawnow
end
%建立回归网络
layers = [ ...
imageInputLayer([28 28 1])
convolution2dLayer(12,25)
reluLayer
fullyConnectedLayer(1)
regressionLayer
];
%设置训练参数
functions={...
@plotTrainingRMSE,...
@(info)stopTrainingAtThreshold(info,0)};%定义训练过程中能调用的函数,用于可视化
options = trainingOptions('sgdm', ...
'MaxEpochs',20, ...
'InitialLearnRate',1e-3, ...
'MiniBatchSize',128,... %每次梯度下降处理样本的个数
'ExecutionEnvironment','cpu',...
'OutputFcn',functions);
%训练网络
net = trainNetwork(trainImages,[trainAngles],layers,options);
%测试网络
[testImages,~,testAngles] = digitTest4DArrayData;
predictedTestAngles = predict(net,testImages);%有直接的predict函数
%查看拟合误差
predictionError = testAngles - predictedTestAngles;
thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numTestImages = size(testImages,4);
accuracy = numCorrect/numTestImages;
disp('accuracy');
disp(accuracy);
squares = predictionError.^2;
rmse = sqrt(mean(squares));
disp('the rmse');
disp(rmse);
%train function
function plotTrainingRMSE(info)
persistent plotObj
if info.State == "start"
figure;
plotObj = animatedline;
xlabel("Iteration")
ylabel("Training RMSE")
elseif info.State == "iteration"
addpoints(plotObj,info.Iteration,double(info.TrainingRMSE))
drawnow limitrate nocallbacks
end
end
function stop = stopTrainingAtThreshold(info,thr) %可调整使训练提前结束
stop = false;
if info.State ~= "iteration"
return
end
persistent TrainingRMSE
% Append accuracy for this iteration
T= info.TrainingRMSE;
% Evaluate mean of iteration accuracy and remove oldest entry
stop = T <thr;
end