import java.util.ArrayList ;
import java.io.* ;
public class SmoTool{
private String filePath ;
OptStruct os ;
public SmoTool(String filePath){
this.filePath = filePath ;
os = new OptStruct() ;
init() ;
smoP(40) ;
System.out.println("test result:" + test()) ;
}
private ArrayList<double[]> getData(String filePath){ //获取数据,并以ArrayList<double[]> 的格式返回
ArrayList<double[]> tmpDataList = new ArrayList<>() ;
File file = new File(filePath) ;
try{
BufferedReader in = new BufferedReader(new FileReader(file)) ;
String tmpStr = null ;
while((tmpStr = in.readLine()) != null){
String[] strArray = tmpStr.split("\t") ;
double[] tmpData = new double[strArray.length] ;
for(int i = 0; i < strArray.length; i ++){
tmpData[i] = Double.parseDouble(strArray[i]) ;
}
tmpDataList.add(tmpData) ;
}
}catch(IOException e){}
return tmpDataList ;
}
private void init(){ //初始化函数,包括读取训练数据,初始化OptStruct对象中各字段值
ArrayList<double[]> tmpDataList = new ArrayList<>() ;
double[] labels ;
File file = new File(filePath) ;
ArrayList<String> tmpStrList = new ArrayList<>() ;
try{
BufferedReader in = new BufferedReader(new FileReader(file)) ;
String tmpStr = null ;
while((tmpStr = in.readLine()) != null){
String[] strArray = tmpStr.split("\t") ;
int length = strArray.length ;
double[] tmpData = new double[length - 1] ;
for(int i = 0; i < length - 1; i ++){
tmpData[i] = Double.parseDouble(strArray[i]) ;
}
tmpDataList.add(tmpData) ;
tmpStrList.add(strArray[length - 1]) ;
}
labels = new double[tmpStrList.size()] ;
for(int i = 0; i < tmpStrList.size(); i ++){
labels[i] = Double.parseDouble(tmpStrList.get(i)) ;
}
os.labels = labels ;
os.datas = tmpDataList ;
os.C = 0.6 ;
os.tolor = 0.00001 ;
os.m = os.datas.size() ;
os.eCache = new double[os.m][2] ;
os.alphas = new double[os.m] ;
for(int i = 0; i < os.m; i ++){
os.eCache[i][0] = 0 ;
os.eCache[i][1] = 0 ;
os.alphas[i] = 0 ;
}
}
catch(IOException e){}
}
private void smoP(int maxIter){ //SMO算法的最外层迭代。
int iter = 0 ;
boolean entireSet = true ;
int alphaPairsChanged = 0 ;
while((iter < maxIter) && ((alphaPairsChanged > 0) || (entireSet))){
alphaPairsChanged = 0 ;
if(entireSet){
int i = 0 ;
for( i = 0; i < os.m; i ++){
alphaPairsChanged += innerL(i,os) ;
}
System.out.printf("fullSet,iter:%d,i:%d,pairsChange:%d",iter,i,alphaPairsChanged) ;
System.out.println() ;
iter += 1 ;
}
else{
int i = 0 ;
for( i = 0 ; i < os.alphas.length; i ++){
if(os.alphas[i] < os.C && os.alphas[i] > 0){
alphaPairsChanged += innerL(i,os) ;
}
}
System.out.printf("non-bound,iter:%d,i:%d,pairsChange:%d",iter,i,alphaPairsChanged) ;
System.out.println() ;
iter += 1 ;
}
if(entireSet)
entireSet = false ;
else if(alphaPairsChanged == 0)
entireSet = true ;
System.out.println("iteration number:" + iter) ;
}
}
private double calcEk(OptStruct os,int k){ //计算E(k)
double sum = 0 ;
double ek = 0 ;
for(int i = 0; i < os.m; i ++){
sum += os.alphas[i]*os.labels[i]*
K(os,i,k) ;
}
sum += os.b ;
ek = sum - os.labels[k] ;
return ek ;
}
private int selectJ(int i,OptStruct os,double ei){ //在缓冲ek里找和ei相差最大的作为第二个alpha,并返回下标,若找不到,则随机找一个。
double ej = 0,maxDeltaE = 0 ;
int maxK = -1 ;
os.eCache[i][0] = 1 ;
os.eCache[i][1] = ei ;
ArrayList<Integer> validEcacheList = new ArrayList<>() ;
validEcacheList = nonZero(os.eCache) ;
System.out.println(i + ": ei=" + ei) ;
if(validEcacheList.size() > 1){
for(int k : validEcacheList){
if(k == i)
continue ;
double ek = calcEk(os,k) ;
double deltaE = Math.abs(ei - ek) ;
// System.out.println(deltaE + ": k=" + k + ": ek=" + ek) ;
if(deltaE > maxDeltaE){
maxK = k ;
maxDeltaE = deltaE ;
ej = ek ;
}
}
if(maxK != -1){
//System.out.println("\u0007") ;
return maxK ;
}
}
maxK = selectJrand(i,os.m - 1) ;
return maxK ;
}
private void updateEk(OptStruct os,int k){ //更新缓冲里的ek的值
double ek = calcEk(os,k) ;
os.eCache[k][0] = 1 ;
os.eCache[k][1] = ek ;
}
private int innerL(int i,OptStruct os){ //这个函数是一次更新alphas的迭代过程,第一次运行会更新一对alphas和b,或者退出
double ei = calcEk(os,i) ;
double L = 0,H = 0 ;
if((os.labels[i] * ei < -os.tolor) &&(os.alphas[i] < os.C) ||
(os.labels[i] * ei > os.tolor) && (os.alphas[i] > 0)){
int j = selectJ(i,os,ei) ;
double ej = calcEk(os,j) ;
double alphaIold = os.alphas[i] ;
double alphaJold = os.alphas[j] ;
if(os.labels[i] != os.labels[j]){
L = Math.max(0,os.alphas[j] - os.alphas[i]) ;
H = Math.min(os.C,os.C + os.alphas[j] - os.alphas[i]) ;
}
else{
L = Math.max(0,os.alphas[j] + os.alphas[i] - os.C) ;
H = Math.min(os.C,os.alphas[j] + os.alphas[i]) ;
}
if(L == H){
System.out.println("L==H") ;
return 0 ;
}
double eta = 2.0 * K(os,i,j) - K(os,i,i) - K(os,j,j) ;
if(eta >= 0){
System.out.println("eta >= 0") ;
return 0 ;
}
os.alphas[j] -= os.labels[j] * (ei - ej) / eta ;
os.alphas[j] = clipAlpha(os.alphas[j],H,L) ;
updateEk(os,j) ;
if(Math.abs(os.alphas[j] - alphaJold) < 0.00001){
System.out.println("j not moving enough") ;
return 0 ;
}
os.alphas[i] += os.labels[i]*os.labels[j] *
(alphaJold - os.alphas[j]) ;
updateEk(os,i) ;
double b1 = os.b - ei - os.labels[i]*(os.alphas[i] - alphaIold)*
K(os,i,i) - os.labels[j]*(os.alphas[j] - alphaJold)*K(os,j,i) ;
double b2 = os.b - ej - os.labels[i]*(os.alphas[i] - alphaIold)*
K(os,i,j) - os.labels[j]*(os.alphas[j] - alphaJold)*K(os,j,j) ;
if((0 < os.alphas[i])&&(os.C > os.alphas[i]))
os.b = b1 ;
else if((0 < os.alphas[j]) && (os.C > os.alphas[j]))
os.b = b2 ;
else
os.b = (b1 + b2) / 2.0 ;
return 1 ;
}
return 0 ;
}
private double clipAlpha(double alpha,double H,double L){ //确保alpha在H与L之间
if(alpha > H)
return H ;
else if(alpha < L)
return L ;
return alpha ;
}
private int selectJrand(int i,int m) { //在0-m中随机选择一个数,但这个数不会等于i
int j = (int)(Math.random() * m + 1) ;
while(j == i){
j = (int)(Math.random() * m + 1) ;
}
return j ;
}
private ArrayList<Integer> nonZero(double[][] eCache){ //用于寻找eCache数组中第一维数不为1的所有下标的一个集合,存放在列表中
ArrayList<Integer> validEcacheList = new ArrayList<>() ;
for(int i = 0; i < eCache.length; i ++){
if(eCache[i][0] != 0){
validEcacheList.add(i) ;
}
}
return validEcacheList ;
}
private double K(OptStruct os, int i, int j){ //核函数
double sum = 0 ;
double[] xi = os.datas.get(i) ;
double[] xj = os.datas.get(j) ;
for(int k = 0 ; k <xi.length; k ++ ){
sum += xi[k] * xj[k] ;
}
return sum ;
}
private double test(){ //测试训练数据,如果结果接近于1,那就成功了
double sum = 0;
for(int i = 0; i < os.m; i ++){
if(os.labels[i] * predict(os.datas.get(i)) > 0)
sum += 1 ;
}
return sum / os.m ;
}
public double predict(double[] testData){ //对某一个数据进行预测的函数
double result = 0;
for(int i = 0; i < os.alphas.length; i ++){
result += os.alphas[i] * os.labels[i] * K(os.datas.get(i),testData) ;
}
result += os.b ;
if(result > 0)
return 1 ;
else return -1 ;
}
private double K(double[]xi,double[] xj){ //核函数
double sum = 0 ;
for(int i = 0; i < xi.length; i ++){
sum += xi[i] * xj[i] ;
}
return sum ;
}
}
c