import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
public class ID3 {
//private int NT;
private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称
private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据 ,即arff文件中的data字符串
int decatt; // 决策变量在属性集中的索引
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
//正则表达,其中*? 表示重复任意次,但尽可能少重复,防止匹配到后面的"}"符号
Document xmldoc;
Element root;
public ID3() {
xmldoc = DocumentHelper.createDocument();
root = xmldoc.addElement("root");
root.addElement("DecisionTree").addAttribute("value", "null");
}
public static void main(String[] args) {
ID3 inst = new ID3();
//inst.NT=0;
inst.readARFF(new File("E:/MyeclipseWorkspace/ID3/src/weather.nominal.arff"));
inst.setDec("play"); //设置类标为第一个索引位置
LinkedList<Integer> ll=new LinkedList<Integer>();
for(int i=0;i<inst.attribute.size();i++){
if(i!=inst.decatt)
ll.add(i);
}
ArrayList<Integer> al=new ArrayList<Integer>();
for(int i=0;i<inst.data.size();i++){
al.add(i);
}
inst.buildDT("DecisionTree", "null", al, ll);
inst.writeXML("E:/MyeclipseWorkspace/ID3/src/dt.xml");
return;
}
//读取arff文件,给attribute、attributevalue、data赋值
public void readARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
attribute.add(matcher.group(1).trim()); //获取第一个括号里的内容
//涉及取值,尽量加.trim(),后面也可以看到,即使是换行符也可能会造成字符串不相等
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
attributevalue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
//设置决策变量
public void setDec(int n) {
if (n < 0 || n >= attribute.size()) {
System.err.println("决策变量指定错误。");
System.exit(2);
}
decatt = n;
}
public void setDec(String name) {
int n = attribute.indexOf(name);
setDec(n);
}
//给一个样本(数组中是各种情况的计数),计算它的熵
public double getEntropy(int[] arr) {
double entropy = 0.0;
int sum = 0;
for (int i = 0; i < arr.length; i++) {
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
sum += arr[i];
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
//给一个样本数组及样本的算术和,计算它的熵
public double getEntropy(int[] arr, int sum) {
double entropy = 0.0;
for (int i = 0; i < arr.length; i++) {
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
//判断类标是否统一,统一则之后即为叶节点
public boolean infoPure(ArrayList<Integer> subset) {
String value = data.get(subset.get(0))[decatt];
//System.out.println("第一个value="+value);
for (int i = 1; i < subset.size(); i++) {
String next=data.get(subset.get(i))[decatt];
//equals表示对象内容相同,==表示两个对象指向的是同一片内存
System.out.println("next="+next);
if (!value.trim().equals(next.trim()))
//System.out.println("返回false");
return false;
}
System.out.println("返回true");
return true;
}
// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
public double calNodeEntropy(ArrayList<Integer> subset, int index) {
int sum = subset.size();
//System.out.println("sum="+sum);
//System.out.println("index="+index);
double entropy = 0.0;
int[][] info = new int[attributevalue.get(index).size()][];
for (int i = 0; i < info.length; i++)
info[i] = new int[attributevalue.get(decatt).size()];
int[] count = new int[attributevalue.get(index).size()];
for (int i = 0; i < sum; i++) {
int n = subset.get(i);
String nodevalue = data.get(n)[index];
int nodeind = attributevalue.get(index).indexOf(nodevalue);
count[nodeind]++;
String decvalue = data.get(n)[decatt];
//System.out.println(attributevalue.get(decatt).indexOf("no"));
int decind = attributevalue.get(decatt).indexOf(decvalue.trim());
info[nodeind][decind]++;
}
for (int i = 0; i < info.length; i++) {
System.out.println("info.len="+info.length);
System.out.println("N+entropy="+entropy);
System.out.println("getinfo[i]="+getEntropy(info[i]));
//System.out.println("count[i]="+count[i]);
//System.out.println("getEn="+getEntropy(info[i]) * count[i] / sum);
entropy += getEntropy(info[i]) * count[i] / sum;
}
return entropy;
}
// 构建决策树 (主要函数)
public void buildDT(String name, String value, ArrayList<Integer> subset,
LinkedList<Integer> selatt) {
//NT+=1;
//System.out.println("现在是第"+NT+"决策树");
//System.out.println("subset="+subset);
//System.out.println("selat="+selatt);
Element ele = null;
@SuppressWarnings("unchecked")
List<Element> list = root.selectNodes("//"+name);
Iterator<Element> iter=list.iterator();
while(iter.hasNext()){
ele=iter.next();
if(ele.attributeValue("value").equals(value))
break;
}
if (infoPure(subset)) {
//System.out.println("结果唯一");
ele.setText(data.get(subset.get(0))[decatt]);
return;
}
int minIndex = -1;
double minEntropy = Double.MAX_VALUE;
for (int i = 0; i < selatt.size(); i++) {
if (i == decatt)
continue;
double entropy = calNode
评论1
最新资源