本文最后更新于:2024年1月14日 晚上

pytorch 模型部署很重要的一步是转存pth模型为ONNX,本文记录方法。

转存 onnx

  • 建立自己的pytorch模型,并加载权重
1
2
model = create_model(num_classes=2)
model.load_state_dict(load(model_path, map_location='cpu')["model"])
  • 转存onnx文件
1
2
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')
torch.onnx._export(model, dummy_input, "faster_rcnn.onnx", verbose=True, opset_version=11)

将模型保存在了当前目录的 faster_rcnn.onnx文件内

验证 onnx 有效性

  • 安装 onnxruntime
1
pip install onnxruntime
  • 加载onnx模型并测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import onnxruntime
from onnxruntime.datasets import get_example

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 测试数据
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')

example_model = get_example(<absolute_root_to_your_onnx_model_file>)
# netron.start(example_model) 使用 netron python 包可视化网络
sess = onnxruntime.InferenceSession(example_model)

# onnx 网络输出
onnx_out = sess.run(None, {<input_layer_name_of_your_network>: to_numpy(dummy_input)})
print(onnx_out)

model.eval()
with torch.no_grad():
# pytorch model 网络输出
torch_out = model(dummy_input)
print(torch_out)
  • 输出:
1
2
3
4
5
6
7
8
9
onnx_out
[array([[ 0. , 93.246 , 228.95842 , 256. ],
[ 0. , 2.6370468, 209.39705 , 148.17822 ]],
dtype=float32), array([1, 1], dtype=int64), array([0.1501071 , 0.07568519], dtype=float32)]

torch_out
[{'boxes': tensor([[ 0.0000, 93.2459, 228.9584, 256.0000],
[ 0.0000, 2.6370, 209.3971, 148.1782]]), 'labels': tensor([1, 1]), 'scores': tensor([0.1501, 0.0757])}]

获取自己网络输入层名称

  • 有时对网络不熟悉的情况下不清楚模型输入层的名称,可以使用Netron可视化自己的网络,获取输入层名称,喂入onnx的sess中。

注意 !!!

  • pytorch 模型在转 ONNX 模型的过程中,使用的导出器是一个基于轨迹的导出器,这意味着它执行时需要运行一次模型,然后导出实际参与运算的运算符. 这也意味着, 如果你的模型是动态的,例如,改变一些依赖于输入数据的操作,这时的导出结果是不准确的.同样,一 个轨迹可能只对一个具体的输入尺寸有效 (这是为什么我们在轨迹中需要有明确的输入的原因之一.) 我们建议检查 模型的轨迹,确保被追踪的运算符是合理的. ——— pytorch 文档

  • 也就是说,如果网络模块中存在 if… else… 类似的分支,在生成ONNX模型时会依据所使用的初始数据来选择其中某一个分支,这样所生成的ONNX模型仅会保留这一个分支的结构,在原始pytorch模型中的其他逻辑能力在该模型中不复存在。

参考资料



文章链接:
https://www.zywvvd.com/notes/study/deep-learning/pytorch/covert-pt-to-onnx/covert-pt-to-onnx/


“觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭”

微信二维码

微信支付

支付宝二维码

支付宝支付

pytorch pth 模型转 onnx模型,并验证结果正确性
https://www.zywvvd.com/notes/study/deep-learning/pytorch/covert-pt-to-onnx/covert-pt-to-onnx/
作者
Yiwei Zhang
发布于
2020年11月19日
许可协议