import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.apache.spark.sql.Row;
public class BayesClassifier {
// 训练集总的词数
private Long totalWordCount = new Long(0);
// 类别
private Map<String, String> classMap = new HashMap<String, String>();
// 类别对应的文章数
private Map<String, Long> classArticleCount = new ConcurrentHashMap<String, Long>();
// 每个类别词的数量
private Map<String, Long> classWordCount = new ConcurrentHashMap<String, Long>();
// 每个类别对应的词典和词频
private Map<String, Map<String, Long>> classWordMap = new ConcurrentHashMap<String, Map<String, Long>>();
// 存放所有出现过的词
private Set<String> allWordSet = new HashSet<String>();
/**
* 训练数据
* @param records 每个Record存储文章的标题、内容和类别等信息
*/
public void train(List<Row> records) {
records.forEach(new Consumer<Row>() {
@Override
public void accept(Row record) {
// 文章的类别
String category = record.getString(4);
// 新的类别
if (!classMap.containsKey(category)) {
classMap.put(category, category);
classArticleCount.put(category, 0L);
classWordCount.put(category, 0L);
classWordMap.put(category, new HashMap<String, Long>());
}
Map<String, Long> wordMap = classWordMap.get(category);
// 获取切分的词
String[] words = record.getString(3).split(" ");
for (String word : words) {
// 更新该类别的词典和词频
if (wordMap.containsKey(word)) {
Long wordCount = wordMap.get(word);
wordMap.put(word, wordCount + 1);
} else {
wordMap.put(word, 1L);
}
allWordSet.add(word);
}
// 更新该类别的词典和词频
Long wordCount = classWordCount.get(category);
classWordCount.put(category, wordCount + words.length);
totalWordCount += words.length;
}
});
}
/**
* @param classKey 类别
* @param word 词
* @return 类别中词出现的次数
*/
public Long wordInClassCount(String classKey, String word) {
Map<String, Long> wordMap = classWordMap.get(classKey);
Long wordCount = wordMap.get(word);
return (wordCount == null) ? 1L : wordCount;
}
/**
* 选择分类概率最大的类别
* @param probClassMap
* @return 返回分类结果
*/
public String getMaxClassification(Map<String, Double> resultMap) {
Set<String> keySet = resultMap.keySet();
String maxClassification = null;
double maxProbability = Double.NEGATIVE_INFINITY;
// 选择归类概率最大的类别作为分类的结果
for (String classKey : keySet) {
double probability = resultMap.get(classKey);
if (probability > maxProbability) {
maxProbability = probability;
maxClassification = classKey;
}
}
return maxClassification;
}
/**
* 对文章进行分类
* @param record 待分类的文章
* @return 分类的类别
*/
public String classify(Row record) {
// 获取文章的分词结果
String[] words = record.getString(3).split(" ");
Map<String, Double> resultMap = new HashMap<String, Double>();
Set<String> keySet = classMap.keySet();
// 计算文章属于每个类别的概率
for (String classKey : keySet) {
double probability = 0.0;
for (String word: words) {
double wordFrequency = wordInClassCount(classKey, word) * 1.0 / (classWordCount.get(classKey)+ + allWordSet.size());
probability += Math.log(wordFrequency);
}
probability += Math.log(classWordCount.get(classKey) * 1.0 / totalWordCount);
resultMap.put(classKey, probability);
}
// 选择分类结果并返回
return getMaxClassification(resultMap);
}
}
评论0
最新资源