package com.topprismcloud.rtm;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;
public class ORCUtil {
private static final String INDEX = "isi";
private static final int IMAGE_SIZE = 224;
private static Model model; // 模型
private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input)
static {
try {
model = Model.newInstance("model");
// 这里的model.pt是上面代码展示的那种方式保存的
model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));
Transform resize = new Resize(IMAGE_SIZE);
Transform toTensor = new ToTensor();
Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },
new float[] { 0.229f, 0.224f, 0.225f });
// Translator处理输入Image转为tensor、输出转为float[]
Translator<Image, float[]> translator = new Translator<Image, float[]>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
NDManager ndManager = ctx.getNDManager();
System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
NDArray transform = normalize
.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
System.out.println(transform.getShape());
NDList list = new NDList();
list.add(transform);
return list;
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
return ndList.get(0).toFloatArray();
}
};
predictor = new Predictor<>(model, translator, Device.cpu(), true);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void upload() throws Exception {
HttpHost host=new HttpHost("124.220.30.196", 9200, HttpHost.DEFAULT_SCHEME_NAME);
RestClientBuilder builder=RestClient.builder(host);
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));
builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));
RestHighLevelClient client = new RestHighLevelClient( builder);
// 批量上传请求
BulkRequest bulkRequest = new BulkRequest(INDEX);
File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");
for (File listFile : file.listFiles()) {
// float[] vector = predictor.predict(ImageFactory.getInstance()
// .fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));
float[] vector = predictor.predict(ImageFactory.getInstance()
.fromInputStream(new FileInputStream(listFile)));
// 构建文档
Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("url", "/resource/"+listFile.getName());
jsonMap.put("vector", vector);
jsonMap.put("user_id", "user123");
IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
bulkRequest.add(request);
}
client.bulk(bulkRequest, RequestOptions.DEFAULT);
client.close();
}
// 接收待搜索图片的inputstream,搜索与其相似的图片
public static List<SearchResult> search(InputStream input) throws Throwable {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
System.out.println(Arrays.toString(vector));
// 展示k个结果
int k = 100;
// 连接Elasticsearch服务器
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("124.220.30.196", 9200, "http")));
SearchRequest searchRequest = new SearchRequest(INDEX);
Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",
Collections.singletonMap("queryVector", vector));
FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders
.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时
.size(k);
searchRequest.source(searchSourceBuilder);
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
SearchHits hits = searchResponse.getHits();
List<SearchResult> list = new ArrayList<>();
for (SearchHit hit : hits) {
// 处理搜索结果
System.out.println(hit.toString());
SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
list.add(result);
}
client.close();
return list;
}
public static void main(String[] args) throws Throwable {
ORCUtil.upload();
System.out.println("hao");
}
}
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
以图搜图Java+html源代码.rar (13个子文件)
以图搜图Java+html源代码
z-project(java后端)
pom.xml 2KB
src
test
resources
java
main
resources
a.py 760B
model.pt 98.15MB
application.yml 23B
java
com
topprismcloud
rtm
AppApplication.java 355B
SearchResult.java 201B
ORCUtil.java 7KB
SearchController.java 699B
html
js
jquery-3.6.0.js 85KB
resource
3.png 42KB
1.png 69KB
2.png 43KB
index.html 5KB
共 13 条
- 1
资源评论
老李笔记
- 粉丝: 122
- 资源: 19
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功