import json
import os
from math import log2
class DecisionTree(object):
def __init__(self, train_set, feature_index_set):
"""
根据训练集构建决策树
:param train_set:全量训练数据集
:param feature_index_set: 所选取的特征的集合
"""
print('construct tree, index:')
train_index_set = [x['id'] for x in train_set]
print('len', len(train_index_set))
self.train_set = train_set
self.feature_index_set = feature_index_set
self.count_class = partition_by_class(self.train_set)
self.majority = max(self.count_class, key=lambda x: len(self.count_class[x]))
self.is_leaf = False
self.best_feature = None
self.prediction = None
self.sub_trees = None
if not self.is_need_to_partition():
# 如果不需要划分,则对该节点作为一个叶子节点处理
self.is_leaf = True
print('不需要划分')
print('预测值为:', self.prediction)
print('===========================')
return
self.best_feature = self.get_best_feature()
self.count_best_feature = partition_by_feature(self.train_set, self.best_feature)
if len(self.count_best_feature) == 1:
# 按当前特征无法划分数据集
self.is_leaf = True
value = self.count_best_feature.popitem()[0]
print('按当前特征无法划分')
print('因为当前数据在特征[%d]下取值都为%s' % (self.best_feature, value))
self.prediction = self.majority
print('预测值为:', self.prediction)
print('===================================')
return
print('按特征[%d]对数据集划分' % self.best_feature)
print('================================================')
print('************************************************')
self.sub_trees = list()
for feature_value, index_list in self.count_best_feature.items():
print('特征:%d,取值: %s' % (self.best_feature, feature_value))
list(feature_index_set).remove(self.best_feature)
new_train_set = [x for x in train_set if x['id'] in index_list]
tree = DecisionTree(new_train_set, feature_index_set)
temp_dict = dict(feature=self.best_feature, value=feature_value, tree=tree)
self.sub_trees.append(temp_dict)
def get_best_feature(self):
info_gains = list()
for feature in self.feature_index_set:
info_gain = get_information_gain(self.train_set, feature)
temp_tuple = (feature, info_gain)
print('feature: %d, info_gain: %f' % (feature, info_gain))
info_gains.append(temp_tuple)
best_feature = max(info_gains, key=lambda x: x[1])[0]
max_info_gain = -1
for feature, info_gain in info_gains:
if info_gain >= max_info_gain:
max_info_gain = info_gain
best_feature = feature
print('当前数据下的最优特征为: %d' % best_feature)
return best_feature
def is_need_to_partition(self):
"""
判断当前节点是否需要继续划分
:return:
"""
if len(self.train_set) == 0:
# 当前节点对应数据集为空,不需要划分
self.prediction = None
print('当前节点对应数据集为空,不需要划分')
return False
if len(self.count_class) == 1:
# 当前数据集包含的样本属于同一类,不需要划分
self.prediction = self.count_class.popitem()[0]
print('当前数据集包含的样本属于同一类,不需要划分')
return False
current_feature_list = self.train_set[0]['feature']
for data in self.train_set[1:]:
# 当前数据特征不完全一样,还需要继续划分
if data['feature'] != current_feature_list:
return True
else:
current_feature_list = data['feature']
# 当前数据特征完全一样,无法继续划分
print('当前数据特征完全一样,无法继续划分')
self.prediction = self.majority
return False
def get_tree_dict(self):
temp_dict = dict()
temp_dict['best_feature'] = self.best_feature
if self.is_leaf:
temp_dict['prediction'] = self.prediction
if not self.sub_trees:
return temp_dict
temp_dict['sub_trees'] = list()
for tree in self.sub_trees:
data = tree['tree'].get_tree_dict()
temp_dict['sub_trees'].append(dict(value=tree['value'], tree=data))
return temp_dict
def predict(self, data):
"""
在当前树下预测data属于哪一类
:param data:
:return:
"""
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
if self.is_leaf:
print('data[%d]的预测值为:%s' % (data['id'], self.prediction))
return self.prediction
else:
print('此时判断特征【%d】' % self.best_feature)
value = data['feature'][self.best_feature]
print('data[%d]的feature[%d]为%s' % (data['id'], self.best_feature, value))
for tree in self.sub_trees:
if tree['value'] == value:
return tree['tree'].predict(data)
def partition_by_feature(data_set, feature):
"""
按特征对数据进行划分
:param data_set: 按类划分data_set数据集
:param feature: 按类划分data_set数据集
:return: 一个字典,键为一个类名,值为该类对应的数据id
"""
r_dict = dict()
for data in data_set:
f = data['feature'][feature]
if f not in r_dict.keys():
r_dict[f] = list()
r_dict[f].append(data['id'])
return r_dict
def partition_by_class(data_set):
"""
按类对数据进行划分
:param data_set: 按类划分data_set数据集
:return: 一个字典,键为一个类名,值为该类对应的数据id
"""
r_dict = dict()
for data in data_set:
c = data['class']
if c not in r_dict.keys():
r_dict[c] = list()
r_dict[c].append(data['id'])
return r_dict
def get_entropy(data_set):
"""
计算data_set数据集的熵
:param: 待计算的数据集
:return:熵
"""
count_class = partition_by_class(data_set)
result = 0
for c, c_list in count_class.items():
freq = len(c_list) / len(data_set)
result -= freq * log2(freq)
return result
def get_condition_entropy(data_set, feature):
"""
计算当前数据集在给定特征feature下的条件熵
:param data_set: 待计算的数据集
:param feature: 给定的特征
:return: 条件熵
"""
result = 0
count_feature = partition_by_feature(data_set, feature)
for feature_value, index_list in count_feature.items():
# feature_value为特征的一个取值
# index_list为该取值对应的样本id
current_data = [x for x in data_set if x['id'] in index_list]
freq = len(current_data) / len(data_set)
temp_entropy = get_entropy(current_data)
result += freq * temp_entropy
return result
def get_information_gain(data_set, feature):
"""
计算当前数据集在给定特征feature下的条件熵
:param data_set: 待计算的数据集
:param feature: 给定的特征
:return: 信息增益
"""
entropy = get_entropy(data_set)
condition_entropy = get_condition_entropy(data_set, feature)
information_gain = entropy - condition_entropy
return information_gain
if __name__ == '__main__':
path = os.path.join(os.getcwd(),
没有合适的资源?快使用搜索试试~ 我知道了~
基于决策树的蘑菇分类.zip
共27个文件
json:13个
py:5个
xml:3个
需积分: 5 2 下载量 113 浏览量
2024-04-25
18:59:34
上传
评论
收藏 227KB ZIP 举报
温馨提示
决策树 决策树(Decision Tree)是一种在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法。由于这种决策分支画成图形很像一棵树的枝干,因此得名决策树。在机器学习中,决策树是一个预测模型,代表的是对象属性与对象值之间的一种映射关系。 决策树的应用场景非常广泛,包括但不限于以下几个方面: 金融风险评估:决策树可以用于预测客户借款违约概率,帮助银行更好地管理风险。通过客户的历史数据构建决策树,可以根据客户的财务状况、征信记录、职业等信息来预测违约概率。 医疗诊断:医生可以通过病人的症状、体征、病史等信息构建决策树,根据不同的症状和体征来推断病情和诊断结果,从而帮助医生快速、准确地判断病情。 营销策略制定:企业可以通过客户的喜好、购买记录、行为偏好等信息构建决策树,根据不同的特征来推断客户需求和市场走势,从而制定更有效的营销策略。 网络安全:决策树可以用于网络安全领域,帮助企业防范网络攻击、识别网络威胁。通过网络流量、文件属性、用户行为等信息构建决策树,可以判断是否有异常行为和攻击威胁。
资源推荐
资源详情
资源评论
收起资源包目录
基于决策树的蘑菇分类.zip (27个子文件)
content
preprocess.py 926B
xzk.py 876B
decision_tree.py 8KB
mushrooms.csv 365KB
data
data_4.json 120KB
data_8.json 120KB
data_7.json 120KB
data_6.json 120KB
data_9.json 120KB
data_5.json 120KB
data_0.json 120KB
data_1.json 120KB
data_3.json 120KB
data_2.json 120KB
total.json 1.17MB
.idea
vcs.xml 180B
misc.xml 185B
modules.xml 282B
MushroomClassfy.iml 453B
sonarlint
issuestore
index.pb 188B
f
9
f9369dfc8dccf9a62f1f8ddcb9c7b2da07133a89 0B
4
6
46a3e0eeda104eadc31cc20ef438e684f0c84c50 68B
3
8
382352104437719732864a63311912c50818d683 0B
classfier.py 186B
p.json 2KB
my_test.py 778B
test_data.json 1KB
共 27 条
- 1
资源评论
生瓜蛋子
- 粉丝: 3927
- 资源: 7441
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 环形导轨椭圆线体STEP全套设计资料100%好用.zip
- 第八章_焊接金相学.ppt
- 常用金属材料的焊接.ppt
- 管理者的目标计划执行.pptx
- 超(超)临界锅炉用新型耐热钢的焊接及热处理.ppt
- 第二章_焊接检验员安全须知.ppt
- 第七章_焊接检验中的公制英制单位制转换.ppt
- 第四章_焊接接头的几何形状及焊接符号.ppt
- 第一章_焊接检验及资格认证.ppt
- 典型焊接结构的生产工艺.ppt
- 第五章_焊接检验及资格认可的有关资料.ppt
- 钢制压力容器焊接工艺评定.ppt
- 过程装备制造Chapter 2 焊接变形与应力.ppt
- 过程装备制造Chapter 1 焊接接头与焊接规范.ppt
- 过程装备制造Chapter 4 焊接结构的断裂失效与防治.ppt
- 过程装备制造Chapter 3 焊接接头的强度计算.ppt
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功