package coforest;
/**
* Description: CoForest is a semi-supervised algorithm, which exploits the power of ensemble learning and available
* large amount of unlabeled data to produce hypothesis with better performance.
*
* Reference: M. Li, Z.-H. Zhou. Improve computer-aided diagnosis with machine learning techniques using undiagnosed
* samples. IEEE Transactions on Systems, Man and Cybernetics - Part A: Systems and Humans, 2007, 37(6).
*
* ATTN: This package is free for academic usage. You can run it at your own risk.
* For other purposes, please contact Prof. Zhi-Hua Zhou (zhouzh@nju.edu.cn).
*
* Requirement: To use this package, the whole WEKA environment (ver 3.4) must be available.
* refer: I.H. Witten and E. Frank. Data Mining: Practical Machine Learning
* Tools and Techniques with Java Implementations. Morgan Kaufmann,
* San Francisco, CA, 2000.
*
* Data format: Both the input and output formats are the same as those used by WEKA.
*
* ATTN2: This package was developed by Mr. Ming Li (lim@lamda.nju.edu.cn). There
* is a ReadMe file provided for roughly explaining the codes. But for any
* problem concerning the code, please feel free to contact with Mr. Li.
*
*/
import java.io.*;
import java.text.*;
import java.util.*;
import weka.core.*;
import weka.classifiers.*;
import weka.classifiers.trees.*;
public class CoForest
{
/** Random Forest */
protected Classifier[] m_classifiers = null;
/** The number component */
protected int m_numClassifiers = 10;
/** The random seed */
protected int m_seed = 1;
/** Number of features to consider in random feature selection.
If less than 1 will use int(logM+1) ) */
protected int m_numFeatures = 0;
/** Final number of features that were considered in last build. */
protected int m_KValue = 0;
/** confidence threshold */
protected double m_threshold = 0.75;
private int m_numOriginalLabeledInsts = 0;
/**
* The constructor
*/
public CoForest()
{
}
/**
* Set the seed for initiating the random object used inside this class
*
* @param s int -- The seed
*/
public void setSeed(int s)
{
m_seed = s;
}
/**
* Set the number of trees used in Random Forest.
*
* @param s int -- Value to assign to numTrees.
*/
public void setNumClassifiers(int n)
{
m_numClassifiers = n;
}
/**
* Get the number of trees used in Random Forest
*
* @return int -- The number of trees.
*/
public int getNumClassifiers()
{
return m_numClassifiers;
}
/**
* Set the number of features to use in random selection.
*
* @param n int -- Value to assign to m_numFeatures.
*/
public void setNumFeatures(int n)
{
m_numFeatures = n;
}
/**
* Get the number of featrues to use in random selection.
*
* @return int -- The number of features
*/
public int getNumFeatures()
{
return m_numFeatures;
}
/**
* Resample instances w.r.t the weight
*
* @param data Instances -- the original data set
* @param random Random -- the random object
* @param sampled boolean[] -- the output parameter, indicating whether the instance is sampled
* @return Instances
*/
public final Instances resampleWithWeights(Instances data,
Random random,
boolean[] sampled)
{
double[] weights = new double[data.numInstances()];
for (int i = 0; i < weights.length; i++) {
weights[i] = data.instance(i).weight();
}
Instances newData = new Instances(data, data.numInstances());
if (data.numInstances() == 0) {
return newData;
}
double[] probabilities = new double[data.numInstances()];
double sumProbs = 0, sumOfWeights = Utils.sum(weights);
for (int i = 0; i < data.numInstances(); i++) {
sumProbs += random.nextDouble();
probabilities[i] = sumProbs;
}
Utils.normalize(probabilities, sumProbs / sumOfWeights);
// Make sure that rounding errors don't mess things up
probabilities[data.numInstances() - 1] = sumOfWeights;
int k = 0; int l = 0;
sumProbs = 0;
while ((k < data.numInstances() && (l < data.numInstances()))) {
if (weights[l] < 0) {
throw new IllegalArgumentException("Weights have to be positive.");
}
sumProbs += weights[l];
while ((k < data.numInstances()) &&
(probabilities[k] <= sumProbs)) {
newData.add(data.instance(l));
sampled[l] = true;
newData.instance(k).setWeight(1);
k++;
}
l++;
}
return newData;
}
/**
* Returns the probability label of a given instance
*
* @param inst Instance -- The instance
* @return double[] -- The probability label
* @throws Exception -- Some exception
*/
public double[] distributionForInstance(Instance inst) throws Exception
{
double[] res = new double[inst.numClasses()];
for(int i = 0; i < m_classifiers.length; i++)
{
double[] distr = m_classifiers[i].distributionForInstance(inst);
for(int j = 0; j < res.length; j++)
res[j] += distr[j];
}
Utils.normalize(res);
return res;
}
/**
* Classifies a given instance
*
* @param inst Instance -- The instance
* @return double -- The class value
* @throws Exception -- Some Exception
*/
public double classifyInstance(Instance inst) throws Exception
{
double[] distr = distributionForInstance(inst);
return Utils.maxIndex(distr);
}
/**
* Build the classifiers using Co-Forest algorithm
*
* @param labeled Instances -- Labeled training set
* @param unlabeled Instances -- unlabeled training set
* @throws Exception -- certain exception
*/
public void buildClassifier(Instances labeled, Instances unlabeled) throws Exception
{
double[] err = new double[m_numClassifiers];
double[] err_prime = new double[m_numClassifiers];
double[] s_prime = new double[m_numClassifiers];
boolean[][] inbags = new boolean[m_numClassifiers][];
Random rand = new Random(m_seed);
m_numOriginalLabeledInsts = labeled.numInstances();
RandomTree rTree = new RandomTree();
// set up the random tree options
m_KValue = m_numFeatures;
if (m_KValue < 1) m_KValue = (int) Utils.log2(labeled.numAttributes())+1;
rTree.setKValue(m_KValue);
m_classifiers = Classifier.makeCopies(rTree, m_numClassifiers);
Instances[] labeleds = new Instances[m_numClassifiers];
int[] randSeeds = new int[m_numClassifiers];
for(int i = 0; i < m_numClassifiers; i++)
{
randSeeds[i] = rand.nextInt();
((RandomTree)m_classifiers[i]).setSeed(randSeeds[i]);
inbags[i] = new boolean[labeled.numInstances()];
labeleds[i] = resampleWithWeights(labeled, rand, inbags[i]);
m_classifiers[i].buildClassifier(labeleds[i]);
err_prime[i] = 0.5;
s_prime[i] = 0;
}
boolean bChanged = true;
while(bChanged)
{
bChanged = false;
boolean[] bUpdate = new boolean[m_classifiers.length];
Instances[] Li = new Instances[m_numClassifiers];
for(int i = 0; i < m_numClassifiers; i++)
{
err[i] = measureError(labeled, inbags, i);
Li[i] = new Instances(labeled, 0);
/** if (e_i < e'_i) */
if(err[i] < err_prime[i])
{
if(s_prime[i] == 0)
s_prime[i] = Math.min(unlabeled.sumOfWeights() / 10, 100);
/** Subsample U for each hi */
double weight = 0;
unlabeled.randomize(rand);
int numWeightsAfterSubsample = (int) Math.ceil(err_prime[i] * s_prime[i] / err[i] - 1);
- 1
- 2
前往页