没有合适的资源?快使用搜索试试~ 我知道了~
Shape层1
资源详情
资源评论
资源推荐
Shape
层
初
始
⽰
例
代
码
(
TensorRT8
中
的
⽤
法
)
TensorRT7 + static shape
模
式
中
的
⽤
法
TensorRT7 + dynamic shape
模
式
中
的
⽤
法
TensorRT6
中
的
⽤
法
(
已
废
弃
)
[TODO]
初
始
⽰
例
代
码
(
TensorRT8
中
的
⽤
法
)
import numpy as np
from cuda import cudart
import tensorrt as trt
nIn, cIn, hIn, wIn = 1, 3, 4, 5 #
输
⼊
张
量
NCHW
data = np.arange(cIn, dtype=np.float32).reshape(cIn, 1, 1) * 100 + np.arange(hIn).reshape(1, hIn, 1) *
10 + np.arange(wIn).reshape(1, 1, wIn) #
输
⼊
数
据
data = data.reshape(nIn, cIn, hIn, wIn).astype(np.float32)
np.set_printoptions(precision=8, linewidth=200, suppress=True)
cudart.cudaDeviceSynchronize()
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (nIn, cIn, hIn, wIn))
#---------------------------------------------------------- --------------------#
替
换
部
分
shapeLayer = network.add_shape(inputT0)
#---------------------------------------------------------- --------------------#
替
换
部
分
network.mark_output(shapeLayer.get_output(0))
engineString = builder.build_serialized_network(network, config)
engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
context = engine.create_execution_context()
_, stream = cudart.cudaStreamCreate()
inputH0 = np.ascontiguousarray(data.reshape(-1))
outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
_, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
_, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)
cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
context.execute_async_v2([int(inputD0), int(outputD0)], stream)
cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
cudart.cudaStreamSynchronize(stream)
print("inputH0 :", data.shape)
print(data)
print("outputH0:", outputH0.shape)
print(outputH0)
cudart.cudaStreamDestroy(stream)
cudart.cudaFree(inputD0)
StoneChan
- 粉丝: 27
- 资源: 321
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0