#include <iostream>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>
using namespace std;
const int maximum=10000;
struct counter{
int type;
int count;
counter(int t){
type=t;
count=1;
}
};
double square(double x){
return x*x;
}
void setValue(int *array,int val,int length){
for(int i=0;i<length;i++)
array[i]=val;
}
void setValue(double *array,double val,int length){
for(int i=0;i<length;i++)
array[i]=val;
}
class treeNode{
public:
struct treeNode *leftkid,*rightkid;
int *samples;
int *availables;
int total,row,col,index,type;
public:
treeNode(int r,int c,int *ava);
void addSample(int x);
int measure(int **feature,int *category);
int getNumTypes(int *category);
void createKids();
void devideSamples(int **feature);
bool examine(int index,int **feature,int *category,int totalTypes);
void majority(int *category,int totalTypes);
};
treeNode::treeNode(int r,int c,int *ava){
row=r;
col=c;
samples=new int[row];
setValue(samples,-1,row);
availables=new int[col];
for(int i=0;i<col;i++)
availables[i]=ava[i];
total=0;
index=-1;
type=-1;
leftkid=NULL;
rightkid=NULL;
}
void treeNode::createKids(){
//index indicates according to which feature the division was done
leftkid=new treeNode(row,col,availables);
rightkid=new treeNode(row,col,availables);
leftkid->availables[index]=0;
rightkid->availables[index]=0;
}
void treeNode::addSample(int x){
samples[total++]=x;
}
int treeNode::getNumTypes(int *category){
int store[row],count=0;
setValue(store,-1,row);
store[count++]=category[0];
for(int i=1;i<row;i++){
bool found=false;
for(int j=0;j<count;j++){
if(category[i]==store[j]){
found=true;
break;
}
}
if(!found)
store[count++]=category[i];
}
return count;
}
int treeNode::measure(int **feature,int *category){
double min_gain=1,index=-1,gain;
int numTypes=getNumTypes(category);
for(int i=0;i<col;i++){
if(!availables[i])
continue;
struct counter *set1[numTypes],*set2[numTypes];
int set1_count=0,set2_count=0;
int set1_total=0,set2_total=0;
for(int j=0;j<total;j++){
if(feature[samples[j]][i]){
bool found=false;
for(int k=0;k<set1_count;k++)
if(set1[k]->type==category[samples[j]]){
set1[k]->count++;
found=true;
break;
}
if(!found)
set1[set1_count++]=new counter(category[samples[j]]);
set1_total++;
}
else{
bool found=false;
for(int k=0;k<set2_count;k++)
if(set2[k]->type==category[samples[j]]){
set2[k]->count++;
found=true;
break;
}
if(!found)
set2[set2_count++]=new counter(category[samples[j]]);
set2_total++;
}
}
//calculate GINI index
double set1_sum=1,set2_sum=1;
for(int k=0;k<set1_count;k++)
set1_sum=set1_sum-square((double)set1[k]->count/set1_total);
for(int k=0;k<set2_count;k++)
set2_sum=set2_sum-square((double)set2[k]->count/set2_total);
gain=set1_sum*set1_total/total+set2_sum*set2_total/total;
if(gain<min_gain){
min_gain=gain;
index=i;
}
}
return index;
}
bool treeNode::examine(int index,int **feature,int *category,int totalTypes){
//calculate chi-square to decide relevance between feature and type
int **rect=new int*[totalTypes];
double **mat=new double*[totalTypes];
int *category_sum,*feature_sum;
for(int i=0;i<totalTypes;i++){
rect[i]=new int[2];
mat[i]=new double[2];
setValue(rect[i],0,2);
setValue(mat[i],0.0,2);
}
for(int i=0;i<total;i++){
int order=samples[i];
rect[category[order]][feature[order][index]]++;
}
category_sum=new int[totalTypes];
feature_sum=new int[2];
for(int i=0;i<2;i++){
int sum=0;
for(int j=0;j<totalTypes;j++)
sum+=rect[j][i];
feature_sum[i]=sum;
}
for(int j=0;j<totalTypes;j++){
int sum=0;
for(int i=0;i<2;i++)
sum+=rect[j][i];
category_sum[j]=sum;
}
for(int j=0;j<totalTypes;j++)
for(int i=0;i<2;i++)
mat[j][i]=(double)category_sum[j]*feature_sum[i]/total;
double target=0;
for(int j=0;j<totalTypes;j++)
for(int i=0;i<2;i++)
if(mat[j][i]!=0)
target+=square(mat[j][i]-rect[j][i])/mat[j][i];
if(target>0.5)
return true;
return false;
}
void treeNode::devideSamples(int **feature){
for(int i=0;i<total;i++){
if(feature[samples[i]][index]==0)
this->leftkid->addSample(samples[i]);
else
this->rightkid->addSample(samples[i]);
}
}
void treeNode::majority(int *category,int totalTypes){
int* store=new int[totalTypes];
setValue(store,0,totalTypes);
int mark=-1,top=-1;
for(int i=0;i<total;i++)
store[category[samples[i]]]++;
for(int i=0;i<totalTypes;i++)
if(store[i]>top){
top=store[i];
mark=i;
}
type=mark;
}
class DecisionTree{
public:
treeNode *root;
treeNode **array;
int count,ptr;
int row,col;
int totalTypes;
public:
DecisionTree(int r,int c);
void train(int **feature,int *category);
int predict(int *sample);
};
DecisionTree::DecisionTree(int r,int c){
row=r;
col=c;
count=0;
ptr=0;
totalTypes=0;
int *ava=new int[col];
setValue(ava,1,col);
root=new treeNode(r,c,ava);
for(int i=0;i<row;i++)
root->addSample(i);
array=new treeNode*[maximum];
array[count++]=root;
}
void DecisionTree::train(int **feature,int *category){
totalTypes=root->getNumTypes(category);
while(true){
if(ptr==count)
break;
int index=array[ptr]->measure(feature,category);
bool devisible=false;
if(index>-1)
devisible=array[ptr]->examine(index,feature,category,totalTypes);
if(devisible){
array[ptr]->index=index;
array[ptr]->createKids();
array[count++]=array[ptr]->leftkid;
array[count++]=array[ptr]->rightkid;
array[ptr]->devideSamples(feature);
}
else
array[ptr]->majority(category,totalTypes);
ptr++;
}
}
int DecisionTree::predict(int *sample){
treeNode* ptr=this->root;
while(ptr->index!=-1){
if(sample[ptr->index]==0)
ptr=ptr->leftkid;
else
ptr=ptr->rightkid;
}
return ptr->type;
}
/*int main(){
srand((unsigned)time(NULL));
int rows=1000,cols=8;
DecisionTree tree(rows,cols);
int **feature,*category;
feature=new int*[rows];
category=new int[rows]; // 0=toxic,1=discusting but nontoxic,2=delicious
setValue(category,0,rows);
for(int i=0;i<rows;i++)
feature[i]=new int[cols];
for(int i=0;i<rows;i++)
for(int j=0;j<cols;j++)
feature[i][j]=rand()%2;
for(int i=0;i<rows;i++){
if(feature[i][0]+feature[i][3]==feature[i][5]+feature[i][6])
category[i]=5;
else if(feature[i][2]+feature[i][7]==feature[i][4]+feature[i][6])
category[i]=4;
else if(feature[i][1]+feature[i][4]+feature[i][6]==2)
category[i]=3;
else if(feature[i][0]+feature[i][1]+feature[i][2]==1)
category[i]=2;
else if(feature[i][7]&&feature[i][3]!=feature[i][5])
category[i]=1;
}
tree.train(feature,category);
int *sample=new int[cols];
int mark=0,correct=0;
while(true){
for(int i=0;i<cols;i++){
sample[i]=rand()%2;
cout<<sample[i]<<" ";
}
int result=tree.predict(sample);
cout<<result<<endl;
if(sample[0]+sample[3]==sample[5]+sample[6]){
if(result==5)
correct++;
}
else if(sample[2]+sample[7]==sample[4]+sample[6]){
if(result==4)
correct++;
}
else if(sample[1]+sample[4]+sample[6]==2){
if(result==3)
correct++;
}
else if(sample[0]+sample[1]+sample[2]==1){
if(result==2)
correct++;
}
else if(sample[7]&&sample[3]!=sample[5]){
if(result==1)
correct++;
}
else{
if(!result)
correct++;
}
mark++;
if(mark==100)
break;
}
cout<<"correct prediction:"<<correct<<endl;
return 0;
}*/
评论0
最新资源