clear
clc
%bp神经网络,将60000个数据训练epochs次,每训练一个epochs,就测试一次。
% 加载 MNIST 数据集
load test_MNIST.mat;
load train_MNIST.mat;
%将数据标签转换成10*矩阵
train_Y=zeros(10,60000);
test_Y=zeros(10,10000);
for i=1:1:60000
train_Y(train_y(i)+1,i) = 1;
end
for i=1:1:10000
test_Y(test_y(i)+1,i) = 1;
end
% 设置神经网络参数
learning_rate1 = 0.01;%学习率
epochs=50; %训练次数
% 设置神经网络层数
input_neurons = 784; %输入神经元数量
hidden_neurons = 30; %隐含神经元数量
output_neurons = 10; %输出神经元数量
% 随机定义偏置参数
weights1 = randn(hidden_neurons,input_neurons); %30*784
weights2 = randn(hidden_neurons,output_neurons); %30*10
bias1 = randn(hidden_neurons,1); %30*1
bias2 = randn(output_neurons,1); %10*1
% 定义激活函数及其导数
sigmoid = @(x) 1./(1+exp(-x));
sigmoid_derivative = @(x) sigmoid(x).*(1-sigmoid(x));
loss1=[]; %用来记录损失值
% 训练神经网络
for s=1:1:epochs
learning_rate=learning_rate1.* 0.98.^s; %动态学习率
for m= 1:1:60000
train_x1=train_x(m,:);
train_y1=train_Y(:,m);
% 前向传播
z1 = weights1 * train_x1' ; %30*1
a1 = sigmoid(z1 - bias1 ); %隐含层输出
z2 = weights2' * a1 ; %10*1
a2 = sigmoid(z2 - bias2); %输出层输出
% 计算损失
loss = 0.5.*mean(((a2-train_y1).^2),'all'); %mean:求均值
% 反向传播
delta2 = -(a2 - train_y1).*sigmoid_derivative(z2); %gj
delta1 = (weights2*delta2).*sigmoid_derivative(z1); %eh
% 更新参数
weights2 = weights2 + learning_rate.*(a1*delta2'); %学习率*gi*bh size(train_x,2) 返回矩阵的列数
bias2 = bias2 - learning_rate.*delta2; %-学习率*gi mean(delta2,2) 返回值为该矩阵的各行向量的均值
weights1 = weights1 + learning_rate.*(delta1*train_x1); %学习率*eh*xi
bias1 = bias1 - learning_rate.*delta1; %-学习率*eh
end
loss1=[loss1,loss]; %记录损失值
% Test the network
count=0; %记录识别正确的测试样本个数
for k=1:1:10000
test_x1=test_x(k,:);
test_y1=test_Y(:,k);
z3 = weights1 * test_x1' ; %30*10
a3 = sigmoid(z3 - bias1); %隐含层输出
z4 = weights2' * a3 ; %10*1
a4 = sigmoid(z4 - bias2); %输出层输出
[~,inx] = max(a4); %每列最大值
inx = inx -1;
if inx == test_y(k)
count = count+1;
end
end
accuracy = count/10000; %测试正确率
disp(['Epochs: ' num2str(s) ' Loss: ' num2str(loss) ' Test accuracy: ' num2str(accuracy)])
end
%画图
figure
x=1:1:epochs;
plot(x,loss1(x))
xlabel('迭代次数')
ylabel('损失值')
pain1(test_x,test_y,weights1,bias1,weights2,bias2)