# -*- coding: utf-8 -*-
"""
Created on Wed Dec 13 10:27:08 2017
@author: ChenYing
"""
import trees
import treePlotter
import csv
from sklearn import tree
import numpy as np
import matplotlib.pyplot as plt
def translate(filename):
age = {'0-25':0,'25-50':1,'50-75':2,'75-100':3}
csvfile = file(filename, 'rb')
reader = csv.reader(csvfile)
data = []
for line in reader:
data.append(line)
csvfile.close()
new_data = []
mark = 0
for dataline in data:
x = [0,0,0,0,0,0,0,0,0,0,0,0,0]
if mark ==0:
new_data.append(dataline)
mark += 1
else:
agenum = int(dataline[0])
if agenum>=0 and agenum<25:
x[0] = age['0-25']
elif agenum>=25 and agenum<50:
x[0] = age['25-50']
elif agenum>=50 and agenum<75:
x[0] = age['50-75']
elif agenum>=75:
x[0] = age['75-100']
x[1] = dataline[1]
x[2] = dataline[2]
x[3] = dataline[3]
x[4] = dataline[4]
x[5] = dataline[5]
x[6] = dataline[6]
x[7] = dataline[7]
gain = int(dataline[8])
if gain>0:
x[8] = '>0'
else:
x[8] = '=0'
loss = int(dataline[9])
if loss>0:
x[9] = '>0'
else:
x[9] = '=0'
hour = int(dataline[10])
if hour == 40:
x[10] = '=40'
elif hour > 40:
x[10] = '>40'
elif hour < 40:
x[10] = '<40'
if dataline[11] == 'United-States' :
x[11] = 'USA'
else:
x[11] = 'not USA'
if dataline[12] == '<=50K':
x[12] = '<=50K'
else:
x[12] = '>50K'
new_data.append(x)
return new_data
def translateToValue(filename): #把数据集转换成数值型的
age = {'0-25':0,'25-50':1,'50-75':2,'75-100':3}
capital_gain = {'=0':0, '>0':1} #10
capital_loss = {'=0':0, '>0':1} #11
hours_per_week = {'=40':0, '>40':1, '<40':2} #12
native_country = {'USA':0, 'not USA':1} #13
workclass= {'Freelance': 1, 'Other': 3, 'Proprietor': 4, 'Private': 2, 'Government': 0}
education= {'Primary': 2, 'Tertiary': 0, 'Secondary': 1}
maritial_status= {'1': 1, '0': 0}
occupation= {'High': 1, 'Med': 2, 'Low': 0}
relationship= {'Other': 0, 'Husband': 1, 'Wife': 2}
race= {'1': 0, '0': 1}
sex= {'Male': 0, 'Female': 1}
income = {'<=50K':0, '>50K':1}
csvfile = file(filename, 'rb')
reader = csv.reader(csvfile)
data = []
for line in reader:
data.append(line)
csvfile.close()
new_data = []
mark = 0
for dataline in data:
x = [0,0,0,0,0,0,0,0,0,0,0,0,0]
if mark ==0:
new_data.append(dataline)
mark += 1
else:
agenum = int(dataline[0])
if agenum>=0 and agenum<25:
x[0] = age['0-25']
elif agenum>=25 and agenum<50:
x[0] = age['25-50']
elif agenum>=50 and agenum<75:
x[0] = age['50-75']
elif agenum>=75:
x[0] = age['75-100']
x[1] = workclass[dataline[1]]
x[2] = education[dataline[2]]
x[3] = maritial_status[dataline[3]]
x[4] = occupation[dataline[4]]
x[5] = relationship[dataline[5]]
x[6] = race[dataline[6]]
x[7] = sex[dataline[7]]
gain = int(dataline[8])
if gain>0:
x[8] = capital_gain['>0']
else:
x[8] = capital_gain['=0']
loss = int(dataline[9])
if loss>0:
x[9] = capital_loss['>0']
else:
x[9] = capital_loss['=0']
hour = int(dataline[10])
if hour == 40:
x[10] = hours_per_week['=40']
elif hour > 40:
x[10] = hours_per_week['>40']
elif hour < 40:
x[10] = hours_per_week['<40']
if dataline[11] == 'United-States' :
x[11] = native_country['USA']
else:
x[11] = native_country['not USA']
if dataline[12] == '<=50K':
x[12] = income['<=50K']
else:
x[12] = income['>50K']
new_data.append(x)
return new_data
def write_new_data():
#adult_data_all在原始数据的基础上对某些属性做了一定的合并、修改等
new_data_value = translateToValue('adult_data_all.csv')
with open( './new_data_value.csv', 'wb') as f:
writer = csv.writer(f)
writer.writerows(new_data_value)
f.close()
new_data_value_test = translateToValue('adult_test_all.csv')
with open( './new_data_value_test.csv', 'wb') as f:
writer = csv.writer(f)
writer.writerows(new_data_value_test)
f.close()
new_data = translate('adult_data_all.csv')
with open( './new_data.csv', 'wb') as f:
writer = csv.writer(f)
writer.writerows(new_data)
f.close()
new_data_test = translate('adult_test_all.csv')
with open( './new_data_test.csv', 'wb') as f:
writer = csv.writer(f)
writer.writerows(new_data_test)
f.close()
def readData(filename):
csvfile = file(filename, 'rb')
reader = csv.reader(csvfile)
data_all = [] #训练数据集
data_feature = [] #特征列
data_label = [] #标签列
mark = 0
featurnlen = 0
for line in reader:
if mark ==0:
featurnlen = len(line) - 1
mark += 1
else:
data_all.append(line)
data_feature.append(line[0:featurnlen])
data_label.append(line[-1])
csvfile.close()
return data_all,data_feature,data_label
#调用sklearn的决策树函数
def use_sklearn_tree():
train_data,trainX,trainY = readData('new_data_value.csv')
test_data,testX,testY = readData('new_data_value_test.csv')
model = tree.DecisionTreeClassifier()
model.max_depth = 8
model.min_samples_split = 9
model.fit(trainX, trainY)
predict = model.predict(testX)
accuratyNum = 0
total = 0
for index in range(len(predict)):
if predict[index] == testY[index]:
accuratyNum += 1
total += 1
print "when use the sklearn............"
importances = model.feature_importances_
# print "the accuratyNum is",accuratyNum
# print "the total num is",total
print "the accuraty is"
accuracy = float(accuratyNum)/total
print 'accuracy: %.2f%%' % (100 * accuracy)
return model.tree_
def use_myTree():
adultLabels = ['age','workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']
adultLabels_test = ['age','workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']
adult = readData('new_data.csv')[0]
adult_test = readData('new_data_test.csv')[1]
adult_test_label = readData('new_data_test.csv')[2]
adultTree = trees.createTree(adult,adultLabels) #生成决策树
treePlotter.createPlot(adultTree) #画出决策树