function [k,mu,alpha,sigma,nabla,delta,ypred,ypredv,post] = rjnn(x,y,chainLength,Ndata,bFunction,par,xv,yv);
%
% =============================
if nargin < 5, error('Not enough input arguments.'); end;
if ((nargin==5) | (nargin==7)),
if nargin == 5
Validation = 0;
else
Validation = 1;
end;
hyper.a = 2; % Hyperparameter for delta.
hyper.b = 10; % Hyperparameter for delta.
hyper.e1 = 0.0001; % Hyperparameter for nabla.
hyper.e2 = 0.0001; % Hyperparameter for nabla.
hyper.v = 0; % Hyperparameter for sigma
hyper.gamma = 0; % Hyperparameter for sigma.
kMax = 50; % Maximum number of basis.
arbC = 0.5; % Constant for birth and death moves.
doPlot = 1; % To plot or not to plot? Thats ...
sigStar = .1; % Merge-split parameter.
sWalk = .001;
Lambda = .5;
walkPer = 0.1;
elseif ((nargin==6) | (nargin==8))
if nargin == 6
Validation = 0;
else
Validation = 1;
end;
hyper.a = par.a;
hyper.b = par.b;
hyper.e1 = par.e1;
hyper.e2 = par.e2;
hyper.v = par.v;
hyper.gamma = par.gamma;
kMax = par.kMax;
arbC = par.arbC;
doPlot = par.doPlot;
sigStar = par.merge;
sWalk = par.sRW;
Lambda = par.Lambda;
walkPer = par.walkPer;
else
error('Wrong Number of input arguments.');
end;
if Validation,
[Nv,dv] = size(xv); % Nv = number of test data, dv = dimension of xv.
end;
[N,d] = size(x); % N = number of train data, d = dimension of x.
[N,c] = size(y); % c = dimension of y, i.e. number of outputs.
if Ndata ~= N, error('input must me N by d and output N by c.'); end;
% INITIALISATION:
% ==============
post = ones(chainLength,1); % p(centres,k|y).
if Validation,
ypredv = zeros(Nv,c,chainLength); % Output fit (test set).
end;
ypred = zeros(N,c,chainLength); % Output fit (train set).
nabla = zeros(chainLength,1); % Poisson parameter.
delta = zeros(chainLength,c); % Regularisation parameter.
k = ones(chainLength,1); % Model order - number of basis.
sigma = ones(chainLength,c); % Output noise variance.
mu = cell(chainLength,1); % Radial basis centres.
alpha = cell(chainLength,c); % Radial basis coefficients.
% DEFINE WALK INTERVAL FOR MU:
% ===========================
walk = walkPer*(max(x)-min(x));
walkInt=zeros(d,1);
for i=1:d,
walkInt(i,1) = (max(x(:,i))-min(x(:,i))) + 2*walk(i);
end;
% SAMPLE INITIAL CONDITIONS FROM THEIR PRIORS:
% ===========================================
nabla(1) = gengamma(0.5 + hyper.e1,hyper.e2);
k(1) = poissrnd(nabla(1));
k(1) = 40; % TEMPORARY: for demo1 comparison.
k(1) = max(k(1),1);
k(1) = min(k(1),kMax);
for i=1:c
delta(1,i) = inv(gengamma(hyper.a,hyper.b));
sigma(1,i) = inv(gengamma(hyper.v/2,hyper.gamma/2));
alpha{1,i} = mvnrnd(zeros(1,k(1)+d+1),sigma(1,i)*delta(1,i)*eye(k(1)+d+1),1)';
end;
% DRAW THE INITIAL RADIAL CENTRES:
% ===============================
mu{1}=zeros(k(1),d);
for i=1:d,
mu{1}(:,i)= (min(x(:,i))-walk(i))*ones(k(1),1) + ((max(x(:,i))+walk(i))-(min(x(:,i))-walk(i)))*rand(k(1),1);
end;
% FILL THE REGRESSION MATRIX:
% ==========================
M=zeros(N,k(1)+d+1);
M(:,1) = ones(N,1);
M(:,2:d+1) = x;
for j=d+2:k(1)+d+1,
M(:,j) = feval(bFunction,mu{1}(j-d-1,:),x);
end;
for i=1:c,
ypred(:,i,1) = M*alpha{1,i};
end;
if Validation
Mv=zeros(Nv,k(1)+d+1);
Mv(:,1) = ones(Nv,1);
Mv(:,2:d+1) = xv;
for j=d+2:k(1)+d+1,
Mv(:,j) = feval(bFunction,mu{1}(j-d-1,:),xv);
end;
for i=1:c,
ypredv(:,i,1) = Mv*alpha{1,i};
end;
end;
% INITIALISE COUNTERS:
% ===================
aUpdate=0;
rUpdate=0;
aBirth=0;
rBirth=0;
aDeath=0;
rDeath=0;
aMerge=0;
rMerge=0;
aSplit=0;
rSplit=0;
aRW=0;
rRW=0;
match=0;
if doPlot
figure(3)
clf;
end;
% ITERATE THE MARKOV CHAIN:
% ========================
for t=1:chainLength-1,
iteration=t
% COMPUTE THE CENTRES AND DIMENSION WITH METROPOLIS, BIRTH AND DEATH MOVES:
% ========================================================================
decision=rand(1);
birth=arbC*min(1,(nabla(t)/(k(t)+1)));
death=arbC*min(1,((k(t)+1)/nabla(t)));
if ((decision <= birth) & (k(t)<kMax)),
[k,mu,M,match,aBirth,rBirth] = radialBirth(match,aBirth,rBirth,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
elseif ((decision <= birth+death) & (k(t)>0)),
[k,mu,M,aDeath,rDeath] = radialDeath(aDeath,rDeath,k,mu,M,delta,x,y,hyper,t,nabla);
elseif ((decision <= 2*birth+death) & (k(t)<kMax) & (k(t)>1)),
[k,mu,M,aSplit,rSplit] = radialSplit(aSplit,rSplit,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt,walk);
elseif ((decision <= 2*birth+2*death) & (k(t)>1)),
[k,mu,M,aMerge,rMerge] = radialMerge(aMerge,rMerge,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt);
else
uLambda = rand(1);
if ((uLambda>Lambda) & (k(t)>0))
[k,mu,M,match,aRW,rRW] = radialRW(match,aRW,rRW,k,mu,M,delta,x,y,hyper,t,bFunction,sWalk,walk);
else
[k,mu,M,match,aUpdate,rUpdate] = radialUpdate(match,aUpdate,rUpdate,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
end;
end;
% UPDATE OTHER PARAMETERS WITH GIBBS:
% ==================================
H=zeros(k(t+1)+1+d,k(t+1)+1+d,c);
F=zeros(k(t+1)+1+d,c);
P=zeros(N,N,c);
for i=1:c,
H(:,:,i) = inv(M'*M + (1/delta(t,i))*eye(k(t+1)+1+d));
F(:,i) = H(:,:,i)*M'*y(:,i);
P(:,:,i) = eye(N) - M*H(:,:,i)*M';
sigma(t+1,i) = inv(gengamma((hyper.v+N)/2,(hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))/2));
alpha{t+1,i} = mvnrnd(F(:,i),sigma(t+1,i)*H(:,:,i),1)';
delta(t+1,i) = inv(gengamma(hyper.a+(k(t+1)+d+1)/2,hyper.b+inv(2*sigma(t+1,i))*alpha{t+1,i}'*alpha{t+1,i}));
end;
nabla(t+1) = gengamma(0.5+hyper.e1+k(t+1),1+hyper.e2);
% COMPUTE THE POSTERIOR FOR MONITORING:
% ====================================
posterior =exp(-nabla(t+1)) * delta(t+1,1)^(-(d+k(t+1)+1)/2) * inv(prod(1:k(t+1)) * prod(walkInt)^(k(t+1))) * nabla(t+1)^(k(t+1)) * sqrt(det(H(:,:,1))) * (hyper.gamma+y(:,1)'*P(:,:,1)*y(:,1))^(-(hyper.v+N)/2);
for i=2:c,
newpost = delta(t+1,i)^(-(d+k(t+1)+1)/2) * sqrt(det(H(:,:,i))) * (hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))^(-(hyper.v+N)/2);
posterior = posterior * newpost;
end;
post(t+1) = log(posterior);
% PLOT FOR FUN AND MONITORING:
% ============================
for i=1:c,
ypred(:,i,t+1) = M*alpha{t+1,i};
end;
msError = inv(N) * trace((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1)));
% NRMSE = sqrt((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1))*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y)))))
if Validation,
% FILL THE VALIDATION REGRESSION MATRIX:
% ======================================
Mv=zeros(Nv,k(t+1)+d+1);
Mv(:,1) = ones(Nv,1);
Mv(:,2:d+1) = xv;
for j=d+2:k(t+1)+d+1,
Mv(:,j) = feval(bFunction,mu{t+1}(j-d-1,:),xv);
end;
for i=1:c,
ypredv(:,i,t+1) = Mv*alpha{t+1,i};
end;
msErrorv = inv(Nv) * trace((yv-ypredv(:,:,t+1))'*(yv-ypredv(:,:,t+1)));
end;
if doPlot,
figure(1)
clf
if (c==2),
plot(x(:,1),y(:,1),'b+',x(:,2),y(:,2),'r+',x(:,1),ypred(:,1,t+1),'bo',x(:,2),ypred(:,2,t+1),'ro');
elseif c==1,
plot(x,y,'b+',x,ypred(:,:,t+1),'ro');
end;
errorv = sum(abs(yv-ypredv(:,:,t+1)))*100*inv(Nv);
ylabel('Output','fontsize',15)
xlabel('Input','fontsize',15)
figure(3)
subplot(511);
hold on;
plot(t,k(t),'*');
ylabel('k','fontsize',15);
subplot(512);
hold on;
plot(t,post(t+1),'*');
ylabel('p(k,mu|y)','fontsize',15);
subplot(513);
hold on;
plot(t,msError,'r*');
ylabel('Train error','fontsize',15);
subplot(514);
hold on;
plot(t,msErrorv,'r*');
ylabel('Test error','fontsize',15);
subplot(515);
hold on;
bar([1 2 3 4 5 6 7 8 9 10 11 12 13]
评论0