function [err,raw,prob,sse,distmat] = pnncv(train,ncls,clsinfo,clsize,smooth,choice,prntopt,distmat)
% PNNCV probabalistic neural network cross validation
% Author: Ron Shaffer
% Revisions: 4/10/96 Version 1.0 Original code (based on PROBNN 1.4)
% 4/15/96 Version 1.1 Added ability to pass in distance matrix
% 4/16/96 Version 1.2 Added the ability to suppress printout
% 4/25/95 Version 1.3 Compute continuous error criterion
% using sum-of-squared-errors
% 4/26/95 Version 1.4 Removed possibility of sse returning NaN
%
% [err,raw,prob,sse,distmat] = pnncv(train,ncls,clsinfo,clsize,smooth,choice,prntopt,distmat)
%
% err number of misclassifed patterns in cross-validation
% raw raw pnn outputs for cross-validation procedure
% prob: Bayes posterior probabilities
% sse: sum of squared errors
% distmat matrix of distance values
% train: training set patterns (number of pattern x number of sensors)
% ncls: number of classes (number of outputs for PNN)
% clsinfo: vector containing the classication of each pattern in training
% clsize: vector of class sizes
% smooth smoothing factor
% choice choice of distance measure 1 = d.p 2 = euclidean
% prntopt printout control (1 = full) (0 = min) [optional]
% distmat input distance matrix [optional]
% NOTE: Use this code at your own risk because the author assumes no liability!
%
% Note: this m-file makes use of routines from the PLS_Toolbox package from Eigenvector
% technologies and the neural networks toolbox from MATLAB.
%
%
% set constants
%
err = 0;
sse = 0;
misclassed(1:ncls) = zeros(size(1:ncls));
[npat_t,ndim] = size(train);
nhcel = npat_t;
smooth_sqr = smooth * smooth;
%
% if only 5 arguments are passed in then suppress printout
%
if (nargin == 6)
prntopt = 0;
end
%
% move training data to hidden units (i.e., training) and normalize.
% Normalization method based on Mark Beale's normr routine from the
% neural network toolbox
%
hcel = sqrt(ones./(sum((train.*train)')))'*ones(1,ndim).*train;
%
% if nargin is less than 7 then distance matrix must be computed
% otherwise skip this time-consuming step
%
if (nargin <= 7)
%
% compute distance matrix using dot product calculation (much faster method!)
% or euclidean distance calculation
%
if choice == 1
if prntopt == 1
fprintf('Computing distance matrix using Dot Product calculation \n');
end
distmat = 1-(hcel * hcel');
else
if prntopt == 1
fprintf('Computing distance matrix using Euclidean Distance calculation \n');
end
for i = 1:nhcel
for j = 1:i
distmat(i,j) = sum((hcel(j,1:ndim) - hcel(i,1:ndim)).^2);
distmat(j,i) = distmat(i,j);
end
end
end
end
%
% now perform cross-validation
%
for i = 1:npat_t
%
distvect = distmat(i,:)';
distvect = delsamps(distvect,i);
weight = exp((-distvect)/smooth_sqr);
%
newclsinfo = delsamps(clsinfo,i);
%
% summation for each output layer
%
output(1:ncls) = zeros(size(1:ncls));
for j = 1:nhcel-1
output(clsinfo(j)) = output(clsinfo(j)) + weight(j);
end
%
% output with highest probability is the winner
%
raw(i,1:ncls) = output(1:ncls);
%
% Compute mean output for each class by dividing by class size
% This calculation corrects for unequal class sizes.
%
cvclsize = clsize;
cvclsize(clsinfo(i)) = clsize(clsinfo(i)) - 1;
output(1:ncls) = output(1:ncls) ./ cvclsize;
[junk,winner(i)] = max(output);
sumout = sum(output);
prob(i,1:ncls) = output(1:ncls)./sumout;
%
% compute sum of squared error (Masters's book pages 197-201)
%
sse = sse + (((1-prob(i,clsinfo(i)))^2) + (sum((prob(i,:).^2)) - prob(i,clsinfo(i))^2));
%
% collect misclassified patterns
%
if winner(i) ~= clsinfo(i)
err = err + 1;
misclassed(clsinfo(i)) = misclassed(clsinfo(i)) + 1;
end
end
%
% Print out results and exit
%
misclassed(1:ncls) = 100 .* (clsize(1:ncls) - misclassed(1:ncls)) ./ clsize(1:ncls);
if prntopt == 1
for i = 1:ncls
fprintf('Class %d Percentage Correct %7.4f \n',i,misclassed(i));
end
end
overallerr = 100*(err/npat_t);
if isnan(sse) == 1
sse = 9999;
end
if prntopt == 1
fprintf('Overall Error %7.4f \n',overallerr);
fprintf('Sum of Squared Errors %7.4f \n', sse);
end