在在C++中加载中加载TorchScript模型的方法模型的方法
主要介绍了在C++中加载TorchScript模型的方法,本文通过实例代码给大家介绍的非常详细,具有一定的参考
借鉴价值,需要的朋友可以参考下
本教程已更新为可与PyTorch 1.2一起使用
顾名思义,PyTorch的主要接口是Python编程语言。尽管Python是合适于许多需要动态性和易于迭代的场景,并且是首选的语
言,但同样的,在许多情况下,Python的这些属性恰恰是不利的。后者通常适用的一种环境是要求生产-低延迟和严格部署。
对于生产场景,即使只将C ++绑定到Java,Rust或Go之类的另一种语言中,它也是经常选择的语言。以下各段将概述
PyTorch提供的从现有Python模型到可以完全从C ++加载和执行的序列化表示形式的路径,而无需依赖Python。
步骤步骤1:将:将PyTorch模型转换为模型转换为Torch脚本脚本
PyTorch模型从Python到C ++的旅程由Torch Script启动,Torch Script是PyTorch模型的一种表示形式,可以由Torch Script编
译器理解,编译和序列化。如果您是从使用vanilla“eager” API编写的现有PyTorch模型开始的,则必须首先将模型转换为
Torch脚本。在最常见的情况下(如下所述),这只需要花费很少的功夫。如果您已经有了Torch脚本模块,则可以跳到本教
程的下一部分。
有两种将PyTorch模型转换为Torch脚本的方法。第一种称为跟踪,一种机制,其中通过使用示例输入对模型的结构进行一次
评估,并记录这些输入在模型中的流量,从而捕获模型的结构。这适用于有限使用控制流的模型。第二种方法是在模型中添加
显式批注,以告知Torch Script编译器可以根据Torch Script语言施加的约束直接解析和编译模型代码。
提示:您可以在官方 Torch脚本参考 中找到有关这两种方法的完整文档,以及使用方法的进一步指导。
方法方法1:通过跟踪转换为:通过跟踪转换为Torch脚本脚本
要将PyTorch模型通过跟踪转换为Torch脚本,必须将模型的实例以及示例输入传递给 torch.jit.trace 函数。这将产生一
个 torch.jit.ScriptModule 对象,该对象的模型评估痕迹将嵌入模块的 forward 方法中:
import torch
import torchvision
# 你模型的一个实例.
model = torchvision.models.resnet18()
# 您通常会提供给模型的forward()方法的示例输入。
example = torch.rand(1, 3, 224, 224)
# 使用`torch.jit.trace `来通过跟踪生成`torch.jit.ScriptModule`
traced_script_module = torch.jit.trace(model, example)
现在可以对跟踪的 ScriptModule 进行评估,使其与常规PyTorch模块相同:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
方法方法2:通过注释转换为:通过注释转换为Torch脚本脚本
在某些情况下,例如,如果模型采用特定形式的控制流,则可能需要直接在Torch脚本中编写模型并相应地注释模型。例如,
假设您具有以下vanilla Pytorch模型:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
因为此模块的前向方法使用取决于输入的控制流,所以它不适合跟踪。相反,我们可以将其转换为 ScriptModule 。为了将模
块转换为 ScriptModule ,需要使用 torch.jit.script 编译模块,如下所示:
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)