package com.bjtu.dd.approach;
import com.bjtu.dd.util.MyMath;
public class PGD {
public double[][][] train;
public double lamda;
public int maxIter;
public double fTol;
public int cycle;
public final int week_amount = 7;
public static int predict_days = 7;
public final int week_day = 7;
public final int length_vari = 10;
public PGD(double[][][] train, double lamda, int cycle, int maxIter, double fTol) {
this.train = train;
this.lamda = lamda;
this.maxIter = maxIter;
this.fTol = fTol;
this.cycle = cycle;
}
public double[] runModel() {
double[] weight = new double[cycle];
for (int i = 0; i < cycle; i++) {
weight[i] = 1;
}
fTol = 1e-4;
int current_num = 1;
while (current_num < maxIter) {
System.out.println("iterator number: " + current_num);
current_num += 1;
double[] pre_weight = MyMath.copyVector(weight, cycle);
double[] pre_grad = gradient(weight, cycle);
double pre_func = function(weight, cycle);
double[] direction = MyMath.numMultiplyVector(pre_grad, cycle, -1);
double t = linearSearch_backtrack(direction, cycle, weight, cycle);
// 作投影
double[] temp = MyMath.vectorPlusVector(weight, cycle, MyMath.numMultiplyVector(direction, cycle, t),
cycle);
direction = MyMath.vectorMinusVector(projection(temp, cycle), cycle, weight, cycle);
t = linearSearch_backtrack(direction, cycle, weight, cycle);
weight = MyMath.vectorPlusVector(weight, cycle, MyMath.numMultiplyVector(direction, cycle, t), cycle);
double func = function(weight, cycle);
if (Math.abs(func - pre_func) < fTol) {
System.out.print("function value get it!");
break;
}
}
System.out.println(function(weight, cycle));
if (current_num >= maxIter)
System.out.println("exceed the maxIter!");
return weight;
}
public double function(double[] weight, int len_1) {
double func = 0;
// 计算loss第一项
int train_num = train.length;
for (int i = 0; i < train_num; i++) {
double[][] train_instance = train[i];
for (int j = 0; j + cycle < (week_amount -1) * week_day ; j++) {
double[][] input_x = MyMath.split2DimenMatrix2Dimen(train_instance, 0, length_vari - 1, j,
j + cycle - 1);
double[] target = MyMath.split2DimenMatrix1Dimen(train_instance, length_vari, j + cycle);
double[] res_temp_1 = MyMath.matrixMultiplytVector(input_x, length_vari, cycle, weight, cycle);
res_temp_1 = MyMath.vectorMinusVector(res_temp_1, length_vari, target, length_vari);
func += MyMath.vectorMultiplyVector(res_temp_1, length_vari, res_temp_1, length_vari);
}
}
// 计算regularize第二项
func += lamda * MyMath.vectorMultiplyVector(weight, cycle, weight, cycle);
return func;
}
public double[] gradient(double[] weight, int len_1) {
double[] temp_grad = new double[cycle];
int train_num = train.length;
for (int i = 0; i < train_num; i++) {
double[][] train_instance = train[i];
for (int j = 0; j + cycle < (week_amount -1) * week_day; j++) {
double[][] input_x = MyMath.split2DimenMatrix2Dimen(train_instance, 0, length_vari - 1, j,
j + cycle - 1);
double[] target = MyMath.split2DimenMatrix1Dimen(train_instance, length_vari, j + cycle);
double[] res_temp_1 = MyMath.matrixMultiplytVector(input_x, length_vari, cycle, weight, cycle);
res_temp_1 = MyMath.vectorMinusVector(res_temp_1, length_vari, target, length_vari);
res_temp_1 = MyMath.vectorMultiplyMatrix(res_temp_1, length_vari, input_x, length_vari, cycle);
res_temp_1 = MyMath.numMultiplyVector(res_temp_1, cycle, 2);
temp_grad = MyMath.vectorPlusVector(temp_grad, cycle, res_temp_1, cycle);
}
}
double temp_value = MyMath.vectorMultiplyVector(weight, weight.length, weight, weight.length);
double[] temp_weight = MyMath.numMultiplyVector(weight, weight.length, 1/temp_value);
temp_grad = MyMath.vectorPlusVector(temp_grad, temp_grad.length,
MyMath.numMultiplyVector(temp_weight, temp_weight.length, 2 * lamda), temp_weight.length);
return temp_grad;
}
public double[] projection(double[] weight, int len_1) {
long begin = System.currentTimeMillis();
double[] new_weight = new double[len_1];
if (weight.length != len_1) {
System.err.println("projection error!");
}
for (int i = 0; i < len_1; i++) {
new_weight[i] = Math.abs(weight[i]);
}
long end = System.currentTimeMillis();
System.out.println("projection time: " + (end - begin));
return new_weight;
}
public double linearSearch_wolfe(double[] direction, int len_1, double[] weight, int len_2) {
long begin = System.currentTimeMillis();
double c1 = 0.01;
double c2 = 0.9;
double t0 = 1;
double tol = 0.001;
double low = 0;
double high = 1000;
double t = t0;
double fx = function(weight, len_2);
double[] dfx = gradient(weight, len_2);
while (true) {
// wolfe条件一
// x + t * d
double[] weight_direction = MyMath.vectorPlusVector(weight, len_2,
MyMath.numMultiplyVector(direction, len_1, t), len_1);
double temp_1 = function(weight_direction, len_2);
double temp_2 = fx + MyMath.vectorMultiplyVector(dfx, cycle, direction, cycle) * c1 * t;
// wolfe条件二
double temp_3 = MyMath.vectorMultiplyVector(gradient(weight_direction, cycle), cycle, direction, cycle);
double temp_4 = c2 * MyMath.vectorMultiplyVector(dfx, cycle, direction, cycle);
if (temp_1 > temp_2) {
high = t;
t = (low + high) / 2;
} else if (temp_3 < temp_4) {
low = t;
if (high > 600)
t = 2 * low;
else
t = (low + high) / 2;
} else {
break;
}
if (high - low < tol)
break;
}
long end = System.currentTimeMillis();
System.out.println("linearSearch_wolfe time: " + (end - begin));
return t;
}
public double linearSearch_backtrack(double[] direction, int len_1, double[] weight, int len_2) {
double alpha = 0.4;
double beta = 0.8;
double t = 1;
double fx = function(weight, len_2);
double[] dfx = gradient(weight, len_2);
while (true) {
double[] weight_direction = MyMath.vectorPlusVector(weight, len_2,
MyMath.numMultiplyVector(direction, len_1, t), len_1);
double temp_1 = function(weight_direction, len_2);
double temp_2 = fx + MyMath.vectorMultiplyVector(dfx, cycle, direction, cycle) * alpha * t;
if (temp_1 > temp_2)
t = beta * t;
else
break;
}
return t;
}
}
评论0