function [model, errors] = rbmFit(X, numhid, y, varargin)
%Fit an RBM to discrete labels in y
%This is not meant to be applied to image data
%code by Andrej Karpathy
%based on implementation of Kevin Swersky and Ruslan Salakhutdinov
%INPUTS:
%X ... data. should be binary, or in [0,1] interpreted as
% ... probabilities
%numhid ... number of hidden units
%y ... List of discrete labels
%additional inputs (specified as name value pairs or in struct)
%nclasses ... number of classes
%method ... CD or SML
%eta ... learning rate
%momentum ... momentum for smoothness amd to prevent overfitting
% ... NOTE: momentum is not recommended with SML
%maxepoch ... # of epochs: each is a full pass through train data
%avglast ... how many epochs before maxepoch to start averaging
% ... before. Procedure suggested for faster convergence by
% ... Kevin Swersky in his MSc thesis
%penalty ... weight decay factor
%weightdecay ... A boolean flag. When set to true, the weights are
% ... Decayed linearly from penalty->0.1*penalty in epochs
%batchsize ... The number of training instances per batch
%verbose ... For printing progress
%anneal ... Flag. If set true, the penalty is annealed linearly
% ... through epochs to 10% of its original value
%OUTPUTS:
%model.W ... The weights of the connections
%model.b ... The biases of the hidden layer
%model.c ... The biases of the visible layer
%model.Wc ... The weights on labels layer
%model.cc ... The biases on labels layer
%errors ... The errors in reconstruction at every epoch
%Process options
args= prepareArgs(varargin);
[ nclasses ...
method ...
eta ...
momentum ...
maxepoch ...
avglast ...
penalty ...
batchsize ...
verbose ...
anneal ...
] = process_options(args , ...
'nclasses' , nunique(y), ...
'method' , 'CD' , ...
'eta' , 0.1 , ...
'momentum' , 0.5 , ...
'maxepoch' , 50 , ...
'avglast' , 5 , ...
'penalty' , 2e-4 , ...
'batchsize' , 100 , ...
'verbose' , false , ...
'anneal' , false);
avgstart = maxepoch - avglast;
oldpenalty= penalty;
[N,d]=size(X);
if (verbose)
fprintf('Preprocessing data...\n')
end
%Create targets: 1-of-k encodings for each discrete label
u= unique(y);
targets= zeros(N, nclasses);
for i=1:length(u)
targets(y==u(i),i)=1;
end
%Create batches
numbatches= ceil(N/batchsize);
groups= repmat(1:numbatches, 1, batchsize);
groups= groups(1:N);
groups = groups(randperm(N));
for i=1:numbatches
batchdata{i}= X(groups==i,:);
batchtargets{i}= targets(groups==i,:);
end
%fit RBM
numcases=N;
numdims=d;
numclasses= length(u);
W = 0.1*randn(numdims,numhid);
c = zeros(1,numdims);
b = zeros(1,numhid);
Wc = 0.1*randn(numclasses,numhid);
cc = zeros(1,numclasses);
ph = zeros(numcases,numhid);
nh = zeros(numcases,numhid);
phstates = zeros(numcases,numhid);
nhstates = zeros(numcases,numhid);
negdata = zeros(numcases,numdims);
negdatastates = zeros(numcases,numdims);
Winc = zeros(numdims,numhid);
binc = zeros(1,numhid);
cinc = zeros(1,numdims);
Wcinc = zeros(numclasses,numhid);
ccinc = zeros(1,numclasses);
Wavg = W;
bavg = b;
cavg = c;
Wcavg = Wc;
ccavg = cc;
t = 1;
errors=zeros(1,maxepoch);
for epoch = 1:maxepoch
errsum=0;
if (anneal)
penalty= oldpenalty - 0.9*epoch/maxepoch*oldpenalty;
end
for batch = 1:numbatches
[numcases numdims]=size(batchdata{batch});
data = batchdata{batch};
classes = batchtargets{batch};
%go up
ph = logistic(data*W + classes*Wc + repmat(b,numcases,1));
phstates = ph > rand(numcases,numhid);
if (isequal(method,'SML'))
if (epoch == 1 && batch == 1)
nhstates = phstates;
end
elseif (isequal(method,'CD'))
nhstates = phstates;
end
%go down
negdata = logistic(nhstates*W' + repmat(c,numcases,1));
negdatastates = negdata > rand(numcases,numdims);
negclasses = softmaxPmtk(nhstates*Wc' + repmat(cc,numcases,1));
negclassesstates = softmax_sample(negclasses);
%go up one more time
nh = logistic(negdatastates*W + negclassesstates*Wc + ...
repmat(b,numcases,1));
nhstates = nh > rand(numcases,numhid);
%update weights and biases
dW = (data'*ph - negdatastates'*nh);
dc = sum(data) - sum(negdatastates);
db = sum(ph) - sum(nh);
dWc = (classes'*ph - negclassesstates'*nh);
dcc = sum(classes) - sum(negclassesstates);
Winc = momentum*Winc + eta*(dW/numcases - penalty*W);
binc = momentum*binc + eta*(db/numcases);
cinc = momentum*cinc + eta*(dc/numcases);
Wcinc = momentum*Wcinc + eta*(dWc/numcases - penalty*Wc);
ccinc = momentum*ccinc + eta*(dcc/numcases);
W = W + Winc;
b = b + binc;
c = c + cinc;
Wc = Wc + Wcinc;
cc = cc + ccinc;
if (epoch > avgstart)
%apply averaging
Wavg = Wavg - (1/t)*(Wavg - W);
cavg = cavg - (1/t)*(cavg - c);
bavg = bavg - (1/t)*(bavg - b);
Wcavg = Wcavg - (1/t)*(Wcavg - Wc);
ccavg = ccavg - (1/t)*(ccavg - cc);
t = t+1;
else
Wavg = W;
bavg = b;
cavg = c;
Wcavg = Wc;
ccavg = cc;
end
%accumulate reconstruction error
err= sum(sum( (data-negdata).^2 ));
errsum = err + errsum;
end
errors(epoch)= errsum;
if (verbose)
fprintf('Ended epoch %i/%i, Reconsruction error is %f\n', ...
epoch, maxepoch, errsum);
end
end
model.W= Wavg;
model.b= bavg;
model.c= cavg;
model.Wc= Wcavg;
model.cc= ccavg;
model.labels= u;
没有合适的资源?快使用搜索试试~ 我知道了~
受限玻尔兹曼机RBM-matlab实现
共17个文件
m:16个
mat:1个
5星 · 超过95%的资源 需积分: 12 178 下载量 11 浏览量
2015-11-28
09:41:57
上传
评论 11
收藏 2.79MB ZIP 举报
温馨提示
条件受玻尔兹曼机的matlab实现代码 网址:https://code.google.com/p/matrbm/downloads/detail?name=RBMLIB.zip
资源推荐
资源详情
资源评论
收起资源包目录
RBMLIB.zip (17个子文件)
RBMLIB
mnist_classify.mat 2.84MB
RBM
dbnFit.m 2KB
rbmVtoH.m 355B
dbnPredict.m 495B
rbmPredict.m 877B
rbmBB.m 5KB
softmax_sample.m 371B
logistic.m 65B
process_options.m 4KB
visualize.m 750B
softmaxPmtk.m 286B
rbmHtoV.m 358B
rbmFit.m 6KB
interweave.m 409B
nunique.m 977B
prepareArgs.m 690B
examplecode.m 2KB
共 17 条
- 1
风翼冰舟
- 粉丝: 2452
- 资源: 56
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页