> 知乎地址: [https://zhuanlan.zhihu.com/c_1101089619118026752](https://zhuanlan.zhihu.com/c_1101089619118026752)
> 作者: 小哲
> github: https://github.com/lxztju/notes
> 微信公众号: 小哲AI
提供两种pytorch模型的部署方式,一种为web部署,一种是c++部署
[TOC]
实现过程中,查找大量的资料,看了很多的代码和文章,参考很多内容,不记得主要参考哪些,如有侵权,联系删除
## 1. web部署
web部署就是采用REST API的形式进行接口调用。
web部署的方式采用flask+ redis的方法进行模型部署,pytorch为模型的框架,flask为后端框架,redis是采用键值的形式存储图像的数据库。
各package包的版本:
```
pytorch 1.2.0
flask 1.0.2
Redis 3.0.6
```
### 1. Redis安装,配置
ubuntu Redis的安装,下载地址:https://redis.io/download
安装教程: https://www.jianshu.com/p/bc84b2b71c1c
```shell
wget http://download.redis.io/releases/redis-6.0.6.tar.gz
# 拷贝到/usr/local目录下
cp redis-3.0.0.rar.gz /usr/local
# 解压
tar xzf redis-6.0.6.tar.gz
cd /usr/local/redis-6.0.6
# 安装至指定的目录下
make PREFIX=/usr/local/redis install
```
Redis配置:
```shell
# redis.conf是redis的配置文件,redis.conf在redis源码目录。
# 拷贝配置文件到安装目录下
# 进入源码目录,里面有一份配置文件 redis.conf,然后将其拷贝到安装路径下
cd /usr/local/redis
cp /usr/local/redis-3.0.0/redis.conf /usr/local/redis/bin
```
此时在/usr/local/redis/bin目录下,有如下文件:
```shell
redis-benchmark redis性能测试工具
redis-check-aof AOF文件修复工具
redis-check-rdb RDB文件修复工具
redis-cli redis命令行客户端
redis.conf redis配置文件
redis-sentinal redis集群管理工具
redis-server redis服务进程
```
Redis服务开启:
```shell
# 这是以前端方式启动,关闭终端,服务停止
./redis-server
# 后台方式启动
#修改redis.conf配置文件, daemonize yes 以后端模式启动
cd /usr/local/redis
./bin/redis-server ./redis.conf
```
连接Redis
```shell
/usr/local/redis/bin/redis-cli
```
关闭Redis
```shell
cd /usr/local/redis
./bin/redis-cli shutdown
```
强行中止Redis,(可能会丢失持久化数据)
```shell
pkill redis-server
```
### 2. server端
```python
@app.route('/predict', methods=['POST'])
def predict():
data = {'Success': False}
if request.files.get('image'):
now = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))
image = request.files['image'].read()
image = Image.open(io.BytesIO(image))
image = image_transform(InputSize)(image).numpy()
# 将数组以C语言存储顺序存储
image = image.copy(order="C")
# 生成图像ID
k = str(uuid.uuid4())
d = {"id": k, "image": base64_encode_image(image)}
# print(d)
db.rpush(ImageQueue, json.dumps(d))
# 运行服务
while True:
# 获取输出结果
output = db.get(k)
# print(output)
if output is not None:
output = output.decode("utf-8")
data["predictions"] = json.loads(output)
db.delete(k)
break
time.sleep(ClientSleep)
data["success"] = True
return jsonify(data)
if __name__ == '__main__':
app.run(host='127.0.0.1', port =5000,debug=True )
```
### 3. Redis服务器端
```python
def classify_process(filepath):
# 导入模型
print("* Loading model...")
model = load_checkpoint(filepath)
print("* Model loaded")
while True:
# 从数据库中创建预测图像队列
queue = db.lrange(ImageQueue, 0, BatchSize - 1)
imageIDs = []
batch = None
# 遍历队列
for q in queue:
# 获取队列中的图像并反序列化解码
q = json.loads(q.decode("utf-8"))
image = base64_decode_image(q["image"], ImageType,
(1, InputSize[0], InputSize[1], Channel))
# 检查batch列表是否为空
if batch is None:
batch = image
# 合并batch
else:
batch = np.vstack([batch, image])
# 更新图像ID
imageIDs.append(q["id"])
# print(imageIDs)
if len(imageIDs) > 0:
print("* Batch size: {}".format(batch.shape))
preds = model(torch.from_numpy(batch.transpose([0, 3,1,2])))
results = decode_predictions(preds)
# 遍历图像ID和预测结果并打印
for (imageID, resultSet) in zip(imageIDs, results):
# initialize the list of output predictions
output = []
# loop over the results and add them to the list of
# output predictions
print(resultSet)
for label in resultSet:
prob = label.item()
r = {"label": label.item(), "probability": float(prob)}
output.append(r)
# 保存结果到数据库
db.set(imageID, json.dumps(output))
# 从队列中删除已预测过的图像
db.ltrim(ImageQueue, len(imageIDs), -1)
time.sleep(ServeSleep)
def load_checkpoint(filepath):
checkpoint = torch.load(filepath, map_location='cpu')
model = checkpoint['model'] # 提取网络结构
model.load_state_dict(checkpoint['model_state_dict']) # 加载网络权重参数
for parameter in model.parameters():
parameter.requires_grad = False
model.eval()
return model
if __name__ == '__main__':
filepath = '../c/resnext101_32x8.pth'
classify_process(filepath)
```
### 4. 调用测试
```shell
curl -X POST -F image=@test.jpg 'http://127.0.0.1:5000/predict'
```
```python
from threading import Thread
import requests
import time
# 请求的URL
REST_API_URL = "http://127.0.0.1:5000/predict"
# 测试图片
IMAGE_PATH = "./test.jpg"
# 并发数
NUM_REQUESTS = 500
# 请求间隔
SLEEP_COUNT = 0.05
def call_predict_endpoint(n):
# 上传图像
image = open(IMAGE_PATH, "rb").read()
payload = {"image": image}
# 提交请求
r = requests.post(REST_API_URL, files=payload).json()
# 确认请求是否成功
if r["success"]:
print("[INFO] thread {} OK".format(n))
else:
print("[INFO] thread {} FAILED".format(n))
# 多线程进行
for i in range(0, NUM_REQUESTS):
# 创建线程来调用api
t = Thread(target=call_predict_endpoint, args=(i,))
t.daemon = True
t.start()
time.sleep(SLEEP_COUNT)
time.sleep(300)
```
## 2. c++模型部署
教程:https://pytorch.apachecn.org/docs/1.2/beginner/Intro_to_TorchScript_tutorial.html
利用TorchScript进行模型c++部署,
业界与学术界最大的区别在于工业界的模型需要落地部署,学界更多的是关心模型的精度要求,而不太在意模型的部署性能。一般来说,我们用深度学习框架训练出一个模型之后,使用Python就足以实现一个简单的推理演示了。但在生产环境下,Python的可移植性和速度性能远不如C++。所以对于深度学习算法工程师而言,Python通常用来做idea的快速实现以及模型训练,而用C++作为模型的生产工具。目前PyTorch能够完美的将二者结合在一起。实现PyTorch模型部署的核心技术组件就是TorchScript和libtorch。
所以基于PyTorch的深度学习算法工程化流程大体如下图所示:
![img](https://mmbiz.qpic.cn/mmbiz_png/4lN1XOZshffPolyqD9QQuyauzLAMibQdla7uP4gXjQyQC2mc5npAa3fJ1BayELlfNGPpvYzFADD91JxqtNFnaDQ/640?wx_fmt=png&tp=webp&wxfrom=5&wx_lazy=1&wx_co=1)
### 1. 安装libtorch
[pytorch官网](https://pytorch.org/) 下载libtorch
解压到指定的位置,我这里直接解压到`/home/xxx/`.
没有合适的资源?快使用搜索试试~ 我知道了~
pytorch_classification:利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型...
共214个文件
pyc:108个
py:94个
csv:3个
4星 · 超过85%的资源 需积分: 49 116 下载量 117 浏览量
2021-02-05
00:27:15
上传
评论 24
收藏 4.39MB ZIP 举报
温馨提示
pytorch_classification 利用pytorch实现图像分类,其中包含的密集网,resnext,mobilenet,efficiencynet,resnet等图像分类网络,可以根据需要再行利用torchvision扩展其他的分类算法 实现功能 基础功能利用pytorch实现图像分类 包含带有warmup的cosine学习率调整 warmup的step学习率优调整 多模型融合预测,修正与投票融合 利用flask实现模型云端api部署 使用tta测试时增强进行预测 添加label smooth的pytorch实现(标签平滑) 添加使用cnn提取特征,并使用SVM,RF,MLP,KN
资源详情
资源评论
资源推荐
收起资源包目录
pytorch_classification:利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,cnn提取特征,svm或者随机森林等进行分类,模型蒸馏,一个完整的代码 (214个子文件)
main.cpp 3KB
method_3.csv 83KB
method_1.csv 83KB
method_2.csv 83KB
Feature_Visualization-checkpoint.ipynb 1.73MB
Feature_Visualization.ipynb 1.73MB
LICENSE 1KB
README.md 14KB
README.md 4KB
f1_conv1.png 1.51MB
test.png 143KB
roi_heads.py 22KB
roi_heads.py 22KB
rpn.py 17KB
rpn.py 17KB
mask_rcnn.py 16KB
mask_rcnn.py 16KB
keypoint_rcnn.py 16KB
keypoint_rcnn.py 16KB
faster_rcnn.py 15KB
faster_rcnn.py 15KB
utils.py 14KB
utils.py 14KB
inception.py 13KB
inception.py 13KB
_utils.py 12KB
_utils.py 12KB
resnet.py 11KB
resnet.py 11KB
model.py 9KB
model.py 9KB
densenet.py 8KB
densenet.py 8KB
googlenet.py 8KB
googlenet.py 8KB
shufflenetv2.py 7KB
shufflenetv2.py 7KB
build_model.py 7KB
build_model.py 7KB
train_val.py 6KB
vgg.py 6KB
vgg.py 6KB
train.py 6KB
squeezenet.py 5KB
squeezenet.py 5KB
transform.py 5KB
transform.py 5KB
cnn_ml.py 5KB
segmentation.py 5KB
segmentation.py 5KB
train_kd.py 5KB
mobilenet.py 4KB
mobilenet.py 4KB
warmup_lr.py 4KB
resnext_wsl.py 3KB
resnext_wsl.py 3KB
deeplabv3.py 3KB
deeplabv3.py 3KB
transform.py 3KB
predict.py 3KB
dataset.py 3KB
transform.py 3KB
redis_db.py 3KB
__init__.py 3KB
__init__.py 3KB
server.py 3KB
Feature_Visualization.py 3KB
_utils.py 2KB
_utils.py 2KB
backbone_utils.py 2KB
backbone_utils.py 2KB
generalized_rcnn.py 2KB
generalized_rcnn.py 2KB
utils.py 2KB
kaggle_vote.py 2KB
alexnet.py 2KB
alexnet.py 2KB
server.py 1KB
cfg.py 1KB
preprocess.py 1KB
fcn.py 1KB
fcn.py 1KB
label_smoothing_pytorch.py 1KB
streess_test.py 1012B
random_eraser.py 1004B
_utils.py 1000B
_utils.py 1000B
client.py 908B
SaveTorchscriptModel.py 778B
image_list.py 754B
image_list.py 754B
loss_kd.py 621B
cfg.py 497B
__init__.py 306B
__init__.py 306B
__init__.py 169B
__init__.py 169B
utils.py 151B
utils.py 151B
__init__.py 102B
共 214 条
- 1
- 2
- 3
三渔
- 粉丝: 23
- 资源: 4544
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论3