package com.github.handong0123.tensorflow.deploy.session.model;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelDataType;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelInput;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelOutput;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelParam;
import com.google.common.primitives.Longs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import javax.annotation.PostConstruct;
import java.io.IOException;
import java.lang.reflect.Array;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
/**
* Tensorflow model service achieve
* <p>
* Provide some operations on the model
*
* @author handong
*/
public class TensorflowModelServiceImpl implements TensorflowModelService {
private static final Logger LOG = LoggerFactory.getLogger(TensorflowModelServiceImpl.class);
private static final String DEFAULT_GPU_ID = "-1";
private static final float DEFAULT_PER_GPU_MEMORY_FRACTION = 0.95f;
private String modelFile;
private String modelPath;
private String gpuId;
private float perGpuMemoryFraction;
private Session session;
private Graph graph;
public TensorflowModelServiceImpl(String modelFile, String modelPath) {
this(modelFile, modelPath, DEFAULT_GPU_ID, DEFAULT_PER_GPU_MEMORY_FRACTION);
}
public TensorflowModelServiceImpl(String modelFile, String modelPath, String gpuId) {
this(modelFile, modelPath, gpuId, DEFAULT_PER_GPU_MEMORY_FRACTION);
}
public TensorflowModelServiceImpl(String modelFile, String modelPath, String gpuId, float perGpuMemoryFraction) {
this.modelFile = modelFile;
this.modelPath = modelPath;
this.perGpuMemoryFraction = perGpuMemoryFraction;
this.gpuId = gpuId;
try {
this.init();
} catch (Exception e) {
e.printStackTrace();
}
}
@PostConstruct
public void init() throws IOException {
byte[] graphDef = Files.readAllBytes(Paths.get(this.modelPath, this.modelFile));
graph = new Graph();
if (DEFAULT_GPU_ID.equals(this.gpuId)) {
graph.importGraphDef(graphDef);
this.session = new Session(graph);
LOG.info("CPU:model init success,{}", Paths.get(this.modelPath, this.modelFile));
} else {
GPUOptions gpuOptions = GPUOptions.newBuilder()
.setVisibleDeviceList(this.gpuId)
.setPerProcessGpuMemoryFraction(this.perGpuMemoryFraction)
.setAllowGrowth(true)
.build();
ConfigProto configProto = ConfigProto.newBuilder()
.setGpuOptions(gpuOptions)
.setAllowSoftPlacement(true)
.build();
this.session = new Session(graph, configProto.toByteArray());
LOG.info("GPU:model init success,{}", Paths.get(this.modelPath, this.modelFile));
}
}
@Override
public ModelOutput predict(ModelInput modelInput) {
if (null == modelInput) {
return null;
}
ModelOutput outPut = new ModelOutput();
long startTime = System.currentTimeMillis();
List<Tensor> inputTensorList = new ArrayList<>();
List<Tensor<?>> outTensorList = new ArrayList<>();
try {
Session.Runner runner = this.session.runner();
for (ModelParam placeHolder : modelInput.getPlaceHolderInput()) {
Object data = placeHolder.getData();
Tensor tensorInput;
if (data instanceof String) {
tensorInput = Tensor.create(new byte[][]{((String) data).getBytes()});
} else if (data instanceof String[]) {
String[] originData = (String[]) data;
byte[][] res = new byte[originData.length][];
for (int i = 0; i < originData.length; i++) {
res[i] = originData[i].getBytes();
}
tensorInput = Tensor.create(res);
} else {
tensorInput = Tensor.create(data);
}
runner = runner.feed(placeHolder.getPlaceHolderName(), tensorInput);
inputTensorList.add(tensorInput);
}
List<String> tensorNameList = new ArrayList<>();
for (String name : modelInput.getExpectedOutput().keySet()) {
runner = runner.fetch(name);
tensorNameList.add(name);
}
outTensorList = runner.run();
LOG.info("Model Run Cost Time: {}", System.currentTimeMillis() - startTime);
if (tensorNameList.size() != outTensorList.size()) {
throw new Exception("Model Run Error: OutTensor Size Error");
}
for (int i = 0; i < outTensorList.size(); i++) {
Tensor tensor = outTensorList.get(i);
String outTensorName = tensorNameList.get(i);
int[] shape = Longs.asList(tensor.shape()).stream().mapToInt(Long::intValue).toArray();
ModelDataType type = modelInput.getExpectedOutput().get(outTensorName);
Object array = Array.newInstance(type.getType(), shape);
tensor.copyTo(array);
outPut.addOutput(outTensorName, array);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
inputTensorList.forEach(Tensor::close);
outTensorList.forEach(Tensor::close);
}
return outPut;
}
@Override
public void modelReload() {
try {
Graph graph = new Graph();
byte[] graphDef = Files.readAllBytes(Paths.get(this.modelPath, this.modelFile));
graph.importGraphDef(graphDef);
Session session = new Session(graph);
synchronized (this) {
LOG.info("Start Model Reload...");
this.session.close();
this.session = session;
this.graph.close();
this.graph = graph;
LOG.info("Finish Model Reload");
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
简化与优化tensorflow模型的Java部署
版权申诉
131 浏览量
2024-02-06
09:42:04
上传
评论
收藏 21KB ZIP 举报
Java程序员-张凯
- 粉丝: 1w+
- 资源: 6727
最新资源
- 谷歌浏览器自动化测试版113.0.5672.0(包含linux,windows32/64,mac三个版本,不会自动更新)
- uniapp中tab切换,底部内容跟着移动,相反,底部移动,tab也跟着切换-组件
- 基于JS+TS实现跨平台3D相机控制器-附项目源码-优质项目分享.zip
- 跨相机-基于Rust实现的跨平台相机捕获-附项目源码-优质项目分享.zip
- odise 14离线安装包 大众斯柯达奥迪 5054 6153
- 网页设计期末作业-纯html加css+少量js-盗墓笔记旅游导航网站.rar
- 算法笔记模拟退火.rar
- MATLAB大数据仿真案例-蚁群算法(ACO)用于求解旅行商(TSP)问题.rar
- 基于yolov5的吸烟行为检测源码+模型.zip
- MySQL基础知识-个人笔记.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈