package jay.NaiveBayes;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
public class NaiveBayes extends Configured implements Tool {
private static final String BASE_PATH = "hdfs://localhost:9000/user/jay/NaiveBayes/";
private static final String BASE_DATA_PATH = "hdfs://localhost:9000/user/jay/NaiveBayes/"
+ "OriginalData/20news-bydate/20news-bydate-test/";
private static final String INPUT_PATH_1 = BASE_PATH + "InputSequenceData";
private static final String OUTPUT_PATH = BASE_PATH
+ "ResultOfClassification";
private static final String CLASS_DICT_PATH = BASE_PATH
+ "Dict/classDict.list"; // 类词典位置
private static final String TERM_FREQUENCE_IN_CLASS = BASE_PATH
+ "TermFrequenceInClass/part-r-00000";
public static class NaiveBayesMapper extends
Mapper<Text, BytesWritable, Text, Text> {
private Text docID = new Text();
private Text classAndProbility = new Text();
private Map<String, Double> classProbility = new HashMap<String, Double>(); // 类的先验概率
private Map<String, Double> termInClassProbility = new HashMap<String, Double>(); // 单词在具体类中的后验概率
// 匹配英文正则表达式
private static final Pattern PATTERN = Pattern.compile("[/sa-zA-Z]+");
// 停用词表
private static String[] stopWordsArray = { "A", "a", "the", "an", "in",
"on", "and", "The", "As", "as", "AND" };
private static Vector stopWords;
// 类别数组
// private String[] classGroup = {"alt.atheism", "comp.graphics",
// "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware" };
private String[] classGroup = { "alt.atheism" };
/**
* 加载先验概率和后验概率
*/
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
if (null == stopWords) {
stopWords = new Vector<String>();
for (String tem : stopWordsArray) {
stopWords.add(tem);
}
}
Configuration conf = context.getConfiguration();
Path classDict = new Path(conf.get("CLASS_DICT_PATH"));
Path termFrequence = new Path(conf.get("TERM_FREQUENCE_IN_CLASS"));
FileSystem fs = FileSystem.get(conf);
// 从类中文件读出具体类文件数,并写入map函数中
SequenceFile.Reader reader = new SequenceFile.Reader(fs, classDict,
conf);
IntWritable value = new IntWritable();
Text key = new Text();
double sum = 0;
Map<String, Integer> tmpclassProbility = new HashMap<String, Integer>();
while (reader.next(key, value)) {
tmpclassProbility.put(key.toString(), value.get());
sum += value.get();
}
reader.close();
// 算出类的概率
Iterator it = tmpclassProbility.keySet().iterator();
// int flag = 0;
while (it.hasNext()) {
// flag ++;
String tmpKey = "";
tmpKey = (String) it.next();
double tmpValue = tmpclassProbility.get(tmpKey) / sum;
// System.out.println(tmpKey + " " + tmpValue + " "
// + tmpclassProbility.size());
classProbility.put(tmpKey, tmpValue);
}
SequenceFile.Reader reader1 = new SequenceFile.Reader(fs,
termFrequence, conf);
IntWritable value1 = new IntWritable();
Text key1 = new Text();
double sum1 = 0;
Map<String, Integer> tmptermInClassProbility = new HashMap<String, Integer>();
while (reader1.next(key1, value1)) {
tmptermInClassProbility.put(key1.toString(), value1.get());
sum1 += value1.get();
}
reader1.close();
// 算出单词在类中的概率
Iterator it1 = tmptermInClassProbility.keySet().iterator();
while (it1.hasNext()) {
String tmpKey1 = "";
tmpKey1 = (String) it1.next();
double tmpValue1 = tmptermInClassProbility.get(tmpKey1) / sum1;
// System.out.println(tmpKey1 + " " + tmpValue1 + " " +
// tmptermInClassProbility.get(tmpKey1));
termInClassProbility.put(tmpKey1, tmpValue1);
}
super.setup(context);
}
public void map(Text key, BytesWritable value, Context context)
throws IOException, InterruptedException {
String content = new String(value.get());
Matcher m = PATTERN.matcher(content);
// 获得文件名和文件上级目录名,分别用作docID和classID
String[] classAndFile = key.toString().split("@");
String fileName = classAndFile[1];
// String className = classAndFile[0];
for (String classname : classGroup) {
double multipleTerm = 0;
while (m.find()) {
String temkey = m.group();
String classAndWord;
if (!stopWords.contains(temkey)) {
// this.docID.set(className + "@" + temkey);
classAndWord = classname + "@" + temkey;
if (termInClassProbility.containsKey(classAndWord)) {
multipleTerm += Math.log10(termInClassProbility
.get(classAndWord));
}
// multipleTerm *=
// termInClassProbility.get(classAndWord);
// context.write(this.word, this.singleCount);
// System.out.println(word.toString());
}
}
multipleTerm += Math.log10(classProbility.get(classname));
this.docID.set(fileName);
this.classAndProbility.set(classname + "/"
+ Double.toString(multipleTerm));
// System.out.println(this.classAndProbility.toString());
context.write(this.docID, this.classAndProbility);
}
}
}
public static class NaiveBayesReducer extends
Reducer<Text, Text, Text, Text> {
private Text docID = new Text();
private Text classID = new Text();
public void reduce(Text key, Iterable<Text> value, Context context)
throws IOException, InterruptedException {
double maxProbility = -9999999;
String maxClass = "";
// 计算文档属于那一类
for (Text val : value) {
String[] tmpVal = val.toString().split("/");
//
// System.out.println("docID = " + key.toString() + " "
// + val.toString() + " " + tmpVal.length);
// System.out.println(tmpVal[0] + " "
// + Double.valueOf(tmpVal[1]));
double probility = Double.valueOf(tmpVal[1]);
if (probility > maxProbility) {
maxProbility = probility;
maxClass = tmpVal[0];
}
}
this.docID.set(key);
this.classID.set(maxClass);
this.classID.set(maxClass + "/" + maxProbility);
context.write(this.docID, this.classID);
}
}
@Override
public int run(String[] args) throws Exception {
// TODO Auto-generated method stub
Configuration conf = getConf();
String[] otherArgs = new GenericOptionsParser(conf, args)
.getRemainingArgs();
if (otherArgs.length != 2) {
}
Path outputPath = new Path(OUTPUT_PATH);
FileSystem fs = outputPath.getFileSystem(conf);
if (fs.exists(outputPath)) {
fs.delete(outputPath, true);
}
conf.set("CLASS_DICT_PATH", CLASS_DICT_PATH);
conf.set("TERM_FREQUENCE_IN_CLASS", TERM_FREQUENCE_IN_CLASS);
Job job = new Job(conf, "NaiveBayes");
job.setJarByClass(NaiveBayes.class);
job.setMapperClass(NaiveBayesMapper.class);
j
评论11
最新资源