package com.omega.engine.service.impl;
import java.io.File;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import com.omega.common.data.Tensor;
import com.omega.common.data.utils.DataTransforms;
import com.omega.common.utils.DataLoader;
import com.omega.common.utils.ImageUtils;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.LabelUtils;
import com.omega.common.utils.MathUtils;
import com.omega.engine.controller.TrainTask;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.CrossEntropyLoss;
import com.omega.engine.loss.LossType;
import com.omega.engine.loss.SoftmaxWithCrossEntropyLoss;
import com.omega.engine.nn.data.DataSet;
import com.omega.engine.nn.layer.AVGPoolingLayer;
import com.omega.engine.nn.layer.BasicBlockLayer;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.DropoutLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.model.NetworkInit;
import com.omega.engine.nn.network.BPNetwork;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.pooling.PoolingType;
import com.omega.engine.service.BusinessService;
import com.omega.engine.updater.UpdaterType;
@Service
public class BusinessServiceImpl implements BusinessService {
/*
* @Autowired private NetworksDataBase dataBase;
*/
@Override
public void bpNetwork_iris() {
// TODO Auto-generated method stub
/**
* 读取训练数据集
*/
String iris_train = "H:/dataset\\iris\\iris.txt";
String iris_test = "H:/dataset\\iris\\iris_test.txt";
String[] labelSet = new String[] {"1","-1"};
DataSet trainData = DataLoader.loalDataByTxt(iris_train, ",", 1, 1, 4, 2, labelSet);
DataSet testData = DataLoader.loalDataByTxt(iris_test, ",", 1, 1, 4, 2, labelSet);
System.out.println("train_data:"+JsonUtils.toJson(trainData));
BPNetwork netWork = new BPNetwork(new SoftmaxWithCrossEntropyLoss(), UpdaterType.adam);
InputLayer inputLayer = new InputLayer(1,1,4);
FullyLayer hidden1 = new FullyLayer(4, 40);
ReluLayer active1 = new ReluLayer();
FullyLayer hidden2 = new FullyLayer(40, 20);
ReluLayer active2 = new ReluLayer();
FullyLayer hidden3 = new FullyLayer(20, 2);
SoftmaxWithCrossEntropyLayer hidden4 = new SoftmaxWithCrossEntropyLayer(2);
netWork.addLayer(inputLayer);
netWork.addLayer(hidden1);
netWork.addLayer(active1);
netWork.addLayer(hidden2);
netWork.addLayer(active2);
netWork.addLayer(hidden3);
netWork.addLayer(hidden4);
// SGDOptimizer optimizer = new SGDOptimizer(netWork, 2000, 0.001d);
// BGDOptimizer optimizer = new BGDOptimizer(netWork, 20000, 0.001d);
try {
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 10, 0.00001f, 10, LearnRateUpdate.NONE, false);
// netWork.GRADIENT_CHECK = true;
optimizer.train(trainData);
optimizer.test(testData);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}finally {
try {
CUDAMemoryManager.freeAll();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
@Override
@Async
public void bpNetwork_mnist() {
// TODO Auto-generated method stub
/**
* 读取训练数据集
*/
String mnist_train_data = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\train-images.idx3-ubyte";
String mnist_train_label = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\train-labels.idx1-ubyte";
String mnist_test_data = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\t10k-images.idx3-ubyte";
String mnist_test_label = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\t10k-labels.idx1-ubyte";
String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};
DataSet trainData = DataLoader.loadDataByUByte(mnist_train_data, mnist_train_label, labelSet, 1, 1 ,784, true);
DataSet testData = DataLoader.loadDataByUByte(mnist_test_data, mnist_test_label, labelSet, 1, 1 ,784, true);
BPNetwork netWork = new BPNetwork(new SoftmaxWithCrossEntropyLoss(), UpdaterType.adam);
netWork.learnRate = 0.001f;
int inputCount = (int) (Math.sqrt(794)+10);
InputLayer inputLayer = new InputLayer(1,1,784);
FullyLayer hidden1 = new FullyLayer(784, inputCount, false);
BNLayer bn1 = new BNLayer();
ReluLayer active1 = new ReluLayer();
FullyLayer hidden2 = new FullyLayer(inputCount, inputCount, false);
BNLayer bn2 = new BNLayer();
ReluLayer active2 = new ReluLayer();
FullyLayer hidden3 = new FullyLayer(inputCount, inputCount, false);
BNLayer bn3 = new BNLayer();
ReluLayer active3 = new ReluLayer();
FullyLayer hidden4 = new FullyLayer(inputCount, 10);
SoftmaxWithCrossEntropyLayer softmax = new SoftmaxWithCrossEntropyLayer(10);
netWork.addLayer(inputLayer);
netWork.addLayer(hidden1);
netWork.addLayer(bn1);
netWork.addLayer(active1);
netWork.addLayer(hidden2);
netWork.addLayer(bn2);
netWork.addLayer(active2);
netWork.addLayer(hidden3);
netWork.addLayer(bn3);
netWork.addLayer(active3);
netWork.addLayer(hidden4);
netWork.addLayer(softmax);
// SGDOptimizer optimizer = new SGDOptimizer(netWork, 20000, 0.001d);
// BGDOptimizer optimizer = new BGDOptimizer(netWork, 20000, 0.001d);
try {
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 10, 0.001f, 128, LearnRateUpdate.NONE, false);
// netWork.GRADIENT_CHECK = true;
long start = System.nanoTime();
long trainTime = System.nanoTime();
optimizer.train(trainData);
System.out.println("trainTime:"+((System.nanoTime() - trainTime) / 1e9) + "s.");
long testTime = System.nanoTime();
optimizer.test(testData);
System.out.println("testTime:"+((System.nanoTime() - testTime) / 1e9) + "s.");
System.out.println(((System.nanoTime() - start) / 1e9) + "s.");
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}finally {
try {
CUDAMemoryManager.freeAll();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
@Override
public void cnnNetwork_mnist_demo() {
// TODO Auto-generated method stub
/**
* 读取训练数据集
*/
String mnist_train_data = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\train-images.idx3-ubyte";
String mnist_train_label = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\train-labels.idx1-ubyte";
String mnist_test_data = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\t10k-images.idx3-ubyte";
String mnist_test_label = "C:\\Users\\Administrator\\Desktop\\dataset\\mnist\\t10k-labels.idx1-ubyte";
String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};
DataSet trainData = DataLoader.loadDataByUByte(mnist_train_data, mnist_train_label, labelSet, 1, 1 ,784,true);
DataSet testData = DataLoader.loadDataByUByte(mnist_test_data, mnist_test_label, labelSet, 1, 1 ,784,true);
int channel = 1;
int height = 28;
int width = 28;
CNN netWork = new CNN(new CrossEntropyLoss(), UpdaterType.momentum);
netWork.learnRate = 0.1f;
InputLayer inputLayer = new InputLayer(channel, 1, 784);
ConvolutionLayer conv1 = new ConvolutionLayer(channel, 6, width, height, 5, 5, 2, 1);
BNLayer bn1 = new BNLayer();
Relu
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
Omega-AI:基于java打造的深度学习框架,帮助你快速搭建神经网络,实现训练或测试模型,支持多线程运算,框架目前支持BP神经网络、卷积神经网络、vgg16、resnet、yolo等模型的构建
资源推荐
资源详情
资源评论
收起资源包目录
基于java打造的深度学习框架 (889个子文件)
test_batch.bin 29.31MB
data_batch_1.bin 29.31MB
data_batch_3.bin 29.31MB
data_batch_5.bin 29.31MB
data_batch_2.bin 29.31MB
data_batch_4.bin 29.31MB
test_batch.bin 29.31MB
data_batch_1.bin 29.31MB
data_batch_3.bin 29.31MB
data_batch_5.bin 29.31MB
data_batch_2.bin 29.31MB
data_batch_4.bin 29.31MB
test_batch.bin 29.31MB
data_batch_1.bin 29.31MB
data_batch_3.bin 29.31MB
data_batch_5.bin 29.31MB
data_batch_2.bin 29.31MB
data_batch_4.bin 29.31MB
test_batch.bin 29.31MB
data_batch_1.bin 29.31MB
data_batch_3.bin 29.31MB
data_batch_5.bin 29.31MB
data_batch_2.bin 29.31MB
data_batch_4.bin 29.31MB
yolov1.cfg 4KB
yolov1-tiny.cfg 2KB
styles.css 6KB
style.css 2KB
demo.css 1KB
BNKernel.cu 17KB
BNKernel.cu 14KB
BNKernel2.cu 14KB
BNKernel3.cu 7KB
Im2colKernelTmp.cu 5KB
Im2colKernelTmp.cu 5KB
PoolingV2Kernel.cu 5KB
MathKernel.cu 4KB
MathKernel.cu 4KB
MathKernel2.cu 4KB
updater.cu 3KB
Col2imKernel.cu 3KB
Col2imKernel.cu 3KB
PoolingKernel.cu 3KB
PoolingKernel.cu 3KB
CrossEntropyKernel.cu 2KB
Im2colKernel.cu 2KB
Im2colKernel.cu 2KB
Im2colKernel.cu 2KB
SoftmaxKernel.cu 2KB
activeFunction.cu 2KB
SoftmaxKernel.cu 1KB
BiasKernel.cu 1KB
BaseKernel.cu 1KB
RNNKernel.cu 1KB
updater.cu 1KB
AVGPoolingKernel.cu 1005B
ShortcutKernel.cu 707B
JCudaVectorAddKernel.cu 218B
JCudaVectorAddKernel.cu 218B
JCudaVectorAddKernel.cu 218B
JCudaVectorAddKernel.cu 181B
JCudaVectorAddKernel.cu 181B
JCudaVectorAddKernel.cu 181B
test_cuda.dll 141KB
test_cuda.dll 141KB
test_cuda.dll 135KB
.gitignore 368B
MNIST_train-images-idx3-ubyte.gz 9.45MB
MNIST_train-images-idx3-ubyte.gz 9.45MB
MNIST_train-images-idx3-ubyte.gz 9.45MB
MNIST_train-images-idx3-ubyte.gz 9.45MB
MNIST_t10k-images-idx3-ubyte.gz 1.57MB
MNIST_t10k-images-idx3-ubyte.gz 1.57MB
MNIST_t10k-images-idx3-ubyte.gz 1.57MB
MNIST_t10k-images-idx3-ubyte.gz 1.57MB
MNIST_train-labels-idx1-ubyte.gz 28KB
MNIST_train-labels-idx1-ubyte.gz 28KB
MNIST_train-labels-idx1-ubyte.gz 28KB
MNIST_train-labels-idx1-ubyte.gz 28KB
MNIST_t10k-labels-idx1-ubyte.gz 4KB
MNIST_t10k-labels-idx1-ubyte.gz 4KB
MNIST_t10k-labels-idx1-ubyte.gz 4KB
MNIST_t10k-labels-idx1-ubyte.gz 4KB
com_omega_engine_gpu_JNITest.h 1011B
com_omega_engine_gpu_JNITest.h 1011B
com_omega_engine_gpu_JNITest.h 1001B
AICar.html 22KB
AICar.html 22KB
AICar.html 22KB
AICar.html 22KB
AICarDefautMap.html 22KB
AICarDefautMap.html 22KB
AICarDefautMap.html 22KB
AICarDefautMap.html 22KB
origin2.html 17KB
origin2.html 17KB
origin2.html 17KB
origin2.html 17KB
MapEditor.html 14KB
MapEditor.html 14KB
共 889 条
- 1
- 2
- 3
- 4
- 5
- 6
- 9
资源评论
- lanpiao_872023-06-07超赞的资源,感谢资源主分享,大家一起进步!
Java程序员-张凯
- 粉丝: 1w+
- 资源: 6742
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功