%获取MnistSet.mat结构体中的数据
load('MnistSet.mat')
trainImages = getfield(mnist,'train_images');
trainLabels = getfield(mnist,'train_labels');
testImages = getfield(mnist,'test_images');
testLabels = getfield(mnist,'test_labels');
trainLength = length(trainImages);
testLength = length(testImages);
%图像数据处理
trainImages = reshape(trainImages,size(trainImages,1)*size(trainImages,2),size(trainImages,3));
testImages = reshape(testImages,size(testImages,1)*size(testImages,2),size(testImages,3));
trainImages = double(trainImages)/255;
testImages = double(testImages)/255;
testResults = linspace(0,0,length(testImages));
compareLabel = linspace(0,0,120);
for K =1:120 % K从1循环取到120
%knn算法实现
for i=1:testLength
calImage = repmat(testImages(:,i),1,trainLength);
calImage = abs(trainImages-calImage);
s=sum(calImage);
[sortedComp,pos] = sort(s);
for j = 1:K
compareLabel(j) = trainLabels(pos(j));
end
fre = tabulate(compareLabel);
[maxCount,idx] = max(fre(:,2));
testResults(i) = fre(idx);
disp(testResults(i));
disp(testLabels(i));
end
% 计算错误统计个数
error=0;
for i=1:testLength
if (testResults(i) ~= testLabels(i))
error=error+1;
end
end
%计算错误率并绘制图像
p(K,:)=error/testLength
end
plot(p)