from dataloader import *
import numpy as np
import pandas as pd
import json
import sys
def mean_Y(data_Y):
return np.mean(data_Y)
def var(data_Y):
return np.var(data_Y)*len(data_Y)
def group(orgGroup, label, feature=None):
'''
:param orgGroup: [{},{}]
:param label: 'age','sex',...
:param feature: decimal or (0,1),(0,2),(0,3)
:return: [{},{}],[{},{}],avg1,avg2
'''
tag = is_Discrete(label)
if tag == 0:
group1 = [x for x in orgGroup if x[label] <= feature]
group2 = [x for x in orgGroup if x[label] > feature]
elif tag==2:
assert feature is None
group1 = [x for x in orgGroup if x[label] == 0]
group2 = [x for x in orgGroup if x[label] == 1]
else:
def valid(cand):
if cand in feature:
return True
return False
group1 = [x for x in orgGroup if valid(x[label])]
group2 = [x for x in orgGroup if not valid(x[label])]
var1 = var([x['charges'] for x in group1])
var2 = var([x['charges'] for x in group2])
return group1,group2, var1,var2
def chooseBestSplit(orgGroup, tolS=1, tolN=4):
'''
:param orgGroup:
:param tolS: 切分前后误差小于tolS则不切分
:param tolN: 切分后组中元素小于tolN则不切分
:return: group, feature, feature_value
'''
y_data = [x['charges'] for x in orgGroup]
if np.unique(y_data).size == 1: # Y全部一样
return orgGroup, None, mean_Y(y_data)
TosErr = var(y_data)
bestErr = np.inf
bestLabel = 'age'
bestFeaVal = 0
bestGroup1 = []
bestGroup2 = []
for label in ['age','sex','bmi','children','smoker','region']:
tag = is_Discrete(label)
if tag == 0: #连续型
label_feature = set([x[label] for x in orgGroup])
label_feature = list(label_feature)
label_feature.sort()
for value in label_feature[:-1]:
group1, group2, err1, err2 = group(orgGroup, label, value)
if len(group1) < tolN or len(group2) < tolN:
continue
if bestErr > err1+err2:
bestLabel=label
bestFeaVal = value
bestErr = (err1+err2)
bestGroup1 = group1
bestGroup2 = group2
elif tag == 2:
group1, group2, err1, err2 = group(orgGroup, label)
if len(group1) < tolN or len(group2) < tolN:
continue
if bestErr > err1 + err2:
bestLabel = label
bestFeaVal = None
bestErr = (err1 + err2)
bestGroup1 = group1
bestGroup2 = group2
else:
divideReigon = [(0,1), (0,2), (0,3)]
for value in divideReigon:
group1, group2, err1, err2 = group(orgGroup, label,value)
if len(group1) < tolN or len(group2) < tolN:
continue
if bestErr > err1 + err2:
bestLabel = label
bestFeaVal = value
bestErr = (err1 + err2)
bestGroup1 = group1
bestGroup2 = group2
if (TosErr - bestErr) < tolS:
return orgGroup, None, mean_Y(y_data)
return [bestGroup1, bestGroup2], bestLabel, bestFeaVal
def createTree(group,stopN=4,stopErr=1):
'''
:param group:
:return: dic
'''
ngroup, label, feature = chooseBestSplit(group, stopErr, stopN)
if label is None:
return feature
[left, right] = ngroup
tree = {}
tree['label'] = label
tree['feature'] = feature
tree['left'] = createTree(left,stopErr, stopN)
tree['right'] = createTree(right, stopErr, stopN)
return tree
def forward(file_path, data_path,err,n,n_fold,fold):
json_f = open(file_path,'w')
dl = Dataloader(data_path,bagging=fold,fold=n_fold)
forest=[]
for item in dl.data:
orgGroup = item
tree_dic = createTree(orgGroup,err,n)
forest.append(tree_dic)
json.dump(forest, json_f, indent=4)
json_f.close()
def predict(item, tree_dic):
if type(tree_dic) is not dict:
return tree_dic
else:
label = tree_dic['label']
feature = tree_dic['feature']
# print(label,feature)
if feature is None:
if item[label] == 0:
return predict(item, tree_dic['left'])
else:
return predict(item, tree_dic['right'])
elif type(feature) is list:
reigon = item[label]
if reigon in feature:
return predict(item, tree_dic['left'])
else:
return predict(item, tree_dic['right'])
else:
value = item[label]
if value<=feature:
return predict(item, tree_dic['left'])
else:
return predict(item, tree_dic['right'])
def test(tree_dict_path, test_file_path, has_y=True):
json_r = open(tree_dict_path,'r')
forest = json.load(json_r)
n_fold = len(forest)
dl = Dataloader(test_file_path)
data = dl.data
true_y = []
predict_y = []
for item in data:
sum = 0
for tree_dic in forest:
charge = predict(item,tree_dic)
sum += charge
charges = sum / n_fold
if has_y:
true_y.append(item['charges'])
predict_y.append(charges)
if has_y:
y_true = np.array(true_y)
y_pred = np.array(predict_y)
r2 = 1 - np.sum((y_true - y_pred) ** 2) / np.sum((y_true - np.mean(y_true)) ** 2)
print('r_score:{}'.format(r2))
return data, predict_y
def write_answer(path,test_data,predict_y):
sex_dic = { 0:'male', 1:'female'}
smoker_dic = { 0:'no', 1:'yes'}
reigon_dic = {0:'northwest', 1:'northeast', 2: 'southwest', 3:'southeast'}
fw = open(path, 'w', newline="", encoding='utf-8')
csv_write = csv.writer(fw)
csv_write.writerow(['age','sex','bmi','children','smoker','region','charges'])
for item,y in zip(test_data,predict_y):
csv_write.writerow([item['age'],sex_dic[item['sex']],item['bmi'],item['children'],smoker_dic[item['smoker']],reigon_dic[item['region']],y])
fw.close()
if __name__=="__main__":
n_fold = 10
fold = True
#forward('tree_0.5_10_all.json', 'train.csv', err=0.5, n=5,n_fold=n_fold,fold=fold)
#test('tree_0.5_10_p.json','s_test.csv')
#data, predict_y = test('tree_0.1_1_all.json','public_dataset/test_sample.csv',False)
data, predict_y = test('tree_0.5_10_p.json','test_sample.csv',False)
write_answer('submission.csv',data,predict_y)
没有合适的资源?快使用搜索试试~ 我知道了~
BUAA机器学习作业医疗花费预测.zip
共18个文件
csv:9个
xml:3个
py:3个
需积分: 5 0 下载量 64 浏览量
2024-04-16
22:22:04
上传
评论
收藏 59KB ZIP 举报
温馨提示
众所周知,人工智能是当前最热门的话题之一, 计算机技术与互联网技术的快速发展更是将对人工智能的研究推向一个新的高潮。 人工智能是研究模拟和扩展人类智能的理论与方法及其应用的一门新兴技术科学。 作为人工智能核心研究领域之一的机器学习, 其研究动机是为了使计算机系统具有人的学习能力以实现人工智能。 那么, 什么是机器学习呢? 机器学习 (Machine Learning) 是对研究问题进行模型假设,利用计算机从训练数据中学习得到模型参数,并最终对数据进行预测和分析的一门学科。 机器学习的用途 机器学习是一种通用的数据处理技术,其包含了大量的学习算法。不同的学习算法在不同的行业及应用中能够表现出不同的性能和优势。目前,机器学习已成功地应用于下列领域: 互联网领域----语音识别、搜索引擎、语言翻译、垃圾邮件过滤、自然语言处理等 生物领域----基因序列分析、DNA 序列预测、蛋白质结构预测等 自动化领域----人脸识别、无人驾驶技术、图像处理、信号处理等 金融领域----证券市场分析、信用卡欺诈检测等 医学领域----疾病鉴别/诊断、流行病爆发预测等 刑侦领域----潜在犯罪识别与预测、模拟人工智能侦探等 新闻领域----新闻推荐系统等 游戏领域----游戏战略规划等 从上述所列举的应用可知,机器学习正在成为各行各业都会经常使用到的分析工具,尤其是在各领域数据量爆炸的今天,各行业都希望通过数据处理与分析手段,得到数据中有价值的信息,以便明确客户的需求和指引企业的发展。
资源推荐
资源详情
资源评论
收起资源包目录
BUAA机器学习作业医疗花费预测.zip (18个子文件)
content
module.py 7KB
s_test.csv 4KB
.idea
Medical.iml 284B
misc.xml 185B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 266B
.gitignore 220B
dataloader.py 2KB
old_1_1.csv 11KB
DecisionTree.py 6KB
train.csv 43KB
s_train.csv 38KB
__pycache__
dataloader.cpython-38.pyc 3KB
submission.csv 13KB
test_sample.csv 9KB
public_dataset
train.csv 43KB
test_sample.csv 9KB
test.csv 3KB
共 18 条
- 1
资源评论
生瓜蛋子
- 粉丝: 3809
- 资源: 4660
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功