function [net,stats] = cnn_train_dag(net, imdb, getBatch, varargin)
%CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
% CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
% the DagNN wrapper instead of the SimpleNN wrapper.
% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).
addpath(fullfile(vl_rootnn, 'examples'));
opts.expDir = fullfile('data','exp') ;
opts.continue = true ;
opts.batchSize = 256 ;
opts.numSubBatches = 1 ;
opts.train = [] ;
opts.val = [] ;
opts.gpus = [] ;
opts.prefetch = false ;
opts.epochSize = inf;
opts.numEpochs = 300 ;
opts.learningRate = 0.001 ;
opts.weightDecay = 0.0005 ;
opts.solver = [] ; % Empty array means use the default SGD solver
[opts, varargin] = vl_argparse(opts, varargin) ;
if ~isempty(opts.solver)
assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
'Invalid solver; expected a function handle with two outputs.') ;
% Call without input arguments, to get default options
opts.solverOpts = opts.solver() ;
end
opts.momentum = 0.9 ;
opts.saveSolverState = true ;
opts.nesterovUpdate = false ;
opts.randomSeed = 0 ;
opts.profile = false ;
opts.parameterServer.method = 'mmap' ;
opts.parameterServer.prefix = 'mcn' ;
opts.derOutputs = {'objective', 1} ;
opts.extractStatsFn = @extractStats ;
opts.plotStatistics = true;
opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
opts = vl_argparse(opts, varargin) ;
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
opts.train = [] ;
end
if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
opts.val = [] ;
end
% -------------------------------------------------------------------------
% Initialization
% -------------------------------------------------------------------------
evaluateMode = isempty(opts.train) ;
if ~evaluateMode
if isempty(opts.derOutputs)
error('DEROUTPUTS must be specified when training.\n') ;
end
end
% -------------------------------------------------------------------------
% Train and validate
% -------------------------------------------------------------------------
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
start = opts.continue * findLastCheckpoint(opts.expDir) ;
if start >= 1
fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
[net, state, stats] = loadState(modelPath(start)) ;
else
state = [] ;
end
for epoch=start+1:opts.numEpochs
% Set the random seed based on the epoch and opts.randomSeed.
% This is important for reproducibility, including when training
% is restarted from a checkpoint.
rng(epoch + opts.randomSeed) ;
prepareGPUs(opts, epoch == start+1) ;
% Train for one epoch.
params = opts ;
params.epoch = epoch ;
params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
params.train = opts.train(randperm(numel(opts.train))) ; % shuffle
params.train = params.train(1:min(opts.epochSize, numel(opts.train)));
params.val = opts.val(randperm(numel(opts.val))) ;
params.imdb = imdb ;
params.getBatch = getBatch ;
if numel(opts.gpus) <= 1
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
else
spmd
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if labindex == 1 && ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
end
lastStats = accumulateStats(lastStats) ;
end
stats.train(epoch) = lastStats.train ;
stats.val(epoch) = lastStats.val ;
clear lastStats ;
saveStats(modelPath(epoch), stats) ;
if opts.plotStatistics
switchFigure(1) ; clf ;
plots = setdiff(...
cat(2,...
fieldnames(stats.train)', ...
fieldnames(stats.val)'), {'num', 'time'}) ;
for p = plots
p = char(p) ;
values = zeros(0, epoch) ;
leg = {} ;
for f = {'train', 'val'}
f = char(f) ;
if isfield(stats.(f), p)
tmp = [stats.(f).(p)] ;
values(end+1,:) = tmp(1,:)' ;
leg{end+1} = f ;
end
end
subplot(1,numel(plots),find(strcmp(p,plots))) ;
plot(1:epoch, values','o-') ;
xlabel('epoch') ;
title(p) ;
legend(leg{:}) ;
grid on ;
end
drawnow ;
print(1, modelFigPath, '-dpdf') ;
end
if ~isempty(opts.postEpochFn)
if nargout(opts.postEpochFn) == 0
opts.postEpochFn(net, params, state) ;
else
lr = opts.postEpochFn(net, params, state) ;
if ~isempty(lr), opts.learningRate = lr; end
if opts.learningRate == 0, break; end
end
end
end
% With multiple GPUs, return one copy
if isa(net, 'Composite'), net = net{1} ; end
% -------------------------------------------------------------------------
function [net, state] = processEpoch(net, state, params, mode)
% -------------------------------------------------------------------------
% Note that net is not strictly needed as an output argument as net
% is a handle class. However, this fixes some aliasing issue in the
% spmd caller.
% initialize with momentum 0
if isempty(state) || isempty(state.solverState)
state.solverState = cell(1, numel(net.params)) ;
state.solverState(:) = {0} ;
end
% move CNN to GPU as needed
numGpus = numel(params.gpus) ;
if numGpus >= 1
net.move('gpu') ;
for i = 1:numel(state.solverState)
s = state.solverState{i} ;
if isnumeric(s)
state.solverState{i} = gpuArray(s) ;
elseif isstruct(s)
state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
end
end
end
if numGpus > 1
parserv = ParameterServer(params.parameterServer) ;
net.setParameterServer(parserv) ;
else
parserv = [] ;
end
% profile
if params.profile
if numGpus <= 1
profile clear ;
profile on ;
else
mpiprofile reset ;
mpiprofile on ;
end
end
num = 0 ;
epoch = params.epoch ;
subset = params.(mode) ;
adjustTime = 0 ;
stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;
start = tic ;
for t=1:params.batchSize:numel(subset)
fprintf('%s: epoch %02d: %3d/%3d:', mode, epoch, ...
fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
batchSize = min(params.batchSize, numel(subset) - t + 1) ;
for s=1:params.numSubBatches
% get this image batch and prefetch the next
batchStart = t + (labindex-1) + (s-1) * numlabs ;
batchEnd = min(t+params.batchSize-1, numel(subset)) ;
batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
num = num + numel(batch) ;
if numel(batch) == 0, continue ; end
inputs = params.getBatch(params.imdb, batch) ;
if params.prefetch
if s == params.numSubBatches
batchStart = t + (labindex-1) + params.batchSize ;
batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
else
batchStart = batchStart + numlabs ;
end
nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
params.getBatch(params.imdb, nextBatch) ;
end
if strcmp(mode, 'train')
net.mode = 'normal' ;
net.accumulateParamDers = (s ~= 1) ;
net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ;
else