package je;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.RandomAccessFile;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;
public class JELR {
public static HashMap<String, Float> W = new HashMap<String, Float>();
public static FileWriter fw;
public static float TONE=0.45f;
public static float TRAIN_RATE=0.003f;
public static float b=0.1f;
public static float gamma=13f;
public static void writerResult(FileWriter fw,String lujing ,String line,float score){
try {
fw.write(lujing+line+score+"\n");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public static boolean trainFM(float score){
if(Math.abs(score-0.5f)>TONE)
return false;
return true;
}
public static boolean trainb_sampling(float score){
if(Math.random()<b/(b+Math.abs(score-0.5)))
return true;
return false;
}
public static boolean trainLMS(float score){
if(Math.random()<Math.exp(gamma*-1*Math.abs(score-0.5)))
return true;
return false;
}
/**
* notice the TONE and FM
* the TONE add the condition if classication has a wrong answer
* @param score
* @return
*/
public static boolean trainTONE(float score){
// if(Math.abs(score-0.5f)>TONE)
// return false;
return true;
}
/**
*
* @param X
* @return
*/
public static float calculateScore(HashMap<String, Float> X){
float score=0f;
Iterator iter = X.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
//System.out.println("key:"+entry.getValue()+" value:"+entry.getKey());
score+=(Float) entry.getValue();
}
double c= (Math.exp(score)/(Math.exp(score)+1.0));
if(c==0.0){
//System.out.println("ex:"+ex +" c:"+c);
}
return (float) c;
//return (float) (1.0d/(1.0d+Math.exp(-1.0f*score)));
}
/**
* update Spam
* @param W
*/
public static void updateSpam(HashMap<String, Float> X,float score){
Iterator iter = X.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
if(W.containsKey(entry.getKey())){
float w= (Float) entry.getValue();
float delta_weight=(1.0f-score)*TRAIN_RATE;
W.put((String)entry.getKey(), w+delta_weight);
}
}
}
/**
* update Ham
* @param W
*/
public static void updateHam(HashMap<String, Float> X,float score){
Iterator iter = X.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
if(W.containsKey(entry.getKey())){
float w= (Float) entry.getValue();
float delta_weight=score*TRAIN_RATE;
W.put((String)entry.getKey(),w-delta_weight);
}
}
}
public static void main(String[] args) {
String input=args[0];
String resultfile =args[1];
String filename =input + "index";
System.out.println("input: "+input);
System.out.println("result: "+resultfile);
BufferedReader br;
try {
fw = new FileWriter(resultfile);
br = new BufferedReader(new InputStreamReader(new FileInputStream(
filename)));
String temp = null;
int total = 0;
while ((temp = br.readLine()) != null) {
String[] info = temp.split(" ");
String lujing = input + info[1];
total++;
if(info[0].equalsIgnoreCase("spam")){
HashMap<String, Float> X = new HashMap<String, Float>();
LuceneJEAnalyzerText je = new LuceneJEAnalyzerText(lujing, W,X);
je.readEmail();
float score=calculateScore(X);
if(score>0.5f)
writerResult(fw,lujing," judge=spam class=spam score=",score);
else
writerResult(fw,lujing," judge=spam class=ham score=",score);
if(trainFM(score))
updateSpam(X,score);
X=null;
}
if(info[0].equalsIgnoreCase("ham")){
HashMap<String, Float> X = new HashMap<String, Float>();
LuceneJEAnalyzerText je = new LuceneJEAnalyzerText(lujing, W,X);
je.readEmail();
float score=calculateScore(X);
if(score>0.5f)
writerResult(fw,lujing," judge=ham class=spam score=",score);
else
writerResult(fw,lujing," judge=ham class=ham score=",score);
if(trainFM(score))
updateHam(X,score);
X=null;
}
}
fw.close();
System.out.println("total email is " + total);
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
}
评论0