function [k,mu,alpha,sigma,nabla,delta,ypred,ypredv,post] = rjnn(x,y,chainLength,Ndata,bFunction,par,xv,yv);
%
% =============================
if nargin < 5, error('Not enough input arguments.'); end;
if ((nargin==5) | (nargin==7)),
if nargin == 5
Validation = 0;
else
Validation = 1;
end;
hyper.a = 2; % Hyperparameter for delta.
hyper.b = 10; % Hyperparameter for delta.
hyper.e1 = 0.0001; % Hyperparameter for nabla.
hyper.e2 = 0.0001; % Hyperparameter for nabla.
hyper.v = 0; % Hyperparameter for sigma
hyper.gamma = 0; % Hyperparameter for sigma.
kMax = 50; % Maximum number of basis.
arbC = 0.5; % Constant for birth and death moves.
doPlot = 1; % To plot or not to plot? Thats ...
sigStar = .1; % Merge-split parameter.
sWalk = .001;
Lambda = .5;
walkPer = 0.1;
elseif ((nargin==6) | (nargin==8))
if nargin == 6
Validation = 0;
else
Validation = 1;
end;
hyper.a = par.a;
hyper.b = par.b;
hyper.e1 = par.e1;
hyper.e2 = par.e2;
hyper.v = par.v;
hyper.gamma = par.gamma;
kMax = par.kMax;
arbC = par.arbC;
doPlot = par.doPlot;
sigStar = par.merge;
sWalk = par.sRW;
Lambda = par.Lambda;
walkPer = par.walkPer;
else
error('Wrong Number of input arguments.');
end;
if Validation,
[Nv,dv] = size(xv); % Nv = number of test data, dv = dimension of xv.
end;
[N,d] = size(x); % N = number of train data, d = dimension of x.
[N,c] = size(y); % c = dimension of y, i.e. number of outputs.
if Ndata ~= N, error('input must me N by d and output N by c.'); end;
% INITIALISATION:
% ==============
post = ones(chainLength,1); % p(centres,k|y).
if Validation,
ypredv = zeros(Nv,c,chainLength); % Output fit (test set).
end;
ypred = zeros(N,c,chainLength); % Output fit (train set).
nabla = zeros(chainLength,1); % Poisson parameter.
delta = zeros(chainLength,c); % Regularisation parameter.
k = ones(chainLength,1); % Model order - number of basis.
sigma = ones(chainLength,c); % Output noise variance.
mu = cell(chainLength,1); % Radial basis centres.
alpha = cell(chainLength,c); % Radial basis coefficients.
% DEFINE WALK INTERVAL FOR MU:
% ===========================
walk = walkPer*(max(x)-min(x));
walkInt=zeros(d,1);
for i=1:d,
walkInt(i,1) = (max(x(:,i))-min(x(:,i))) + 2*walk(i);
end;
% SAMPLE INITIAL CONDITIONS FROM THEIR PRIORS:
% ===========================================
nabla(1) = gengamma(0.5 + hyper.e1,hyper.e2);
k(1) = poissrnd(nabla(1));
k(1) = 40; % TEMPORARY: for demo1 comparison.
k(1) = max(k(1),1);
k(1) = min(k(1),kMax);
for i=1:c
delta(1,i) = inv(gengamma(hyper.a,hyper.b));
sigma(1,i) = inv(gengamma(hyper.v/2,hyper.gamma/2));
alpha{1,i} = mvnrnd(zeros(1,k(1)+d+1),sigma(1,i)*delta(1,i)*eye(k(1)+d+1),1)';
end;
% DRAW THE INITIAL RADIAL CENTRES:
% ===============================
mu{1}=zeros(k(1),d);
for i=1:d,
mu{1}(:,i)= (min(x(:,i))-walk(i))*ones(k(1),1) + ((max(x(:,i))+walk(i))-(min(x(:,i))-walk(i)))*rand(k(1),1);
end;
% FILL THE REGRESSION MATRIX:
% ==========================
M=zeros(N,k(1)+d+1);
M(:,1) = ones(N,1);
M(:,2:d+1) = x;
for j=d+2:k(1)+d+1,
M(:,j) = feval(bFunction,mu{1}(j-d-1,:),x);
end;
for i=1:c,
ypred(:,i,1) = M*alpha{1,i};
end;
if Validation
Mv=zeros(Nv,k(1)+d+1);
Mv(:,1) = ones(Nv,1);
Mv(:,2:d+1) = xv;
for j=d+2:k(1)+d+1,
Mv(:,j) = feval(bFunction,mu{1}(j-d-1,:),xv);
end;
for i=1:c,
ypredv(:,i,1) = Mv*alpha{1,i};
end;
end;
% INITIALISE COUNTERS:
% ===================
aUpdate=0;
rUpdate=0;
aBirth=0;
rBirth=0;
aDeath=0;
rDeath=0;
aMerge=0;
rMerge=0;
aSplit=0;
rSplit=0;
aRW=0;
rRW=0;
match=0;
if doPlot
figure(3)
clf;
end;
% ITERATE THE MARKOV CHAIN:
% ========================
for t=1:chainLength-1,
iteration=t
% COMPUTE THE CENTRES AND DIMENSION WITH METROPOLIS, BIRTH AND DEATH MOVES:
% ========================================================================
decision=rand(1);
birth=arbC*min(1,(nabla(t)/(k(t)+1)));
death=arbC*min(1,((k(t)+1)/nabla(t)));
if ((decision <= birth) & (k(t)<kMax)),
[k,mu,M,match,aBirth,rBirth] = radialBirth(match,aBirth,rBirth,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
elseif ((decision <= birth+death) & (k(t)>0)),
[k,mu,M,aDeath,rDeath] = radialDeath(aDeath,rDeath,k,mu,M,delta,x,y,hyper,t,nabla);
elseif ((decision <= 2*birth+death) & (k(t)<kMax) & (k(t)>1)),
[k,mu,M,aSplit,rSplit] = radialSplit(aSplit,rSplit,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt,walk);
elseif ((decision <= 2*birth+2*death) & (k(t)>1)),
[k,mu,M,aMerge,rMerge] = radialMerge(aMerge,rMerge,k,mu,M,delta,x,y,hyper,t,bFunction,sigStar,walkInt);
else
uLambda = rand(1);
if ((uLambda>Lambda) & (k(t)>0))
[k,mu,M,match,aRW,rRW] = radialRW(match,aRW,rRW,k,mu,M,delta,x,y,hyper,t,bFunction,sWalk,walk);
else
[k,mu,M,match,aUpdate,rUpdate] = radialUpdate(match,aUpdate,rUpdate,k,mu,M,delta,x,y,hyper,t,bFunction,walkInt,walk);
end;
end;
% UPDATE OTHER PARAMETERS WITH GIBBS:
% ==================================
H=zeros(k(t+1)+1+d,k(t+1)+1+d,c);
F=zeros(k(t+1)+1+d,c);
P=zeros(N,N,c);
for i=1:c,
H(:,:,i) = inv(M'*M + (1/delta(t,i))*eye(k(t+1)+1+d));
F(:,i) = H(:,:,i)*M'*y(:,i);
P(:,:,i) = eye(N) - M*H(:,:,i)*M';
sigma(t+1,i) = inv(gengamma((hyper.v+N)/2,(hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))/2));
alpha{t+1,i} = mvnrnd(F(:,i),sigma(t+1,i)*H(:,:,i),1)';
delta(t+1,i) = inv(gengamma(hyper.a+(k(t+1)+d+1)/2,hyper.b+inv(2*sigma(t+1,i))*alpha{t+1,i}'*alpha{t+1,i}));
end;
nabla(t+1) = gengamma(0.5+hyper.e1+k(t+1),1+hyper.e2);
% COMPUTE THE POSTERIOR FOR MONITORING:
% ====================================
posterior =exp(-nabla(t+1)) * delta(t+1,1)^(-(d+k(t+1)+1)/2) * inv(prod(1:k(t+1)) * prod(walkInt)^(k(t+1))) * nabla(t+1)^(k(t+1)) * sqrt(det(H(:,:,1))) * (hyper.gamma+y(:,1)'*P(:,:,1)*y(:,1))^(-(hyper.v+N)/2);
for i=2:c,
newpost = delta(t+1,i)^(-(d+k(t+1)+1)/2) * sqrt(det(H(:,:,i))) * (hyper.gamma+y(:,i)'*P(:,:,i)*y(:,i))^(-(hyper.v+N)/2);
posterior = posterior * newpost;
end;
post(t+1) = log(posterior);
% PLOT FOR FUN AND MONITORING:
% ============================
for i=1:c,
ypred(:,i,t+1) = M*alpha{t+1,i};
end;
msError = inv(N) * trace((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1)));
% NRMSE = sqrt((y-ypred(:,:,t+1))'*(y-ypred(:,:,t+1))*inv((y-mean(y)*ones(size(y)))'*(y-mean(y)*ones(size(y)))))
if Validation,
% FILL THE VALIDATION REGRESSION MATRIX:
% ======================================
Mv=zeros(Nv,k(t+1)+d+1);
Mv(:,1) = ones(Nv,1);
Mv(:,2:d+1) = xv;
for j=d+2:k(t+1)+d+1,
Mv(:,j) = feval(bFunction,mu{t+1}(j-d-1,:),xv);
end;
for i=1:c,
ypredv(:,i,t+1) = Mv*alpha{t+1,i};
end;
msErrorv = inv(Nv) * trace((yv-ypredv(:,:,t+1))'*(yv-ypredv(:,:,t+1)));
end;
if doPlot,
figure(1)
clf
if (c==2),
plot(x(:,1),y(:,1),'b+',x(:,2),y(:,2),'r+',x(:,1),ypred(:,1,t+1),'bo',x(:,2),ypred(:,2,t+1),'ro');
elseif c==1,
plot(x,y,'b+',x,ypred(:,:,t+1),'ro');
end;
errorv = sum(abs(yv-ypredv(:,:,t+1)))*100*inv(Nv);
ylabel('Output','fontsize',15)
xlabel('Input','fontsize',15)
figure(3)
subplot(511);
hold on;
plot(t,k(t),'*');
ylabel('k','fontsize',15);
subplot(512);
hold on;
plot(t,post(t+1),'*');
ylabel('p(k,mu|y)','fontsize',15);
subplot(513);
hold on;
plot(t,msError,'r*');
ylabel('Train error','fontsize',15);
subplot(514);
hold on;
plot(t,msErrorv,'r*');
ylabel('Test error','fontsize',15);
subplot(515);
hold on;
bar([1 2 3 4 5 6 7 8 9 10 11 12 13]
没有合适的资源?快使用搜索试试~ 我知道了~
MCMC的matlab源代码
共15个文件
m:13个
asv:1个
txt:1个
4星 · 超过85%的资源 需积分: 31 357 下载量 58 浏览量
2010-06-17
09:59:12
上传
评论 15
收藏 14KB RAR 举报
温馨提示
MCMC的matlab源代码,非常适合用于多目标跟踪。 希望对大家有所帮助
资源推荐
资源详情
资源评论
收起资源包目录
87361032particle-filter-mcmc.rar (15个子文件)
MCMC
rjGaussian.m 396B
radialUpdate.m 3KB
radialBirth.m 2KB
rjdemo1.asv 5KB
rjdemo1.m 5KB
rjtpSpline.m 300B
radialRW.m 3KB
gengamma.m 539B
rjnn.m 8KB
radialDeath.m 2KB
radialSplit.m 3KB
radialMerge.m 3KB
rjCubic.m 301B
rjMultiquadric.m 315B
www.pudn.com.txt 218B
共 15 条
- 1
houiverson
- 粉丝: 8
- 资源: 1
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
- 4
前往页