import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Stack;
public class HMM {
public static void main(String[] args){
HMMParser p = new HMMParser("data/train.pos");
p.parseTrainer();
HMM hmm = new HMM(p);
System.out.println("likelihood of 'NN' corresponding to 'agreement': "+ hmm.calcLikelihood("NN", "agreement"));
System.out.println("prior probability of NN -> VBG: "+ hmm.calcPriorProb("NN", "VBG"));
HMMParser p2 = new HMMParser("data/test.pos");
hmm.viterbi(p2.wordSequence());
}
HashMap<String, Integer> tagCounts;
HashMap<String, HashMap<String, Integer>> wordCounts;
HashMap<String, HashMap<String, Integer>> tagBigramCounts;
HashMap<String, HashMap<String, Integer>> tagForWordCounts;
HashMap<String, HashMap<String, Double>> goodTuringTagBigramCounts;
HashMap<String, Double> goodTuringTagUnigramCounts;
HashMap<Integer, Integer> numberOfBigramsWithCount;
boolean goodTuringCountsAvailable = false;
int numTrainingBigrams;
String mostFreqTag;
FileWriter writer;
final boolean ADDONE = true;
final boolean GOODTURING = false;
public HMM(HMMParser p){
this.tagCounts = p.tagCounts;
this.wordCounts = p.wordCounts;
this.tagBigramCounts = p.tagBigramCounts;
this.tagForWordCounts = p.tagForWordCounts;
this.mostFreqTag = p.mostFreqTag;
this.goodTuringTagBigramCounts = new HashMap<String, HashMap<String, Double>>();
this.goodTuringTagUnigramCounts = new HashMap<String, Double>();
this.numberOfBigramsWithCount = new HashMap<Integer, Integer>();
this.numTrainingBigrams = p.numTrainingBigrams;
try {
writer = new FileWriter(new File("data/output.pos"));
} catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
}
//returns map[key]
private int counts(HashMap<String, Integer> map, String key){
return (map.containsKey(key)) ? map.get(key) : 0;
}
//returns map[key1][key2]
private int counts(HashMap<String, HashMap<String,Integer>> map, String key1, String key2){
return (map.containsKey(key1))? counts(map.get(key1), key2) : 0;
}
//returns map[key]
private double counts(HashMap<String, Double> map, String key){
return (map.containsKey(key)) ? map.get(key) : 0.0;
}
//returns map[key1][key2]
private double counts(HashMap<String, HashMap<String,Double>> map, String key1, String key2){
return (map.containsKey(key1))? counts(map.get(key1), key2) : 0.0;
}
private int numberOfBigramsWithCount(int count){
if (numberOfBigramsWithCount.containsKey(count)) {
return numberOfBigramsWithCount.get(count);
} else {
return 0;
}
}
private void makeGoodTuringCounts(){
// Fill numberOfBigramsWithCount
for (String tag1 : tagBigramCounts.keySet()) {
HashMap<String, Integer> innerMap = tagBigramCounts.get(tag1);
for (String tag2 : innerMap.keySet()) {
int count = innerMap.get(tag2);
if (numberOfBigramsWithCount.containsKey(count)) {
numberOfBigramsWithCount.put(count, 1+numberOfBigramsWithCount.get(count));
} else {
numberOfBigramsWithCount.put(count, 1);
}
}
}
// Fill goodTuringTagBigramCounts
for (String tag1 : tagBigramCounts.keySet()) {
HashMap<String, Integer> innerMap = tagBigramCounts.get(tag1);
HashMap<String, Double> innerGTMap = new HashMap<String, Double>();
goodTuringTagBigramCounts.put(tag1, innerGTMap);
double unigramCount = 0;
for (String tag2 : innerMap.keySet()) {
int count = innerMap.get(tag2);
// c* = (c+1) * N(c+1) / N(c)
double newCount = ((double)count+1.0)*((double)numberOfBigramsWithCount(count+1))/((double)numberOfBigramsWithCount(count));
innerGTMap.put(tag2, newCount);
unigramCount += newCount;
}
goodTuringTagUnigramCounts.put(tag1, unigramCount);
}
goodTuringCountsAvailable = true;
}
/*
* Calculates P(word|tag)
*/
public double calcLikelihood(String tag, String word){
if(ADDONE){
int vocabSize = tagForWordCounts.keySet().size();
return (double) (counts(wordCounts,tag,word)+1) / (double) (counts(tagCounts,tag)+vocabSize);
} else if(GOODTURING) {
return (double) counts(wordCounts,tag,word) / (double) counts(goodTuringTagUnigramCounts,tag);
} else {
return (double) counts(wordCounts,tag,word) / (double) counts(tagCounts,tag);
}
}
/*
* Calculates P(tag2|tag1)
*/
public double calcPriorProb(String tag1, String tag2){
if(ADDONE) {
int vocabSize = tagCounts.keySet().size();
return (double) (counts(tagBigramCounts,tag1,tag2)+1) / (double) (counts(tagCounts,tag1)+vocabSize);
} else if(GOODTURING) {
if(!goodTuringCountsAvailable) {
System.out.println("Making good turing counts...");
makeGoodTuringCounts();
System.out.println("Done making good turing counts.");
}
double gtcount = counts(goodTuringTagBigramCounts, tag1, tag2);
// If this bigram has occurred, return good turing probability
if (gtcount > 0.0) {
return gtcount / counts(goodTuringTagUnigramCounts, tag1);
}
// Otherwise, return N1/N as per book (page 101)
return numberOfBigramsWithCount(1) / (double)numTrainingBigrams;
} else {
return (double) counts(tagBigramCounts,tag1,tag2) / (double) counts(tagCounts,tag1);
}
}
public void viterbi(ArrayList<String> words){
//two-dimensional Viterbi Matrix
boolean sentenceStart = true;
HashMap<String, Node> prevMap = null;
for(int i=0; i<words.size(); i++){
if (i%500==0) {
System.out.println("working on "+i+" of "+words.size()+" words");
}
String word = words.get(i);
HashMap<String, Node> subMap = new HashMap<String,Node>();
if(sentenceStart){
Node n = new Node(word, "<s>", null, 1.0);
subMap.put(word, n);
sentenceStart = false;
} else {
//add all possible tags (given the current word)
//to the Viterbi matrix
if(tagForWordCounts.containsKey(word)){
// Only Training Set tags
HashMap<String, Integer> tagcounts = tagForWordCounts.get(word);
for(String tag : tagcounts.keySet()){
subMap.put(tag, calcNode(word, tag, prevMap));
}
// Every Tag
//for(String tag : tagCounts.keySet()){
// subMap.put(tag, calcNode(word, tag, prevMap));
//}
} else if (word.matches("[A-Z]\\w*")) {
subMap.put("NNP", calcNode(word, "NNP", prevMap));
} else if (word.matches("\\p{Digit}*.\\p{Digit}*") || word.matches("(\\p{Punct}+|\\p{Digit}+)+")) {
subMap.put("CD", calcNode(word, "CD", prevMap));
} else if (word.contains("-") || word.matches(".*able")) {
subMap.put("JJ", calcNode(word, "JJ", prevMap));
} else if (word