本文最后更新于:2024年3月18日 上午

ONNX 部署pytorch 模型时,可能会遇到 adaptive_avg_pool 算子不支持而报错的情况,本文记录解决方案。

简介

自适应平均池算子是自适应平均池的简称,是深度学习和神经网络体系结构中常用的一种数学运算。可将张量池化到任意的尺寸上。

问题复现

有时在模型转换到 ONNX 时报错:

1
Unsupported: ONNX export of operator adaptive_avg_pool1d, output size that are not factor of input size. Please feel free to request support or submit a pull request on PyTorch GitHub.

pytorch 仓库也有这个问题(2D算子):

https://github.com/pytorch/pytorch/issues/42653

解决方案

用朴实的 torch 语法重写这个算子

方案一

上述 issue 中有大神提到了解决方案(2D):

1
2
3
4
5
6
7
8
9
10
11
class AdaptiveAvgPool2dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool2dCustom, self).__init__()
self.output_size = np.array(output_size)

def forward(self, x: torch.Tensor):
stride_size = np.floor(np.array(x.shape[-2:]) / self.output_size).astype(np.int32)
kernel_size = np.array(x.shape[-2:]) - (self.output_size - 1) * stride_size
avg = nn.AvgPool2d(kernel_size=list(kernel_size), stride=list(stride_size))
x = avg(x)
return x

思路是将原始数据维度降维到新的目标维度,通过动态自适应调整池化的步长和窗口实现自适应池化。

我对照这份代码修改出了 1D 的算子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class AdaptiveAvgPool1dCustom(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool1dCustom, self).__init__()
self.output_size = np.array(output_size)

def forward(self, x: torch.Tensor):
cur_shape = np.array(x.shape[-1])
if cur_shape < self.output_size:
raise RuntimeError(f"AdaptiveAvgPool1dCustom is converting {cur_shape} feature to {self.output_size} by avgpool which is not supported, suggestion is to change outputsize to input shape {cur_shape}.")
stride_size = np.floor(np.array(x.shape[-1]) / self.output_size).astype(np.int32)
kernel_size = np.array(x.shape[-1]) - (self.output_size - 1) * stride_size
avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride_size)
x = avg(x)
return x

但是该方法导出的 onnx 模型有时会在 onnx 运行时报错,可能是因为输出维度更大时 stride_size 为 0 导致的

方案二

上述代码在数据降维的时候可以正常运行,但是当数据维度升高时无法正常工作,而且输出结果与原始自适应池化算子不一致。

这是由于原始自适应池化算子的计算原理与上述方案不同:
$$
lstart=floor(i*L_{in}/L_{out})
$$

$$
lend=ceil((i+1)*L_{in}/L_{out})
$$

$$
Output(i)=\frac{sum(Input[lstart:lend])}{(lstart-lend)}
$$

上述 issue 中也有 大神 提到了这种原理的计算方式,这篇博客 也提到了类似计算方法:

1
2
3
4
5
6
7
8
def torch_pool(inputs, target_size):
start_points = (torch.arange(target_size, dtype=torch.float32) * (inputs.size(-1) / target_size)).long()
end_points = ((torch.arange(target_size, dtype=torch.float32)+1) * (inputs.size(-1) / target_size)).ceil().long()
pooled = []
for idx in range(target_size):
pooled.append(torch.mean(inputs[:, :, start_points[idx]:end_points[idx]], dim=-1, keepdim=False))
pooled = torch.cat(pooled, -1)
return pooled

原理应该没有问题,不过这份代码我没有运行过

但是我考虑这些代码都执行了 for 循环,我觉得不够优雅,写了如下版本,可以正常运行,也可以保存 onnx 模型,供大家参考:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

class AdaptiveAvgPool1dCustomPlus(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool1dCustomPlus, self).__init__()
self.output_size = int(output_size)
assert self.output_size > 0

def forward(self, x: torch.Tensor):
L = x.shape[-1]
cum_res = torch.cumsum(x, dim=-1)
cum_res = torch.cat((torch.zeros(*x.shape[:-1], 1).to(x.device), cum_res), dim=-1)
indexs = torch.arange(0, self.output_size) * L /self.output_size
indexs_larger = indexs + L /self.output_size
lstart_t = torch.floor(indexs).to(torch.long).to(x.device)
lend_t = torch.ceil(indexs_larger).to(torch.long).to(x.device)

output = (cum_res[...,lend_t] - cum_res[...,lstart_t]) / (lend_t - lstart_t)
return output

实现原理是一致的,只是通过累加和 来计算起始和结束的下标,摒弃了 for 循环。

方案三

方案二的版本可以成功转换 onnx 模型并且可以正常运行,但是转为 Tensorrt 后速度很慢,可以尝试直接使用 F.interpolate 算子

1
2
3
4
5
6
7
8
class AdaptiveAvgPool1dVersion3_d3(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool1dVersion3_d3, self).__init__()
self.output_size = int(output_size)
assert self.output_size > 0
def forward(self, x):
x = F.interpolate(x, self.output_size, mode='linear')
return x

可以正常运行,速度快了一些,只是实现原理和原始函数稍有不同,采用的是差值方式。

参考资料



文章链接:
https://www.zywvvd.com/notes/study/deep-learning/deploy/onnx-adaavgpool-bug/onnx-adaavgpool-bug/


“觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭”

微信二维码

微信支付

支付宝二维码

支付宝支付

ONNX 不支持 adaptive_avg_pool 算子的解决方案
https://www.zywvvd.com/notes/study/deep-learning/deploy/onnx-adaavgpool-bug/onnx-adaavgpool-bug/
作者
Yiwei Zhang
发布于
2024年2月4日
许可协议