Skip to content

Commit 2a13579

Browse files
lapertorqubvel
andauthored
Fix PSPNet ONNX export issue (#1270)
* fix(pspnet): fix ONNX export by replacing AdaptiveAvgPool2d with interpolate * fix: restore state_dict compatibility and handle ONNX export in forward * Update segmentation_models_pytorch/decoders/pspnet/decoder.py * fix: support ONNX export and TorchScript for PSPNet --------- Co-authored-by: Raphael Lapertot <raphael.lapertot@gmail.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent 5c0278f commit 2a13579

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

  • segmentation_models_pytorch/decoders/pspnet

segmentation_models_pytorch/decoders/pspnet/decoder.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def __init__(
2020
if pool_size == 1:
2121
use_norm = "identity" # PyTorch does not support BatchNorm for 1x1 shape
2222

23+
self.pool_size = pool_size
24+
2325
self.pool = nn.Sequential(
2426
nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
2527
modules.Conv2dReLU(
@@ -29,7 +31,19 @@ def __init__(
2931

3032
def forward(self, x: torch.Tensor) -> torch.Tensor:
3133
height, width = x.shape[2:]
32-
x = self.pool(x)
34+
35+
if torch.jit.is_scripting():
36+
# TorchScript path: use standard AdaptiveAvgPool2d via self.pool
37+
x = self.pool(x)
38+
elif torch.onnx.is_in_onnx_export():
39+
# ONNX export path: AdaptiveAvgPool2d is often problematic during export.
40+
# Using F.interpolate with 'area' mode provides the same mathematical result
41+
# (average pooling) while being more robustly supported.
42+
x = F.interpolate(x, size=(self.pool_size, self.pool_size), mode="area")
43+
x = self.pool[1](x) # use only ConvRelu block from pool
44+
else:
45+
x = self.pool(x)
46+
3347
x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=True)
3448
return x
3549

0 commit comments

Comments
 (0)