#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
#define INF 99999
#define N 75 //样本数
#define M 4 //属性数
double traindata[N][M+1] = {5, 3, 1.6, 0.2, 1, 5, 3.4, 1.6, 0.4, 1, 5.2, 3.5, 1.5, 0.2, 1, 5.2, 3.4, 1.4, 0.2, 1, 4.7, 3.2, 1.6, 0.2, 1, 4.8, 3.1, 1.6, 0.2, 1, 5.4, 3.4, 1.5, 0.4, 1, 5.2, 4.1, 1.5, 0.1, 1, 5.5, 4.2, 1.4, 0.2, 1, 4.9, 3.1, 1.5, 0.2, 1, 5, 3.2, 1.2, 0.2, 1, 5.5, 3.5, 1.3, 0.2, 1, 4.9, 3.6, 1.4, 0.1, 1, 4.4, 3, 1.3, 0.2, 1, 5.1, 3.4, 1.5, 0.2, 1, 5, 3.5, 1.3, 0.3, 1, 4.5, 2.3, 1.3, 0.3, 1, 4.4, 3.2, 1.3, 0.2, 1, 5, 3.5, 1.6, 0.6, 1, 5.1, 3.8, 1.9, 0.4, 1, 4.8, 3, 1.4, 0.3, 1, 5.1, 3.8, 1.6, 0.2, 1, 4.6, 3.2, 1.4, 0.2, 1, 5.3, 3.7, 1.5, 0.2, 1, 5, 3.3, 1.4, 0.2, 1, 6.6, 3, 4.4, 1.4, 2, 6.8, 2.8, 4.8, 1.4, 2, 6.7, 3, 5, 1.7, 2, 6, 2.9, 4.5, 1.5, 2, 5.7, 2.6, 3.5, 1, 2, 5.5, 2.4, 3.8, 1.1, 2, 5.5, 2.4, 3.7, 1, 2, 5.8, 2.7, 3.9, 1.2, 2, 6, 2.7, 5.1, 1.6, 2, 5.4, 3, 4.5, 1.5, 2, 6, 3.4, 4.5, 1.6, 2, 6.7, 3.1, 4.7, 1.5, 2, 6.3, 2.3, 4.4, 1.3, 2, 5.6, 3, 4.1, 1.3, 2, 5.5, 2.5, 4, 1.3, 2, 5.5, 2.6, 4.4, 1.2, 2, 6.1, 3, 4.6, 1.4, 2, 5.8, 2.6, 4, 1.2, 2, 5, 2.3, 3.3, 1, 2, 5.6, 2.7, 4.2, 1.3, 2, 5.7, 3, 4.2, 1.2, 2, 5.7, 2.9, 4.2, 1.3, 2, 6.2, 2.9, 4.3, 1.3, 2, 5.1, 2.5, 3, 1.1, 2, 5.7, 2.8, 4.1, 1.3, 2, 7.2, 3.2, 6, 1.8, 3, 6.2, 2.8, 4.8, 1.8, 3, 6.1, 3, 4.9, 1.8, 3, 6.4, 2.8, 5.6, 2.1, 3, 7.2, 3, 5.8, 1.6, 3, 7.4, 2.8, 6.1, 1.9, 3, 7.9, 3.8, 6.4, 2, 3, 6.4, 2.8, 5.6, 2.2, 3, 6.3, 2.8, 5.1, 1.5, 3, 6.1, 2.6, 5.6, 1.4, 3, 7.7, 3, 6.1, 2.3, 3, 6.3, 3.4, 5.6, 2.4, 3, 6.4, 3.1, 5.5, 1.8, 3, 6, 3, 4.8, 1.8, 3, 6.9, 3.1, 5.4, 2.1, 3, 6.7, 3.1, 5.6, 2.4, 3, 6.9, 3.1, 5.1, 2.3, 3, 5.8, 2.7, 5.1, 1.9, 3, 6.8, 3.2, 5.9, 2.3, 3, 6.7, 3.3, 5.7, 2.5, 3, 6.7, 3, 5.2, 2.3, 3, 6.3, 2.5, 5, 1.9, 3, 6.5, 3, 5.2, 2, 3, 6.2, 3.4, 5.4, 2.3, 3, 5.9, 3, 5.1, 1.8, 3};
double testdata [N][M+1] = {5.1, 3.5, 1.4, 0.2, 1, 4.9, 3, 1.4, 0.2, 1, 4.7, 3.2, 1.3, 0.2, 1, 4.6, 3.1, 1.5, 0.2, 1, 5, 3.6, 1.4, 0.2, 1, 5.4, 3.9, 1.7, 0.4, 1, 4.6, 3.4, 1.4, 0.3, 1, 5, 3.4, 1.5, 0.2, 1, 4.4, 2.9, 1.4, 0.2, 1, 4.9, 3.1, 1.5, 0.1, 1, 5.4, 3.7, 1.5, 0.2, 1, 4.8, 3.4, 1.6, 0.2, 1, 4.8, 3, 1.4, 0.1, 1, 4.3, 3, 1.1, 0.1, 1, 5.8, 4, 1.2, 0.2, 1, 5.7, 4.4, 1.5, 0.4, 1, 5.4, 3.9, 1.3, 0.4, 1, 5.1, 3.5, 1.4, 0.3, 1, 5.7, 3.8, 1.7, 0.3, 1, 5.1, 3.8, 1.5, 0.3, 1, 5.4, 3.4, 1.7, 0.2, 1, 5.1, 3.7, 1.5, 0.4, 1, 4.6, 3.6, 1, 0.2, 1, 5.1, 3.3, 1.7, 0.5, 1, 4.8, 3.4, 1.9, 0.2, 1, 7, 3.2, 4.7, 1.4, 2, 6.4, 3.2, 4.5, 1.5, 2, 6.9, 3.1, 4.9, 1.5, 2, 5.5, 2.3, 4, 1.3, 2, 6.5, 2.8, 4.6, 1.5, 2, 5.7, 2.8, 4.5, 1.3, 2, 6.3, 3.3, 4.7, 1.6, 2, 4.9, 2.4, 3.3, 1, 2, 6.6, 2.9, 4.6, 1.3, 2, 5.2, 2.7, 3.9, 1.4, 2, 5, 2, 3.5, 1, 2, 5.9, 3, 4.2, 1.5, 2, 6, 2.2, 4, 1, 2, 6.1, 2.9, 4.7, 1.4, 2, 5.6, 2.9, 3.6, 1.3, 2, 6.7, 3.1, 4.4, 1.4, 2, 5.6, 3, 4.5, 1.5, 2, 5.8, 2.7, 4.1, 1, 2, 6.2, 2.2, 4.5, 1.5, 2, 5.6, 2.5, 3.9, 1.1, 2, 5.9, 3.2, 4.8, 1.8, 2, 6.1, 2.8, 4, 1.3, 2, 6.3, 2.5, 4.9, 1.5, 2, 6.1, 2.8, 4.7, 1.2, 2, 6.4, 2.9, 4.3, 1.3, 2, 6.3, 3.3, 6, 2.5, 3, 5.8, 2.7, 5.1, 1.9, 3, 7.1, 3, 5.9, 2.1, 3, 6.3, 2.9, 5.6, 1.8, 3, 6.5, 3, 5.8, 2.2, 3, 7.6, 3, 6.6, 2.1, 3, 4.9, 2.5, 4.5, 1.7, 3, 7.3, 2.9, 6.3, 1.8, 3, 6.7, 2.5, 5.8, 1.8, 3, 7.2, 3.6, 6.1, 2.5, 3, 6.5, 3.2, 5.1, 2, 3, 6.4, 2.7, 5.3, 1.9, 3, 6.8, 3, 5.5, 2.1, 3, 5.7, 2.5, 5, 2, 3, 5.8, 2.8, 5.1, 2.4, 3, 6.4, 3.2, 5.3, 2.3, 3, 6.5, 3, 5.5, 1.8, 3, 7.7, 3.8, 6.7, 2.2, 3, 7.7, 2.6, 6.9, 2.3, 3, 6, 2.2, 5, 1.5, 3, 6.9, 3.2, 5.7, 2.3, 3, 5.6, 2.8, 4.9, 2, 3, 7.7, 2.8, 6.7, 2, 3, 6.3, 2.7, 4.9, 1.8, 3, 6.7, 3.3, 5.7, 2.1, 3};
using namespace std;
struct attribute{
int attribute_name; //决策树中的节点的属性名
double threshold; //用上述属性进行二分类时的界限
};
struct node{
int level; //当前节点处于树的第几层
attribute elem;//决策树上节点的数据
node* lchild; //决策树上节点的左孩子
node* rchild; //决策树上节点的右孩子
};
//计算H(S1)或H(S2)
double calculate_entropy(int count_1, int count_2, int count_3){
double p1=0, p2=0, p3=0, H = 0;
p1 = 1.0 * count_1/(count_1+count_2+count_3);
p2 = 1.0 * count_2/(count_1+count_2+count_3);
p3 = 1.0 * count_3/(count_1+count_2+count_3);
// H = -p1*log(p1) -p2*log(p2) -p3*log(p3); //如果直接用此式会有bug
if(p1 > 0) H = H - p1*log(p1);
if(p2 > 0) H = H - p2*log(p2);
if(p3 > 0) H = H - p3*log(p3);
return H;
}
//返回当前数据集中数量最多的标签名,实际上最终可以不要此函数,之前忽略了"叶节点上的数据集中所有数据标签全相同"。
int which_label(double data[][M+1]){
int count_1=0, count_2=0, count_3=0;
for(int k=0; k<N; k++){
if(data[k][M] == 1) count_1++;
else if(data[k][M] == 2) count_2++;
else if(data[k][M] == 3) count_3++;
}
if(count_1>=count_2 && count_1>=count_3) return 1;
else if(count_2>=count_1 && count_2>=count_3) return 2;
else if(count_3>=count_1 && count_3>=count_2) return 3;
}
//判断当前数据集data中是否所有数据的标签都相同
bool judge_same_label(double data[][M+1], int n){
int label = data[0][M];
for(int i=0; i<n; i++)
if(data[i][M] == label)
continue;
else
return false;
return true;
}
//递归执行ID3算法,根据训练集生成决策树
node* recursion(double data[][M+1], vector<int> attributes, int n){
attribute elem;
node* root = new node;
int count_1=0, count_2=0, count_3=0;
double H1=0, H2=0, H=0, entropy=INF;
//迭代更新来选择一个最佳的attribute,使得它对应的条件熵entropy最小
for(int m=0,j; m<attributes.size(); m++){
j = attributes[m];
//========================================================================================
//分割线中的部分:依据attributes表中的第j个属性,对样本data进行升序排序
for(int ii=0; ii<n; ii++)
for(int jj=1; jj<n-ii; jj++) //jj的起始位置为1,终止位置为n-ii
if(data[jj][j] < data[jj-1][j])
swap(data[jj], data[jj-1]);
//========================================================================================
for(int i=0; i<n; i++){
while(data[i][j] == data[i+1][j])
i++;
//统计前i个样本中的标签,有count_1个y1, count_2个y2, count_3个y3
count_1 = 0, count_2 = 0, count_3 = 0;
for(int k=0; k<=i; k++){
if(data[k][M] == 1) count_1++;
else if(data[k][M] == 2) count_2++;
else if(data[k][M] == 3) count_3++;
}
H1 = calculate_entropy(count_1, count_2, count_3); //计算H(S1)
//统计后N-i个样本中的标签,有count_1个y1, count_2个y2, count_3个y3
count_1 = 0, count_2 = 0, count_3 = 0;
for(int k=i+1; k<n; k++){
if(data[k][M] == 1) count_1++;
else if(data[k][M] == 2) count_2++;
else if(data[k][M] == 3) count_3++;
}
H2 = calculate_entropy(count_1, count_2, count_3); //计算H(S2)
//计算条件为属性j 且j的二分点为(data[i][j]+data[i+1][j])/2 时的条件熵H
H = 1.0* i/n *H1 + (1 - 1.0* i/n)*H2;
//递归出口
if( judge_same_label(data,n) ){
root->elem.attribute_name = which_label(data);
root->lchild = NULL; root->rchild = NULL;
return root;
}
//与已经记录的H做比较,迭代更新为最小的H值
if(H < entropy){
entropy = H;
elem.attribute_name = j;
elem.threshold = (1.0* data[i][j] + data[i+1][j]) / 2.0;
}
}
}
root->elem = elem;
//生成子节点递归时的属性集
vector<int> new_attributes = attributes;
/* for(vector<int>::iterator it = new_attributes.begin(); it!=new_attributes.end(); it++)
if((*it) == elem.attribute_name)
{ new_attributes.erase(it); break; }*/
//生成子节点递归时的数据集
double new_data_l[N][M+1], new_data_r[N][M+1];
int n_l = 0, n_r = 0;
for(int i=0; i<n; i++)
if( data[i][elem.attribute_name] <= elem.threshold )
for(int j=0; j<M+1; j++) new_data_l[n_l /5][j] = data[i][j], n_l = n_l+1;
else
for(int j=0; j<M+1; j++) new_data_r[n_r /5][j] = data[i][j], n_r = n_r+1;
//递归
root->lchild = recursion(new_data_l, new_attributes, n_l/5);
root->rchild = recursion(new_data_r, new_attributes, n_r/5);
return root;
}
//输入一行测试数据,查询决策树对它输出的标签
int querry(double data[M+1], node* root){
if(!root-
- 1
- 2
前往页