package edu.hitsz.c102c.cnn;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import edu.hitsz.c102c.cnn.Layer.Size;
import edu.hitsz.c102c.dataset.Dataset;
import edu.hitsz.c102c.dataset.Dataset.Record;
import edu.hitsz.c102c.util.ConcurenceRunner.TaskManager;
import edu.hitsz.c102c.util.Log;
import edu.hitsz.c102c.util.Util;
import edu.hitsz.c102c.util.Util.Operator;
public class CNN implements Serializable {
/**
*
*/
private static final long serialVersionUID = 337920299147929932L;
private static double ALPHA = 0.85;
protected static final double LAMBDA = 0;
// 网络的各层
private List<Layer> layers;
// 层数
private int layerNum;
// 批量更新的大小
private int batchSize;
// 除数操作符,对矩阵的每一个元素除以一个值
private Operator divide_batchSize;
// 乘数操作符,对矩阵的每一个元素乘以alpha值
private Operator multiply_alpha;
// 乘数操作符,对矩阵的每一个元素乘以1-labmda*alpha值
private Operator multiply_lambda;
/**
* 初始化网络
*
* @param layerBuilder
* 网络层
* @param inputMapSize
* 输入map的大小
* @param classNum
* 类别的个数,要求数据集将类标转化为0-classNum-1的数值
*/
public CNN(LayerBuilder layerBuilder, final int batchSize) {
layers = layerBuilder.mLayers;
layerNum = layers.size();
this.batchSize = batchSize;
setup(batchSize);
initPerator();
}
/**
* 初始化操作符
*/
private void initPerator() {
divide_batchSize = new Operator() {
private static final long serialVersionUID = 7424011281732651055L;
@Override
public double process(double value) {
return value / batchSize;
}
};
multiply_alpha = new Operator() {
private static final long serialVersionUID = 5761368499808006552L;
@Override
public double process(double value) {
return value * ALPHA;
}
};
multiply_lambda = new Operator() {
private static final long serialVersionUID = 4499087728362870577L;
@Override
public double process(double value) {
return value * (1 - LAMBDA * ALPHA);
}
};
}
/**
* 在训练集上训练网络
*
* @param trainset
* @param repeat
* 迭代的次数
*/
public void train(Dataset trainset, int repeat) {
// 监听停止按钮
new Lisenter().start();
for (int t = 0; t < repeat && !stopTrain.get(); t++) {
int epochsNum = trainset.size() / batchSize;
if (trainset.size() % batchSize != 0)
epochsNum++;// 多抽取一次,即向上取整
Log.i("");
Log.i(t + "th iter epochsNum:" + epochsNum);
int right = 0;
int count = 0;
for (int i = 0; i < epochsNum; i++) {
int[] randPerm = Util.randomPerm(trainset.size(), batchSize);
Layer.prepareForNewBatch();
for (int index : randPerm) {
boolean isRight = train(trainset.getRecord(index));
if (isRight)
right++;
count++;
Layer.prepareForNewRecord();
}
// 跑完一个batch后更新权重
updateParas();
if (i % 50 == 0) {
System.out.print("..");
if (i + 50 > epochsNum)
System.out.println();
}
}
double p = 1.0 * right / count;
if (t % 10 == 1 && p > 0.96) {//动态调整准学习速率
ALPHA = 0.001 + ALPHA * 0.9;
Log.i("Set alpha = " + ALPHA);
}
Log.i("precision " + right + "/" + count + "=" + p);
}
}
private static AtomicBoolean stopTrain;
static class Lisenter extends Thread {
Lisenter() {
setDaemon(true);
stopTrain = new AtomicBoolean(false);
}
@Override
public void run() {
System.out.println("Input & to stop train.");
while (true) {
try {
int a = System.in.read();
if (a == '&') {
stopTrain.compareAndSet(false, true);
break;
}
} catch (IOException e) {
e.printStackTrace();
}
}
System.out.println("Lisenter stop");
}
}
/**
* 测试数据
*
* @param trainset
* @return
*/
public double test(Dataset trainset) {
Layer.prepareForNewBatch();
Iterator<Record> iter = trainset.iter();
int right = 0;
while (iter.hasNext()) {
Record record = iter.next();
forward(record);
Layer outputLayer = layers.get(layerNum - 1);
int mapNum = outputLayer.getOutMapNum();
double[] out = new double[mapNum];
for (int m = 0; m < mapNum; m++) {
double[][] outmap = outputLayer.getMap(m);
out[m] = outmap[0][0];
}
if (record.getLable().intValue() == Util.getMaxIndex(out))
right++;
}
double p = 1.0 * right / trainset.size();
Log.i("precision", p + "");
return p;
}
/**
* 预测结果
*
* @param testset
* @param fileName
*/
public void predict(Dataset testset, String fileName) {
Log.i("begin predict");
try {
int max = layers.get(layerNum - 1).getClassNum();
PrintWriter writer = new PrintWriter(new File(fileName));
Layer.prepareForNewBatch();
Iterator<Record> iter = testset.iter();
while (iter.hasNext()) {
Record record = iter.next();
forward(record);
Layer outputLayer = layers.get(layerNum - 1);
int mapNum = outputLayer.getOutMapNum();
double[] out = new double[mapNum];
for (int m = 0; m < mapNum; m++) {
double[][] outmap = outputLayer.getMap(m);
out[m] = outmap[0][0];
}
// int lable =
// Util.binaryArray2int(out);
int lable = Util.getMaxIndex(out);
// if (lable >= max)
// lable = lable - (1 << (out.length -
// 1));
writer.write(lable + "\n");
}
writer.flush();
writer.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
Log.i("end predict");
}
private boolean isSame(double[] output, double[] target) {
boolean r = true;
for (int i = 0; i < output.length; i++)
if (Math.abs(output[i] - target[i]) > 0.5) {
r = false;
break;
}
return r;
}
/**
* 训练一条记录,同时返回是否预测正确当前记录
*
* @param record
* @return
*/
private boolean train(Record record) {
forward(record);
boolean result = backPropagation(record);
return result;
// System.exit(0);
}
/*
* 反向传输
*/
private boolean backPropagation(Record record) {
boolean result = setOutLayerErrors(record);
setHiddenLayerErrors();
return result;
}
/**
* 更新参数
*/
private void updateParas() {
for (int l = 1; l < layerNum; l++) {
Layer layer = layers.get(l);
Layer lastLayer = layers.get(l - 1);
switch (layer.getType()) {
case conv:
case output:
updateKernels(layer, lastLayer);
updateBias(layer, lastLayer);
break;
default:
break;
}
}
}
/**
* 更新偏置
*
* @param layer
* @param lastLayer
*/
private void updateBias(final Layer layer, Layer lastLayer) {
final double[][][][] errors = layer.getErrors();
int mapNum = layer.getOutMapNum();
new TaskManager(mapNum) {
@Override
public void process(int start, int end) {
for (int j = start; j < end; j++) {
double[][] error = Util.sum(errors, j);
// 更新偏置
double deltaBias = Util.sum(error) / batchSize;
double bias = layer.getBias(j) + ALPHA * deltaBias;
layer.setBias(j, bias);
}
}
}.start();
}
/**
* 更新layer层的卷积核(权重)和偏置
*
* @param layer
* 当前层
* @param lastLayer
* 前一层
*/
private void updateKernels(final Layer layer, final Layer lastLayer) {
int mapNum = layer.getOutMapNum();
final int lastMapNum = lastLayer.getOutMapNum();
new TaskManager(mapNum) {
@Override
public void process(int start, int end) {
for (int j = start; j < end; j++) {
for (int i = 0; i < lastMapNum; i++) {
// 对batch的每个记录delta求和
double[][] deltaKernel = null;
for (int r = 0; r < batchSize; r++) {
double[][] error = layer.getError(r, j);
if (deltaKernel =
没有合适的资源?快使用搜索试试~ 我知道了~
CNN.rar_CNN_Sentiment Analysis_south555_文本情感分析_神经网络 文本
共1个文件
java:1个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
5星 · 超过95%的资源 1 下载量 112 浏览量
2022-09-24
22:00:27
上传
评论 2
收藏 5KB RAR 举报
温馨提示
卷积神经网络的源代码,用于微博博文文本情感分析的三分类。
资源推荐
资源详情
资源评论
收起资源包目录
CNN.rar (1个子文件)
CNN.java 18KB
共 1 条
- 1
资源评论
- m0_626823652023-07-24资源质量不错,和资源描述一致,内容详细,对我很有用。
weixin_42651887
- 粉丝: 79
- 资源: 1万+
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功