function [rcnn_model, rcnn_k_fold_model] = ...
my_rcnn_train(imdb, varargin)
% [rcnn_model, rcnn_k_fold_model] = rcnn_train(imdb, varargin)
% Trains an R-CNN detector for all classes in the imdb.
%
% Keys that can be passed in:
%
% svm_C SVM regularization parameter
% bias_mult Bias feature value (for liblinear)
% pos_loss_weight Cost factor on hinge loss for positives
% layer Feature layer to use (either 5, 6 or 7)
% k_folds Train on folds of the imdb
% checkpoint Save the rcnn_model every checkpoint images
% crop_mode Crop mode (either 'warp' or 'square')
% crop_padding Amount of padding in crop
% net_file Path to the Caffe CNN to use
% cache_name Path to the precomputed feature cache
% AUTORIGHTS
% ---------------------------------------------------------
% Copyright (c) 2014, Ross Girshick
%
% This file is part of the R-CNN code and is available
% under the terms of the Simplified BSD License provided in
% LICENSE. Please retain this notice and LICENSE if you use
% this file (or any portion of it) in your project.
% ---------------------------------------------------------
% TODO:
% - allow training just a subset of classes
ip = inputParser;
ip.addRequired('imdb', @isstruct);
ip.addParamValue('svm_C', 10^-3, @isscalar);
ip.addParamValue('bias_mult', 10, @isscalar);
ip.addParamValue('pos_loss_weight', 2, @isscalar);
ip.addParamValue('layer', 7, @isscalar);
ip.addParamValue('k_folds', 2, @isscalar);
ip.addParamValue('checkpoint', 0, @isscalar);
ip.addParamValue('crop_mode', 'warp', @isstr);
ip.addParamValue('crop_padding', 16, @isscalar);
ip.addParamValue('net_file', ...
'./data/caffe_nets/finetune_voc_2007_trainval_iter_70k', ...
@isstr);
ip.addParamValue('cache_name', ...
'v1_finetune_voc_2007_trainval_iter_70000', @isstr);
ip.parse(imdb, varargin{:});
opts = ip.Results;
%opts.net_def_file = './model-defs/rcnn_batch_256_output_fc7.prototxt';
opts.net_def_file = './model-defs/pascal_finetune_deploy.prototxt';
conf = rcnn_config('sub_dir', imdb.name);
conf
done_fname = fullfile(conf.cache_dir, 'done_training');
if exist(done_fname, 'file')
fprintf('done training, skipping training phase\n');
model_data = load([conf.cache_dir 'rcnn_model']);
rcnn_model = model_data.rcnn_model;
rcnn_k_fold_model = [];
if opts.k_folds > 0
model_data = load([conf.cache_dir 'rcnn_k_fold_model']);
rcnn_k_fold_model = model_data.rcnn_k_fold_model;
end
return;
end
% Record a log of the training and test procedure
timestamp = datestr(datevec(now()), 'dd.mmm.yyyy:HH.MM.SS');
diary_file = [conf.cache_dir 'rcnn_train_' timestamp '.txt'];
diary(diary_file);
fprintf('Logging output in %s\n', diary_file);
fprintf('\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n');
fprintf('Training options:\n');
disp(opts);
fprintf('~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n');
% ------------------------------------------------------------------------
% Create a new rcnn model
rcnn_model = rcnn_create_model(opts.net_def_file, opts.net_file, opts.cache_name);
rcnn_model = rcnn_load_model(rcnn_model, conf.use_gpu);
rcnn_model.detectors.crop_mode = opts.crop_mode;
rcnn_model.detectors.crop_padding = opts.crop_padding;
rcnn_model.classes = imdb.classes;
% ------------------------------------------------------------------------
% ------------------------------------------------------------------------
% Get the average norm of the features
opts.feat_norm_mean = rcnn_feature_stats(imdb, opts.layer, rcnn_model);
fprintf('average norm = %.3f\n', opts.feat_norm_mean);
rcnn_model.training_opts = opts;
% ------------------------------------------------------------------------
% ------------------------------------------------------------------------
% Get all positive examples
% We cache only the pool5 features and convert them on-the-fly to
% fc6 or fc7 as required
save_file = sprintf('./feat_cache/%s/%s/gt_pos_layer_5_cache.mat', ...
rcnn_model.cache_name, imdb.name);
try
load(save_file);
fprintf('Loaded saved positives from ground truth boxes\n');
catch
[X_pos, keys_pos] = get_positive_pool5_features(imdb, opts);
save(save_file, 'X_pos', 'keys_pos', '-v7.3');
end
% Init training caches
caches = {};
for i = imdb.class_ids
fprintf('%14s has %6d positive instances\n', ...
imdb.classes{i}, size(X_pos{i},1));
X_pos{i} = rcnn_pool5_to_fcX(X_pos{i}, opts.layer, rcnn_model);
X_pos{i} = rcnn_scale_features(X_pos{i}, opts.feat_norm_mean);
caches{i} = init_cache(X_pos{i}, keys_pos{i});
end
% ------------------------------------------------------------------------
% ------------------------------------------------------------------------
% Train with hard negative mining
first_time = true;
% one pass over the data is enough
max_hard_epochs = 1;
for hard_epoch = 1:max_hard_epochs
for i = 1:length(imdb.image_ids)
fprintf('%s: hard neg epoch: %d/%d image: %d/%d\n', ...
procid(), hard_epoch, max_hard_epochs, i, length(imdb.image_ids));
% Get hard negatives for all classes at once (avoids loading feature cache
% more than once)
[X, keys] = sample_negative_features(first_time, rcnn_model, caches, ...
imdb, i);
% Add sampled negatives to each classes training cache, removing
% duplicates
for j = imdb.class_ids
if ~isempty(keys{j})
if ~isempty(caches{j}.keys_neg)
[~, ~, dups] = intersect(caches{j}.keys_neg, keys{j}, 'rows');
assert(isempty(dups));
end
caches{j}.X_neg = cat(1, caches{j}.X_neg, X{j});
caches{j}.keys_neg = cat(1, caches{j}.keys_neg, keys{j});
caches{j}.num_added = caches{j}.num_added + size(keys{j},1);
end
% Update model if
% - first time seeing negatives
% - more than retrain_limit negatives have been added
% - its the final image of the final epoch
is_last_time = (hard_epoch == max_hard_epochs && i == length(imdb.image_ids));
hit_retrain_limit = (caches{j}.num_added > caches{j}.retrain_limit);
if (first_time || hit_retrain_limit || is_last_time) && ...
~isempty(caches{j}.X_neg)
fprintf('>>> Updating %s detector <<<\n', imdb.classes{j});
fprintf('Cache holds %d pos examples %d neg examples\n', ...
size(caches{j}.X_pos,1), size(caches{j}.X_neg,1));
[new_w, new_b] = update_model(caches{j}, opts);
rcnn_model.detectors.W(:, j) = new_w;
rcnn_model.detectors.B(j) = new_b;
caches{j}.num_added = 0;
z_pos = caches{j}.X_pos * new_w + new_b;
z_neg = caches{j}.X_neg * new_w + new_b;
caches{j}.pos_loss(end+1) = opts.svm_C * opts.pos_loss_weight * ...
sum(max(0, 1 - z_pos));
caches{j}.neg_loss(end+1) = opts.svm_C * sum(max(0, 1 + z_neg));
caches{j}.reg_loss(end+1) = 0.5 * new_w' * new_w + ...
0.5 * (new_b / opts.bias_mult)^2;
caches{j}.tot_loss(end+1) = caches{j}.pos_loss(end) + ...
caches{j}.neg_loss(end) + ...
caches{j}.reg_loss(end);
for t = 1:length(caches{j}.tot_loss)
fprintf(' %2d: obj val: %.3f = %.3f (pos) + %.3f (neg) + %.3f (reg)\n', ...
t, caches{j}.tot_loss(t), caches{j}.pos_loss(t), ...
caches{j}.neg_loss(t), caches{j}.reg_loss(t));
end
% store negative support vectors for visualizing later
SVs_neg = find(z_neg > -1 - eps);
rcnn_model.SVs.keys_neg{j} = caches{j}.keys_neg(SVs_neg, :);
rcnn_model.SVs.scores_neg{j} = z_neg(SVs_neg);
% evict easy examples
easy = find(z_neg < caches{j}.evict_thresh);
caches{j}.X_neg(easy,:) = [];
caches{j}.keys_neg(easy,:) = [];
fprintf(' Pruning easy negatives\n');