package org.mybp;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.HashMap;
import java.util.Scanner;
public class Mybp{
private String fileURL;
private int[] allnode;
private int inputnodenumber;
private int hidenodenumber;
private int outputnodenumber;
private double inputdata[]; //测试输入层数据
private double hidedata[];// 测试中间层输出数据
private double outputdata[];// 测试3层输出数据
private double rightoutput[];// 正确输出数据
private double deltaw0[][][];// deltaw0训练学习
private double deltaw1[][][];// deltaw1训练学习
private double e=1.0;
private double a=0.04;
private double n=0.5;
HashMap<String, Integer> map;
// private HashMap map;
private String [] classname;
private int stepmax=1000;//终止条件步骤
private int step ;
private double w0[][]; //输入层和中间层的wij;
private double w1[][]; //中间层和输出层的wjk;
public Mybp(String fileURL,int hidenodenumber){
this.fileURL = fileURL;
this.hidenodenumber = hidenodenumber;
}
public void initialize(){
inputnodenumber = 4;
outputnodenumber = 3;
w0=new double[inputnodenumber][hidenodenumber];
w1=new double[hidenodenumber][outputnodenumber];
inputdata=new double[inputnodenumber];
hidedata=new double[hidenodenumber];
outputdata=new double[outputnodenumber];
rightoutput=new double[outputnodenumber];
deltaw0=new double[stepmax+1][inputnodenumber][hidenodenumber];//0
deltaw1=new double[stepmax+1][hidenodenumber][outputnodenumber];
map=new HashMap<String,Integer>();
classname=new String[]{"Iris-setosa","Iris-versicolor","Iris-virginica"};
for(int i=0;i<3;++i){
map.put(classname[i], i);
}
for(int i=0;i<inputnodenumber;++i)
{
for(int j=0;j<hidenodenumber;++j)
{
w0[i][j]=Math.random();
}
}
for(int i=0;i<hidenodenumber;++i)
{
for(int j=0;j<outputnodenumber;++j)
{
w1[i][j]=Math.random();
}
}
}
public double sigmoid(double x)
{
return 1/(1+Math.exp(-x));
}
public void bptraining() throws FileNotFoundException{
String input="";
String inputline;
String tmpdata[];
Scanner jin;
jin = new Scanner(new File(fileURL));
while(jin.hasNext())
{
inputline=jin.nextLine();
input+=inputline+"\r\n";
}
for(step=1;step<=stepmax;++step)
{
jin = new Scanner(input);
while(jin.hasNext())
{
inputline=jin.next();
tmpdata=inputline.split(",");
for(int i=0;i<inputnodenumber;++i)
{
inputdata[i]=Double.parseDouble(tmpdata[i]);
}
for(int i=0;i<3;++i)
{
if( i==map.get(tmpdata[inputnodenumber]))
rightoutput[i]=1;
else rightoutput[i]=0;
}
for(int j=0;j<hidenodenumber;++j)//
{
hidedata[j]=-e;
for(int i=0;i<inputnodenumber;++i)
hidedata[j]+=w0[i][j]*inputdata[i];
hidedata[j]=sigmoid(hidedata[j]);
}
for(int k=0;k<outputnodenumber;++k)//
{
outputdata[k]=-e;
for(int j=0;j<hidenodenumber;++j)
outputdata[k]+=w1[j][k]*hidedata[j];
outputdata[k]=sigmoid(outputdata[k]);
}
for(int k=0;k<outputnodenumber;++k)//
{
for(int j=0;j<hidenodenumber;++j)//
deltaw1[step][j][k]=-n*(-(rightoutput[k]-outputdata[k]))*(1-outputdata[k])*outputdata[k]*hidedata[j];
}
for(int j=0;j<hidenodenumber;++j)
{
for(int i=0;i<inputnodenumber;++i)
{
double sum=0;
for(int k=0;k<outputnodenumber;++k)
sum+=-(rightoutput[k]-outputdata[k])*(1-outputdata[k])*outputdata[k]*w1[j][k];
deltaw0[step][i][j]=-n*sum*(1-hidedata[j])*hidedata[j]*inputdata[i];
}
}
for(int k=0;k<outputnodenumber;++k)//
{
for(int j=0;j<hidenodenumber;++j)
w1[j][k]+=deltaw1[step][j][k]+a*deltaw1[step-1][j][k];
}
for(int j=0;j<hidenodenumber;++j)
{
for(int i=0;i<inputnodenumber;++i)
w0[i][j]+=deltaw0[step][i][j]+a*deltaw0[step-1][i][j];
}
}
}
}
public void test() throws FileNotFoundException{
int total=0;
int right=0;
String inputline;
String tmpdata[];
Scanner jin = new Scanner(new File("src/iris.data"));
while(jin.hasNext())
{
inputline=jin.next();
tmpdata=inputline.split(",");
for(int i=0;i<inputnodenumber;++i)
{
inputdata[i]=Double.parseDouble(tmpdata[i]);
}
for(int j=0;j<hidenodenumber;++j)//
{
hidedata[j]=-e;
for(int i=0;i<inputnodenumber;++i)
hidedata[j]+=w0[i][j]*inputdata[i];
hidedata[j]=sigmoid(hidedata[j]);
}
for(int k=0;k<outputnodenumber;++k)//
{
outputdata[k]=-e;
for(int j=0;j<hidenodenumber;++j)
outputdata[k]+=w1[j][k]*hidedata[j];
outputdata[k]=sigmoid(outputdata[k]);
}
int classid=0;
for(int k=1;k<outputnodenumber;++k)
{
if(outputdata[classid]<outputdata[k])
classid=k;
}
System.out.print( inputline+"--"+classname[classid]+" ");//
if(classid==map.get(tmpdata[inputnodenumber]))
{
System.out.println("right");
right++;
}
else System.out.println("wrong");
total++;
}
System.out.println();
System.out.println("测试数据总数"+total);
System.out.println("测试正确的数量"+right);
System.out.println("测试集准确率"+(double)right/total);
}
public void start() throws FileNotFoundException{
long t1=System.currentTimeMillis();
initialize();
bptraining();
test();
// System.out.println("Runtime:"+(System.currentTimeMillis()-t1)+"ms");
}
}