package test;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
public class ID3
{
// 存储属性的名称
private ArrayList<String> attribute = new ArrayList<String>();
private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
// 原始数据
private ArrayList<String[]> data = new ArrayList<String[]>();
// 决策变量在属性集中的索引
int decatt;
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
Document xmldoc;
Element root;
public ID3()
{
xmldoc = DocumentHelper.createDocument();
root = xmldoc.addElement("root");
root.addElement("DecisionTree").addAttribute("value", "null");
}
public static void main(String[] args)
{
ID3 inst = new ID3();
inst.readARFF(new File("F:/app.arff"));
inst.setDec("precision");
// 初始属性索引
LinkedList<Integer> ll = new LinkedList<Integer>();
for (int i = 0; i < inst.attribute.size(); i++)
{
if (i != inst.decatt)
ll.add(i);
}
// 初始数据索引
ArrayList<Integer> al = new ArrayList<Integer>();
for (int i = 0; i < inst.data.size(); i++)
{
al.add(i);
}
inst.buildDT("DecisionTree", "null", al, ll);
inst.writeXML("F:/dt.xml");
return;
}
// 读取arff文件,给attribute、attributevalue、data赋值
public void readARFF(File file)
{
try
{
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null)
{
Matcher matcher = pattern.matcher(line);
if (matcher.find())
{
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
// 属性的所有取值
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values)
{
al.add(value.trim());
}
attributevalue.add(al);
}
else if (line.startsWith("@data"))
{
while ((line = br.readLine()) != null)
{
if (line == "")
continue;
String[] row = line.split(",");
data.add(row);
}
}
else
{
continue;
}
}
br.close();
System.out.println("attribute:");
System.out.println(attribute);
System.out.println("attributevalue");
System.out.println(attributevalue);
System.out.println("data:");
for (String[] strs : data)
{
for (String s : strs)
System.out.print(s + ",");
System.out.println();
}
}
catch (IOException e1)
{
e1.printStackTrace();
}
}
// 设置决策变量
public void setDec(int n)
{
if (n < 0 || n >= attribute.size())
{
System.err.println("决策变量指定错误。");
System.exit(2);
}
decatt = n;
}
public void setDec(String name)
{
int n = attribute.indexOf(name);
setDec(n);
}
// 给一个样本(数组中是各种情况的计数),计算它的熵
public double getEntropy(int[] arr)
{
double entropy = 0.0;
int sum = 0;
for (int i = 0; i < arr.length; i++)
{
entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
/ Math.log(2);
sum += arr[i];
}
entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
entropy /= sum;
return entropy;
}
// 给一个样本数组及样本的算术和,计算它的熵
public double getEntropy(int[] arr, int sum)
{
double entropy = 0.0;
for (int i = 0; i < arr.length; i++)
{
entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
/ Math.log(2);
}
entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
entropy /= sum;
return entropy;
}
// 判断子集中是否只含有单个属性
public boolean infoPure(ArrayList<Integer> subset)
{
// 子集中第一个数据决策变量的取值
String value = data.get(subset.get(0))[decatt];
for (int i = 1; i < subset.size(); i++)
{
String next = data.get(subset.get(i))[decatt];
if (!value.equals(next))
return false;
}
return true;
}
// 找到子集中决策属性最普遍的值
public String getGeneralAttr(ArrayList<Integer> subset)
{
String result = null;
int valueCount = attributevalue.get(decatt).size();
String[] values = new String[valueCount];
for (int i = 0; i < values.length; i++)
values[i] = attributevalue.get(decatt).get(i);
int[] counts = new int[valueCount];
for (int i = 0; i < counts.length; i++)
counts[i] = 0;
for (int i = 0; i < subset.size(); i++)
{
for (int j = 0; j < values.length; j++)
{
if (data.get(subset.get(i))[decatt].equals(values[j]))
counts[j]++;
}
}
int max = -1;
int k = 0;
for (int i = 0; i < counts.length; i++)
{
if (max < counts[i])
{
k = i;
max = counts[i];
}
}
result = attributevalue.get(decatt).get(k);
return result;
}
// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
public double calNodeEntropy(ArrayList<Integer> subset, int index)
{
int sum = subset.size();
double entropy = 0.0;
int[][] info = new int[attributevalue.get(index).size()][];
for (int i = 0; i < info.length; i++)
info[i] = new int[attributevalue.get(decatt).size()];
int[] count = new int[attributevalue.get(index).size()];
for (int i = 0; i < sum; i++)
{
int n = subset.get(i);
String nodevalue = data.get(n)[index];
int nodeind = attributevalue.get(index).indexOf(nodevalue);
count[nodeind]++;
String decvalue = data.get(n)[decatt];
int decind = attributevalue.get(decatt).indexOf(decvalue);
info[nodeind][decind]++;
}
for (int i = 0; i < info.length; i++)
{
entropy += getEntropy(info[i]) * count[i] / sum;
}
return entropy;
}
// 构建决策树
public void buildDT(String name, String value, ArrayList<Integer> subset,
LinkedList<Integer> selatt)
{
System.out.println();
System.out.println("buildDT");
// System.out.println("subset:" + subset);
System.out.println("selatt:" + selatt);
System.out.println("name:"+name);
System.out.println("value:"+value);
Element ele = null;
@SuppressWarnings("unchecked")
List<Element> list = root.selectNodes("//" + name);
System.out.print("list:");
for (Element e : list)
System.out.print("{"+e.getName()+":"+e.attributeValue("value")+"} ");
System.out.println();
Iterator<Element> iter = list.iterator();
//使得当前节点元素名与给定name、value相匹配、并确保当前节点之前未被处理过
Element temp=null;
while (iter.hasNext())
{
ele = iter.next();
if (ele.attributeValue("value").equals(value))
{
temp=ele;
}
}
ele=temp;
System.out.println("ele:"+"{"+ele.getName()+":"+ele.attributeValue("value")+"}");
//如果当前节点只含有单一属性
if (infoPure(subset))
{
System.out.println("isPure");
ele.setText(data.get(subset.get(0))[decatt]);
return;
}
// 如果当前属性集(subset)为空,返回叶节点,其值为当前子集中决策属性最普遍的值
if (selatt.size() == 0)
{
System.out.println("single attribute");
ele.setText(getGeneralAttr(subset));
return;
}
int minIndex = -1;
double minEntropy = Double.MAX_VALUE;
for (int i = 0; i < selatt.size(); i++)
{
if (i == decatt)
continue