import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
public class NB extends AbstractClassifier {
private double [][] m_ClassAtt;
private double [] m_ClassCounts;
private int m_ClassNum; //统计类标记种类数
private int [] m_NumAtt; //统计每个属性当中取值有多少种
private int m_AttNum; //统计属性个数
private int [] m_StartAttIndex; //统计属性I是从哪个下标开始
private int m_AllAttNum; //计算下标位置
private int m_instanceNum;
private int m_classIndex;
public void buildClassifier(Instances instances) throws Exception {
m_ClassNum=instances.numClasses();
m_AttNum=instances.numAttributes();
m_AllAttNum=0;
m_instanceNum=instances.numInstances();
m_NumAtt=new int [m_AttNum];
m_StartAttIndex=new int[m_AttNum];
m_classIndex=instances.classIndex();
for(int i=0;i<m_AttNum;i++)
{
if(i!=instances.classIndex())
{
m_StartAttIndex[i]=m_AllAttNum;
m_NumAtt[i]=instances.attribute(i).numValues();
m_AllAttNum+=m_NumAtt[i];
}
else
{
m_NumAtt[i]=instances.numClasses();
}
}
m_ClassAtt=new double[m_AllAttNum][m_ClassNum];
m_ClassCounts=new double[m_ClassNum];
for(int i=0;i<instances.numInstances();i++)
{
int ClassType;
ClassType=(int)instances.instance(i).classValue();
m_ClassCounts[ClassType]++;
for(int j=0;j<instances.numAttributes();j++)
{
if(j!=instances.classIndex())
{
int AttIndex;
AttIndex=(int)instances.instance(i).value(j);
AttIndex+=m_StartAttIndex[j];
m_ClassAtt[AttIndex][ClassType]++;
}
}
}
}
public double [] distributionForInstance(Instance instance) throws Exception {
double [] probs=new double[m_ClassNum];
int []AttIndex=new int[m_AttNum];
for(int k=0;k<m_AttNum;k++)
{
if(k!=m_classIndex)
{
AttIndex[k]=m_StartAttIndex[k]+(int)instance.value(k);
}
else
{
AttIndex[k]=-1;
}
}
for(int i=0;i<m_ClassNum;i++)
{
probs[i]=(m_ClassCounts[i]+1)/(m_instanceNum+m_ClassNum);
for(int j=0;j<m_AttNum;j++)
{
if(AttIndex[j]==-1) continue;
probs[i]*=(m_ClassAtt[AttIndex[j]][i]+1)/(m_ClassCounts[i]+m_NumAtt[j]);
}
}
Utils.normalize(probs);
return probs;
}
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new NB(), argv));
}
catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}