本文最后更新于:2024年5月7日 下午

Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 本文记录 Python 下 pytorch 模型转换 ONNX 的相关内容。

简介

ONNX Runtime是一个跨平台的推理和训练机器学习加速器。

在 Pytorch 框架中训练好模型后,在部署时可以转成 onnx,再进行下一步部署。

模型转换

核心代码:

  • 生成 onnx 模型: torch.onnx.export
  • 简化 onnx 模型: onnxsim.simplify
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import onnxsim
import onnx

def export_to_onnx(model, output_path, input_shape, input_name, output_names):
dummy_input = torch.rand(1, *input_shape)

model.eval()

temp_dict = dict()

temp_onnx_path = output_path.replace('.onnx', '_temp.onnx')

torch.onnx.export(model, # pytorch 模型
(dummy_input, 'ALL'), # 可以输入 tuple
temp_onnx_path, # 输出 onnx 模型路径
verbose=False, # 聒噪
opset_version=11, # onnx 版本
export_params=True, # 一个指示是否导出模型参数(权重)以及模型架构的标志。
do_constant_folding=True, # 一个指示是否在导出过程中折叠常量节点的标志
input_names=[input_name], # 输入节点名称列表(可选)
output_names=output_names # 输出节点名称列表(可选)
)

input_data = {'image': dummy_input.cpu().numpy()}
model_sim, flag = onnxsim.simplify(temp_onnx_path, input_data=input_data) # 简化 onnx

if flag:
onnx.save(model_sim, output_path)
print(f"simplify onnx model successfully !")
else:
print(f"simplify onnx model failed !!!")
  • 注意: torch.onnx.export 输入伪数据可以支持字符串,但是在 onnx 模型中仅会记录张量流转的路径,字符串、分支逻辑一般不会保存。

动态输出

上述转换方式导出的 onnx 模型仅支持 dummy_input 尺寸的输入数据,模型稳定,速度快,但是不够灵活,当我们需要向网络送入不同尺寸的输入数据时需要开启动态轴

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
...

input_name = 'image' # 输入节点名称

dynamic_axes = dict()
dynamic_axes[input_name] = [0, 2, 3] # 输入节点开启动态输入的轴

output_name = 'output' # 输出节点名称
dynamic_axes[output_name] = [2, 3] # 输出节点动态轴

torch.onnx.export(self,
(dummy_input, 'ALL'),
temp_onnx_path,
verbose=False,
opset_version=11,
export_params=True,
do_constant_folding=True,
input_names=[input_name],
output_names=output_names,
dynamic_axes=dynamic_axes # 添加参数
)

动态输入数据可能使得网络变复杂,如果在 onnx 模型中出现很多莫名其妙的 if 节点考虑是否有 squeeze 操作。

模型检查

onnx 加载模型后可以检测是否合法。

1
2
3
4
5
6
7
8
# onnx check
onnx_model = onnx.load(onnx_model_path)
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is invalid: %s' % e)
else:
print('The model is valid!')

加载、运行 ONNX 模型

ONNXruntime 安装:

1
2
pip install onnxruntime       # CPU build
pip install onnxruntime-gpu # GPU build

onnxruntime 仅能在 cpu 上推理模型,onnxruntime-gpu 可以在 cpu 或 gpu 上推断模型。

CPU 运行

1
2
3
4
5
6
import onnxruntime

session = onnxruntime.InferenceSession("path to model")
session.get_modelmeta()
results = session.run(["output1""output2"], {"input1": indata1, "input2": indata2})
results = session.run([], {"input1": indata1, "input2": indata2})

可以对比 onnx 模型结果与 pytorch 模型结果的差异来对转换结果进行验证。

GPU 运行

  • 查看 onnx 是否支持 gpu 运行
1
2
3
4
5
import onnxruntime

print(onnxruntime.get_available_providers())

# >>> ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']

如果包含 'CUDAExecutionProvider' 表示当前环境支持 onnxruntime 在 gpu 上运行

1
2
3
4
5
# 检查是否有'CUDAExecutionProvider'
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
print('ONNX Runtime GPU is available.')
else:
print('ONNX Runtime GPU is not available.')
  • 模型运行
1
2
3
4
5
6
7
import onnxruntime as ort
# 加载ONNX模型
session = ort.InferenceSession('your_model.onnx', providers=['CUDAExecutionProvider'])
# 准备输入数据
inputs = {session.get_inputs()[0].name: your_input_data}
# 执行模型
outputs = session.run(None, inputs)

常见错误

Status Message: CUDNN failure 9: CUDNN_STATUS_NOT_SUPPORTED ; GPU=0 ; hostname=XDHN-SH-026

可能的原因:

  1. CUDNN 没有安装或者版本不匹配或者没有成功识别。

    确定显卡驱动与显卡型号是否匹配,Cuda 是否和显卡驱动匹配, Cudnn 是否和 Cuda 匹配, Onnx 版本是否和 Onnxruntime-gpu 匹配, Onnxruntime 是否和 Cuda Cudnn匹配

    onnxruntime 版本信息参考 https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

    ONNX Runtime CUDA cuDNN Notes
    1.17 12.2 8.9.2.26 (Linux) 8.9.2.26 (Windows) The default CUDA version for ORT 1.17 is CUDA 11.8. To install CUDA 12 package, please look at Install ORT. Due to low demand on Java GPU package, only C++/C# Nuget and Python packages are released with CUDA 12.2
    1.15 1.16 1.17 11.8 8.2.4 (Linux) 8.5.0.96 (Windows) Tested with CUDA versions from 11.6 up to 11.8, and cuDNN from 8.2.4 up to 8.7.0
    1.14 1.13.1 1.13 11.6 8.2.4 (Linux) 8.5.0.96 (Windows) libcudart 11.4.43 libcufft 10.5.2.100 libcurand 10.2.5.120 libcublasLt 11.6.5.2 libcublas 11.6.5.2 libcudnn 8.2.4
    1.12 1.11 11.4 8.2.4 (Linux) 8.2.2.26 (Windows) libcudart 11.4.43 libcufft 10.5.2.100 libcurand 10.2.5.120 libcublasLt 11.6.5.2 libcublas 11.6.5.2 libcudnn 8.2.4
    1.10 11.4 8.2.4 (Linux) 8.2.2.26 (Windows) libcudart 11.4.43 libcufft 10.5.2.100 libcurand 10.2.5.120 libcublasLt 11.6.1.51 libcublas 11.6.1.51 libcudnn 8.2.4
    1.9 11.4 8.2.4 (Linux) 8.2.2.26 (Windows) libcudart 11.4.43 libcufft 10.5.2.100 libcurand 10.2.5.120 libcublasLt 11.6.1.51 libcublas 11.6.1.51 libcudnn 8.2.4
    1.8 11.0.3 8.0.4 (Linux) 8.0.2.39 (Windows) libcudart 11.0.221 libcufft 10.2.1.245 libcurand 10.2.1.245 libcublasLt 11.2.0.252 libcublas 11.2.0.252 libcudnn 8.0.4
    1.7 11.0.3 8.0.4 (Linux) 8.0.2.39 (Windows) libcudart 11.0.221 libcufft 10.2.1.245 libcurand 10.2.1.245 libcublasLt 11.2.0.252 libcublas 11.2.0.252 libcudnn 8.0.4
    1.5-1.6 10.2 8.0.3 CUDA 11 can be built from source
    1.2-1.4 10.1 7.6.5 Requires cublas10-10.2.1.243; cublas 10.1.x will not work
    1.0-1.1 10.0 7.6.4 CUDA versions from 9.1 up to 10.1, and cuDNN versions from 7.1 up to 7.4 should also work with Visual Studio 2017
  2. 输入数据尺寸不合适

    排查了很久发现我的就是这种情况,将尺寸改小一倍就可以正常运行了。

参考资料



文章链接:
https://www.zywvvd.com/notes/study/deep-learning/deploy/onnx-transfer-infer/onnx-infer/


“觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭”

微信二维码

微信支付

支付宝二维码

支付宝支付

Python ONNX 模型转换、加载、简化、推断
https://www.zywvvd.com/notes/study/deep-learning/deploy/onnx-transfer-infer/onnx-infer/
作者
Yiwei Zhang
发布于
2024年2月2日
许可协议