function [center, U, distortion] = vqKmeans(data, clusterNum, plotOpt)
% vqKmeans: Vector quantization using K-means clustering (Forgy's batch-mode method)
% Usage: [center, U, distortion] = KMEANS(data, clusterNum)
% data (dim x dataNum): data set to be clustered; where each column is a sample data
% clusterNum: number of clusters (greater than one), or matrix of columns of centers
% center (dim x clusterNum): final cluster centers, where each column is a center
% U: final fuzzy partition matrix (or MF matrix)
% distortion: values of the objective function during iterations
%
% Roger Jang, 20030330
if nargin==0, selfdemo; return; end
if nargin<3, plotOpt=0; end
maxLoopCount = 100; % Max. iteration
distortion = zeros(maxLoopCount, 1); % Array for objective function
if length(clusterNum)==1
center = initCenter(clusterNum, data, 4); % Initial cluster centers
else
center = clusterNum; % The passed argument is actually a matrix of cluster centers
end
if plotOpt & size(data,1)>=2
plot(data(1,:), data(2,:), 'b.');
centerH=line(center(1,:), center(2,:), 'color', 'r', 'marker', 'o', 'linestyle', 'none', 'linewidth', 2);
axis image
end;
% Main loop
for i = 1:maxLoopCount,
[center, distortion(i), U] = vqUpdateCenter(center, data);
% fprintf('Iteration count = %d, distortion = %f\n', i, distortion(i));
if plotOpt & size(data,1)>=2
set(centerH, 'xdata', center(1,:), 'ydata', center(2,:));
drawnow;
end
% check termination condition
if (i>1) & (abs(distortion(i-1)-distortion(i))<eps), break; end
end
loopCount = i; % Actual number of iterations
distortion(loopCount+1:maxLoopCount) = [];
if plotOpt & size(data,1)>=2, vqPlotResult(data, center, U); end
% ========== subfunctions ==========
% ====== Find the initial centers
function center = initCenter(clusterNum, data, method)
if nargin<3; method=3; end
switch method
case 1
% ====== Method 1: Randomly pick clusterNum data points as cluster centers
dataNum = size(data, 2);
temp = randperm(dataNum);
center = data(:, temp(1:clusterNum));
case 2
% ====== Method 2: Choose clusterNum data points closest to mean vector
meanVec = mean(data, 2);
distMat = pairwiseSqrDist(meanVec, data);
[minDist, colIndex] = sort(distMat);
center = data(:, colIndex(1:clusterNum));
case 3
% ====== Method 3: Choose clusterNum data points furthest to the mean vector
meanVec = mean(data, 2);
distMat = pairwiseSqrDist(meanVec, data);
[minDist, colIndex] = sort(-distMat);
center = data(:, colIndex(1:clusterNum));
case 4
% ====== Method 4: 使用資料的前幾點作為 center
center = data(:, 1:clusterNum);
otherwise
error('Unknown method!');
end
% ====== Update centers
function [center, distortion, U] = vqUpdateCenter(center, data)
dim = size(data, 1);
dataNum = size(data, 2);
centerNum = size(center, 2);
% ====== Compute distance matrix
distMat=pairwiseSqrDist(center, data);
% ====== Find the U (partition matrix)
[minDist, colIndex] = min(distMat);
U = zeros(size(distMat));
U(colIndex+centerNum*(0:dataNum-1)) = 1;
distortion = sum(minDist); % objective function
% ====== Find new centers
index=find(sum(U,2)==0);
emptyGroupNum=length(index);
if emptyGroupNum==0 % Find the new centers (with no empty cluster)
center = (data*U')./(ones(dim,1)*sum(U,2)'); % Find the new centers
else % Add new centers for the empty clusters
fprintf('Found %d empty group(s)!\n', emptyGroupNum);
U(index,:)=[];
center = (data*U')./(ones(dim,1)*sum(U,2)'); % Find the new centers
if emptyGroupNum<=centerNum/2 % 找出 distortion 最大的幾個 cluster 來進行 center splitting
fprintf('Try center splitting...\n');
distMat(index,:)=[];
distortionByGroup=sum(distMat.*U, 2);
[junk, index]=sort(-distortionByGroup); % Find the indices of the centers to be split
index=index(1:emptyGroupNum);
temp=center; temp(:, index)=[];
center=[temp, center(:,index)-eps, center(:,index)+eps]; % Center splitting
distMat=pairwiseSqrDist(center, data);
[minDist, colIndex] = min(distMat);
U = zeros(size(distMat));
U(colIndex+centerNum*(0:dataNum-1)) = 1;
distortion = sum(minDist); % objective function
center = (data*U')./(ones(dim,1)*sum(U,2)');
else % Select new centers based on random selection on the data points
fprintf('Try random selection...\n');
newU = zeros(1,1);
while ~isempty(find(sum(newU, 2)==0))
temp=randperm(dataNum);
selectedIndex=temp(1:emptyGroupNum);
newU = [zeros(emptyGroupNum, dataNum); U];
for i=1:emptyGroupNum
dataIndex=selectedIndex(i);
index=find(U(:, dataIndex)==1);
newU(index, dataIndex)=0;
newU(i, dataIndex)=1;
end
end
U=newU;
distortion = sum(minDist); % objective function
center=[center, data(:, selectedIndex)];
end
end
% ====== Self demo
function selfdemo
colordef black
data = dcdata(2)';
centerNum=8;
plotOpt=1;
[center, U, distortion] = feval(mfilename, data, centerNum, plotOpt);