function [eigvector, eigvalue, elapse] = SDA(gnd,fea,LabelIdx,UnlabelIdx,options)
% SDA: Semi-supervised Discriminant Analysis
%
% [eigvector, eigvalue, elapse] = SDA(gnd,feaLabel,feaUnlabel,options)
%
% Input:
% gnd - Label vector.
% fea - data matrix. Each row vector of fea is a data point.
%
% LabelIdx - fea(LabelIdx,:) is the labeled data matrix.
% UnlabelIdx - fea(UnlabelIdx,:) is the unlabeled data matrix.
%
% options - Struct value in Matlab. The fields in options
% that can be set:
%
% WOptions Please see ConstructW.m for detailed options.
% or
% W You can construct the W outside.
%
% beta Paramter to tune the weight between
% supervised info and local info
% Default 0.1.
% beta*L+\tilde{I}
%
% Please see LGE.m for other options.
%
% Output:
% eigvector - Each column is an embedding function, for a new
% data point (row vector) x, y = x*eigvector
% will be the embedding result of x.
% eigvalue - The eigvalue of SDA eigen-problem. sorted from
% smallest to largest.
% elapse - Time spent on different steps
%
%
% Examples:
%
%
%
%
% See also LPP, LGE
%
%Reference:
%
% Deng Cai, Xiaofei He and Jiawei Han, "Semi-Supervised Discriminant
% Analysis ", IEEE International Conference on Computer Vision (ICCV),
% Rio de Janeiro, Brazil, Oct. 2007.
%
% version 2.0 --July/2007
% version 1.0 --May/2006
%
% Written by Deng Cai (dengcai2 AT cs.uiuc.edu)
if ~isfield(options,'ReguType')
options.ReguType = 'Ridge';
end
if ~isfield(options,'ReguAlpha')
options.ReguAlpha = 0.1;
end
[nSmp,nFea] = size(fea);
nSmpLabel = length(LabelIdx);
nSmpUnlabel = length(UnlabelIdx);
if nSmpLabel+nSmpUnlabel ~= nSmp
error('input error!');
end
gnd = gnd(LabelIdx);
classLabel = unique(gnd);
nClass = length(classLabel);
Dim = nClass;
if ~isfield(options,'W')
[W, timeW] = constructW(fea,options.WOptions);
if isfield(options.WOptions,'bSemiSupervised') & options.WOptions.bSemiSupervised
if ~isfield(options.WOptions,'SameCategoryWeight')
options.WOptions.SameCategoryWeight = 1;
end
G2 = zeros(nSmpLabel,nSmpLabel);
Label = unique(gnd);
nLabel = length(Label);
for idx=1:nLabel
classIdx = find(gnd==Label(idx));
G2(classIdx,classIdx) = options.WOptions.SameCategoryWeight;
end
W(LabelIdx,LabelIdx) = G2;
end
else
W = options.W;
timeW = 0;
end
tmp_T = cputime;
D = full(sum(W,2));
W = -W;
for i=1:size(W,1)
W(i,i) = W(i,i) + D(i);
end
beta = 0.1;
if isfield(options,'beta') & (options.beta > 0)
beta = options.beta;
end
D = W*beta;
for i=1:nSmpLabel
D(LabelIdx(i),LabelIdx(i)) = D(LabelIdx(i),LabelIdx(i)) + 1;
end
elapse.timeW = timeW + cputime - tmp_T;
tmp_T = cputime;
%==========================
% If data is too large, the following centering codes can be commented
%==========================
if isfield(options,'keepMean') & options.keepMean
;
else
if issparse(fea)
fea = full(fea);
end
sampleMean = mean(fea,1);
fea = (fea - repmat(sampleMean,nSmp,1));
end
%==========================
DPrime = fea'*D*fea;
switch lower(options.ReguType)
case {lower('Ridge')}
for i=1:size(DPrime,1)
DPrime(i,i) = DPrime(i,i) + options.ReguAlpha;
end
case {lower('Tensor')}
DPrime = DPrime + options.ReguAlpha*options.regularizerR;
case {lower('Custom')}
DPrime = DPrime + options.ReguAlpha*options.regularizerR;
otherwise
error('ReguType does not exist!');
end
DPrime = max(DPrime,DPrime');
feaLabel = fea(LabelIdx,:);
Hb = zeros(nClass,nFea);
for i = 1:nClass,
index = find(gnd==classLabel(i));
classMean = mean(feaLabel(index,:),1);
Hb (i,:) = sqrt(length(index))*classMean;
end
WPrime = Hb'*Hb;
WPrime = max(WPrime,WPrime');
elapse.timePCA = cputime - tmp_T;
tmp_T = cputime;
dimMatrix = size(WPrime,2);
if Dim > dimMatrix
Dim = dimMatrix;
end
if isfield(options,'bEigs')
if options.bEigs
bEigs = 1;
else
bEigs = 0;
end
else
if (dimMatrix > 1000 & Dim < dimMatrix/10) | (dimMatrix > 500 & Dim < dimMatrix/20) | (dimMatrix > 250 & Dim < dimMatrix/30)
bEigs = 1;
else
bEigs = 0;
end
end
if bEigs
%disp('use eigs to speed up!');
option = struct('disp',0);
[eigvector, eigvalue] = eigs(WPrime,DPrime,Dim,'la',option);
eigvalue = diag(eigvalue);
else
[eigvector, eigvalue] = eig(WPrime,DPrime);
eigvalue = diag(eigvalue);
[junk, index] = sort(-eigvalue);
eigvalue = eigvalue(index);
eigvector = eigvector(:,index);
if Dim < size(eigvector,2)
eigvector = eigvector(:, 1:Dim);
eigvalue = eigvalue(1:Dim);
end
end
for i = 1:size(eigvector,2)
eigvector(:,i) = eigvector(:,i)./norm(eigvector(:,i));
end
elapse.timeMethod = cputime - tmp_T;
elapse.timeAll = elapse.timeW + elapse.timePCA + elapse.timeMethod;