/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SimpleKMeans.java
* Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
*
*/
package weka.clusterers;
import weka.classifiers.rules.DecisionTableHashKey;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.ManhattanDistance;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import java.util.Vector;
/**
<!-- globalinfo-start -->
* Cluster data using the k means algorithm
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -N <num>
* number of clusters.
* (default 2).</pre>
*
* <pre> -V
* Display std. deviations for centroids.
* </pre>
*
* <pre> -M
* Replace missing values with mean/mode.
* </pre>
*
* <pre> -S <num>
* Random number seed.
* (default 10)</pre>
*
* <pre> -A <classname and options>
* Distance function to be used for instance comparison
* (default weka.core.EuclidianDistance)</pre>
*
* <pre> -I <num>
* Maximum number of iterations. </pre>
*
* <pre> -O
* Preserve order of instances. </pre>
*
*
<!-- options-end -->
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 5538 $
* @see RandomizableClusterer
*/
public class SimpleKMeans
extends RandomizableClusterer
implements NumberOfClustersRequestable, WeightedInstancesHandler {
/** for serialization */
static final long serialVersionUID = -3235809600124455376L;
/**
* replace missing values in training instances
*/
private ReplaceMissingValues m_ReplaceMissingFilter;
/**
* number of clusters to generate
*/
private int m_NumClusters = 2;
/**
* holds the cluster centroids
*/
private Instances m_ClusterCentroids;
/**
* Holds the standard deviations of the numeric attributes in each cluster
*/
private Instances m_ClusterStdDevs;
/**
* For each cluster, holds the frequency counts for the values of each
* nominal attribute
*/
private int [][][] m_ClusterNominalCounts;
private int[][] m_ClusterMissingCounts;
/**
* Stats on the full data set for comparison purposes
* In case the attribute is numeric the value is the mean if is
* being used the Euclidian distance or the median if Manhattan distance
* and if the attribute is nominal then it's mode is saved
*/
private double[] m_FullMeansOrMediansOrModes;
private double[] m_FullStdDevs;
private int[][] m_FullNominalCounts;
private int[] m_FullMissingCounts;
/**
* Display standard deviations for numeric atts
*/
private boolean m_displayStdDevs;
/**
* Replace missing values globally?
*/
private boolean m_dontReplaceMissing = false;
/**
* The number of instances in each cluster
*/
private int [] m_ClusterSizes;
/**
* Maximum number of iterations to be executed
*/
private int m_MaxIterations = 500;
/**
* Keep track of the number of iterations completed before convergence
*/
private int m_Iterations = 0;
/**
* Holds the squared errors for all clusters
*/
private double [] m_squaredErrors;
/** the distance function used. */
protected DistanceFunction m_DistanceFunction = new EuclideanDistance();
/**
* Preserve order of instances
*/
private boolean m_PreserveOrder = false;
/**
* Assignments obtained
*/
protected int[] m_Assignments = null;
/**
* the default constructor
*/
public SimpleKMeans() {
super();
m_SeedDefault = 10;
setSeed(m_SeedDefault);
}
/**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Cluster data using the k means algorithm. Can use either "
+ "the Euclidean distance (default) or the Manhattan distance."
+ " If the Manhattan distance is used, then centroids are computed "
+ "as the component-wise median rather than mean.";
}
/**
* Returns default capabilities of the clusterer.
*
* @return the capabilities of this clusterer
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
result.enable(Capability.NO_CLASS);
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
return result;
}
/**
* Generates a clusterer. Has to initialize all fields of the clusterer
* that are not being set via options.
*
* @param data set of instances serving as training data
* @throws Exception if the clusterer has not been
* generated successfully
*/
public void buildClusterer(Instances data) throws Exception {
// can clusterer handle the data?
getCapabilities().testWithFail(data);
m_Iterations = 0;
m_ReplaceMissingFilter = new ReplaceMissingValues();
Instances instances = new Instances(data);
instances.setClassIndex(-1);
if (!m_dontReplaceMissing) {
m_ReplaceMissingFilter.setInputFormat(instances);
instances = Filter.useFilter(instances, m_ReplaceMissingFilter);
}
m_FullMissingCounts = new int[instances.numAttributes()];
if (m_displayStdDevs) {
m_FullStdDevs = new double[instances.numAttributes()];
}
m_FullNominalCounts = new int[instances.numAttributes()][0];
m_FullMeansOrMediansOrModes = moveCentroid(0, instances, false);
for (int i = 0; i < instances.numAttributes(); i++) {
m_FullMissingCounts[i] = instances.attributeStats(i).missingCount;
if (instances.attribute(i).isNumeric()) {
if (m_displayStdDevs) {
m_FullStdDevs[i] = Math.sqrt(instances.variance(i));
}
if (m_FullMissingCounts[i] == instances.numInstances()) {
m_FullMeansOrMediansOrModes[i] = Double.NaN; // mark missing as mean
}
} else {
m_FullNominalCounts[i] = instances.attributeStats(i).nominalCounts;
if (m_FullMissingCounts[i]
> m_FullNominalCounts[i][Utils.maxIndex(m_FullNominalCounts[i])]) {
m_FullMeansOrMediansOrModes[i] = -1; // mark missing as most common value
}
}
}
m_ClusterCentroids = new Instances(instances, m_NumClusters);
int[] clusterAssignments = new int [instances.numInstances()];
if(m_PreserveOrder)
m_Assignments = clusterAssignments;
m_DistanceFunction.setInstances(instances);
Random RandomO = new Random(getSeed());
int instIndex;
HashMap initC = new HashMap();
DecisionTableHashKey hk = null;
Instances initInstances = null;
if(m_PreserveOrder)
initInstances = new Instances(instances);
else
initInstances = i