## 3.5 Classification and Regression Trees
> 决策树是一种广受欢迎的、强大的预测方法。它之所以受到欢迎,是因为其最终的模型对于从业人员来说易于理解,给出的决策树可以确切解释为何做出特定的预测。决策树是最简单的机器学习算法,它易于实现,可解释性强,完全符合人类的直观思维,有着广泛的应用。
>
> 同时,决策树也是更为高级的集成算法(如bagging,random forests和gradient boosting等)的基础。在本节中,您将了解Gini指数的概念、如何创建数据集的拆分、如何构建一棵树、如何利用构建的树作出分类决策以及如何在Banknote数据集上应用这些知识。
### 3.5.1 算法介绍
Classification and Regression Trees(简称CART),指的是可用于分类或回归预测建模问题的决策树算法。在本节中,我们将重点介绍如何使用CART解决分类问题,并以Banknote数据集为例进行演示。
CART模型的表示形式是一棵二叉树。每个节点表示单个输入变量(X)和该变量的分割点(假定变量是数字化的)。树的叶节点(也称作终端节点)包含用于预测的输出变量(y)。
创建二元决策树实际上是划分输入空间的过程。一般采用贪婪方法对变量进行递归的二进制拆分,使用某个成本函数(通常是Gini指数)测试不同的分割点,选择成本最高的拆分(即拆分完之后,剩余成本降到最低,亦代表这种拆分所含的“信息量”最大)。
### 3.5.2 算法讲解
#### 按属性分割数据
* 功能:切分函数,根据切分点将数据分为左右两组
* 输出:从切分点处切分后的数据结果
```c
struct dataset *test_split(int index, double value, int row, int col, double **data)
{
// 将切分结果作为结构体返回
struct dataset *split = (struct dataset *)malloc(sizeof(struct dataset));
int count1=0,count2=0;
double ***groups = (double ***)malloc(2 * sizeof(double **));
for (int i = 0; i < 2; i++)
{
groups[i]=(double **)malloc(row * sizeof(double *));
for (int j = 0; j < row; j++)
{
groups[i][j] = (double *)malloc(col * sizeof(double ));
}
}
for (int i = 0; i < row; i++)
{
if (data[i][index]<value)
{
groups[0][count1]=data[i];
count1 ++;
}else{
groups[1][count2] = data[i];
count2++;
}
}
split->splitdata = groups;
split->row1 = count1;
split->row2 = count2;
return split;
}
```
#### Gini指数
基尼指数是用于评估数据集中的拆分所常用的成本函数。数据集中的拆分涉及一个输入属性和该属性的一个值。它可以用于将训练模式分为两组。最理想的拆分是使基尼指数变为0,而最坏的情况是在二分类问题中分为每一类的概率都是50%(即基尼指数变为0.5)。
基尼系数的具体计算公式如下:
$$
G = 1-\sum^{k}_{i=1}{p_i^2}\tag{5.1}
$$
其中$k$是数据集中样本分类的数量,$p_i$表示第$i$类样本占总样本的比例。如果某一属性取多个值,则按照每一个值所占的比重进行加权平均。
例如,对于下面这些样本:
| day | deadline? | party? | lazy? | activity |
| ---- | --------- | ------ | ----- | -------- |
| 1 | urgent | yes | yes | party |
| 2 | urgent | no | yes | study |
| 3 | near | yes | yes | party |
| 4 | none | yes | no | party |
| 5 | none | no | yes | pub |
| 6 | none | yes | no | party |
| 7 | near | no | no | study |
| 8 | near | no | yes | TV |
| 9 | near | yes | yes | party |
| 10 | urgent | no | no | study |
以“deadline?”这个属性为例。首先计算deadline这个属性取每一个值的比例:
$$
P(deadline=urgent)={3\over10}\\
P(deadline=near)={4\over10}\\
P(deadline=none)={3\over10}\tag{5.2}
$$
然后分别计算deadline这个属性取每一个值下的Gini指数:
$$
P(deadline=urgent\&activity=party)={1\over3}\\
P(deadline=urgent\&activity=study)={2\over3}\\
G(urgent)=1-(({1\over3})^2+({2\over3})^2)={4\over9}\tag{5.3}
$$
$$
P(deadline=near\&activity=party)={2\over4}\\
P(deadline=near\&activity=study)={1\over4}\\
P(deadline=near\&activity=TV)={1\over4}\\
G(near)=1-(({2\over4})^2+({1\over4})^2+({1\over4})^2)={5\over8}\tag{5.4}
$$
$$
P(deadline=none\&activity=party)={2\over3}\\
P(deadline=none\&activity=pub)={1\over3}\\
G(none)=1-(({2\over3})^2+({1\over3})^2)={4\over9}\tag{5.5}
$$
最后按照取每一个值所占的比重对以上三个Gini指数做加权平均:
$$
G_1=G(deadline)={3\over10}\times{4\over9}+{4\over10}\times{5\over8}+{3\over10}\times{4\over9}={31\over60}\tag{5.6}
$$
同理可以算出按属性“party?”和“lazy?”切分时的Gini指数:
$$
G_2=G(party)={5\over10}\times[1-({5\over5})^2]+{5\over10}\times[1-(({3\over5})^2+({1\over5})^2+({1\over5})^2)]={7\over25}\tag{5.7}
$$
$$
G_3=G(lazy)={6\over10}\times[1-(({3\over6})^2+({1\over6})^2+({1\over6})^2+({1\over6})^2)]+{4\over10}\times[1-(({2\over4})^2+({2\over4})^2)]={3\over5}\tag{5.8}
$$
由于$G_2<G_1<G_3$
```{c}
double gini_index(int index,double value,int row, int col, double **dataset, double *class, int classnum)
{
float *numcount1 = (float *)malloc(classnum * sizeof(float));
float *numcount2 = (float *)malloc(classnum * sizeof(float));
for (int i = 0; i < classnum; i++)
{
numcount1[i]=numcount2[i]=0;
}
float count1 = 0, count2 = 0;
double gini1,gini2,gini;
gini1=gini2=gini=0;
// 计算每一类的个数
for (int i = 0; i < row; i++)
{
if (dataset[i][index] < value)
{
count1 ++;
for (int j = 0; j < classnum; j++)
if (dataset[i][col-1]==class[j])
numcount1[j] += 1;
}
else
{
count2++;
for (int j = 0; j < classnum; j++)
if (dataset[i][col - 1] == class[j])
numcount2[j]++;
}
}
// 判断分母是否为0,防止运算错误
if (count1==0)
{
gini1=1;
for (int i = 0; i < classnum; i++)
gini2 += (numcount2[i] / count2) * (numcount2[i] / count2);
}else if (count2==0)
{
gini2=1;
for (int i = 0; i < classnum; i++)
gini1 += (numcount1[i] / count1) * (numcount1[i] / count1);
}else
{
for (int i = 0; i < classnum; i++)
{
gini1 += (numcount1[i] / count1) * (numcount1[i] / count1);
gini2 += (numcount2[i] / count2) * (numcount2[i] / count2);
}
}
// 计算Gini指数
gini1 = 1 - gini1;
gini2 = 1 - gini2;
gini = (count1 / row) * gini1 + (count2 / row) * gini2;
free(numcount1);free(numcount2);
numcount1=numcount2=NULL;
return gini;
}
```
#### 寻找最佳分割点
我们需要根据计算出的Gini指数来决定最佳的分割点。具体做法是计算所有切分点Gini指数,选出Gini指数最小的切分点作为最后的分割点。
* 功能:选取数据的最优切分点
* 输出:数据中最优切分点下的树结构
```c
struct treeBranch *get_split(int row, int col, double **dataset, double *class, int classnum)
{
struct treeBranch *tree=(struct treeBranch *)malloc(sizeof(struct treeBranch));
int b_index=999;
double b_score = 999, b_value = 999,score;
// 计算所有切分点Gini系数,选出Gini系数最小的切分点
for (int i = 0; i < col-1; i++)
{
for (int j = 0; j < row; j++)
{
double value=dataset[j][i];
score=gini_index(i,value,row,col,dataset,class,classnum);
if (score<b_score)
{
b_score=score;
b_value=value
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
C语言手撕机器学习深度学习算法_C-machine-learning.zip (216个子文件)
BA.c 11KB
RF.c 11KB
DT.c 8KB
main.c 5KB
test_prediction.c 5KB
main.c 4KB
test_prediction.c 4KB
main.c 3KB
evaluate.c 3KB
evaluate.c 3KB
knn_model.c 3KB
evaluate.c 3KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
evaluate.c 2KB
test_prediction.c 2KB
main.c 2KB
main.c 2KB
main.c 2KB
main.c 2KB
perceptron_model.c 1KB
read_csv.c 1KB
main.c 1KB
read_csv.c 1KB
test_prediction.c 1KB
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
read_csv.c 1010B
main.c 972B
read_csv.c 970B
test_prediction.c 970B
read_csv.c 967B
k_fold.c 951B
main.c 926B
test_prediction.c 861B
main.c 844B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 820B
k_fold.c 819B
test_prediction.c 810B
test_prediction.c 786B
main.c 771B
k_fold.c 769B
k_fold.c 769B
test_prediction.c 758B
stacking_model.c 741B
main.c 704B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 690B
normalize.c 688B
test_prediction.c 621B
test_prediction.c 616B
test_prediction.c 597B
test_prediction.c 497B
rmse.c 394B
rmse.c 329B
rmse.c 329B
rmse.c 329B
score.c 329B
score.c 296B
score.c 296B
score.c 296B
score.c 296B
score.c 296B
score.c 296B
score.c 296B
score.c 296B
score.c 253B
score.c 249B
winequality-white.csv 258KB
winequality-white.csv 258KB
共 216 条
- 1
- 2
- 3
资源评论
普通网友
- 粉丝: 0
- 资源: 510
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功