Skip to content

Fix PSPNet ONNX export issue#1270

Merged
qubvel merged 4 commits intoqubvel-org:mainfrom
lapertor:fix/pspnet-onnx-export
Mar 19, 2026
Merged

Fix PSPNet ONNX export issue#1270
qubvel merged 4 commits intoqubvel-org:mainfrom
lapertor:fix/pspnet-onnx-export

Conversation

@lapertor
Copy link
Copy Markdown
Contributor

@lapertor lapertor commented Mar 9, 2026

Hello!

This PR addresses Issue #1266 reported by @staticplasma.

The problem

Exporting a PSPNet model to ONNX currently fails due to AdaptiveAvgPool2d. It typically throws a SymbolicValueError because ONNX cannot always compute the kernel/stride sizes dynamically for certain input shapes or opsets.

The fix

I replaced AdaptiveAvgPool2d with F.interpolate(mode='area') specifically during ONNX export using torch.onnx.is_in_onnx_export(). This has several benefits:

  • Area interpolation is mathematically equivalent to average pooling.
  • This primitive is much better supported by the ONNX exporter across different opsets and dynamic input shapes.
  • The original code remains untouched for standard PyTorch inference.

Verification

I have verified the fix with the following environment: Python 3.13, PyTorch 2.5+, ONNX 1.17.

Click to see the Reproduction/Export Script
import torch
import segmentation_models_pytorch as smp
import onnx
import os

def reproduce():
    ENCODER = "mobilenet_v2"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    INPUT_SIZE = (128, 128)
    OUTPUT_PATH = "model.onnx"

    print(f"--- Tentative de reproduction avec DEVICE={DEVICE} ---")

    model = smp.PSPNet(
        encoder_name=ENCODER,
        encoder_weights=None, 
        in_channels=3,
        classes=1,
        activation=None
    )
    model.to(DEVICE)
    model.eval()

    dummy_input = torch.randn(1, 3, *INPUT_SIZE).to(DEVICE)

    print("Test d'inférence PyTorch...")
    with torch.no_grad():
        output = model(dummy_input)
    print(f"Inférence réussie. Taille de sortie : {output.shape}")

    print("\nTentative d'export ONNX...")
    try:
        torch.onnx.export(
            model,
            dummy_input,
            OUTPUT_PATH,
            export_params=True,
            opset_version=17, 
            do_constant_folding=True,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
        )
        print(f"SUCCÈS : Modèle exporté sous {OUTPUT_PATH}")
    except Exception as e:
        print("\n--- ERREUR DÉTECTÉE ---")
        print(e)
        print("-----------------------")

if __name__ == "__main__":
    reproduce()
Click to see the ONNX Runtime Validation Script
import onnxruntime as ort
import numpy as np
import torch

try:
    session = ort.InferenceSession("model.onnx")
    print("Modèle ONNX chargé avec succès !")
except Exception as e:
    print(f"Erreur de chargement : {e}")
    exit()

input_name = session.get_inputs()[0].name
dummy_in = np.random.randn(1, 3, 128, 128).astype(np.float32)

outputs = session.run(None, {input_name: dummy_in})

print(f"Inférence réussie !")
print(f"Taille de la sortie : {outputs[0].shape}")

Best regards,
Raphaël

P.S.: This is my first ever contribution and pull request to an open-source projet on Github! So I probably didn't do everything perfectly, even though it's a small change. Tell me what I could do / have done better!

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...entation_models_pytorch/decoders/pspnet/decoder.py 87.50% 1 Missing ⚠️
Files with missing lines Coverage Δ
...entation_models_pytorch/decoders/pspnet/decoder.py 97.56% <87.50%> (-2.44%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, appreciate your help! please see the comment

Comment on lines 24 to 27
self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))
self.conv = modules.Conv2dReLU(
in_channels, out_channels, kernel_size=1, use_norm=use_norm
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will lead to an error while loading weights for those who trained their model with previous version, lets keep it as it was in original implementation. instead in forward pass let's make smth like

if torch.onnx.is_in_onnx_export():
            x = F.interpolate(x, size=(self.pool_size, self.pool_size), mode='area')
            x = self.pool[1](x) # use only ConvRelu block from pool
else:
            x = self.pool(x)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. you could see the error in CI

E               Missing key(s) in state_dict: "decoder.psp.blocks.0.conv.0.bias", "decoder.psp.blocks.0.conv.0.weight", "decoder.psp.blocks.1.conv.0.weight", "decoder.psp.blocks.1.conv.1.bias", "decoder.psp.blocks.1.conv.1.running_mean", "decoder.psp.blocks.1.conv.1.running_var", "decoder.psp.blocks.1.conv.1.weight", "decoder.psp.blocks.2.conv.0.weight", "decoder.psp.blocks.2.conv.1.bias", "decoder.psp.blocks.2.conv.1.running_mean", "decoder.psp.blocks.2.conv.1.running_var", "decoder.psp.blocks.2.conv.1.weight", "decoder.psp.blocks.3.conv.0.weight", "decoder.psp.blocks.3.conv.1.bias", "decoder.psp.blocks.3.conv.1.running_mean", "decoder.psp.blocks.3.conv.1.running_var", "decoder.psp.blocks.3.conv.1.weight"
E               Unexpected key(s) in state_dict: "decoder.psp.blocks.0.pool.1.0.bias", "decoder.psp.blocks.0.pool.1.0.weight", "decoder.psp.blocks.1.pool.1.0.weight", "decoder.psp.blocks.1.pool.1.1.bias", "decoder.psp.blocks.1.pool.1.1.num_batches_tracked", "decoder.psp.blocks.1.pool.1.1.running_mean", "decoder.psp.blocks.1.pool.1.1.running_var", "decoder.psp.blocks.1.pool.1.1.weight", "decoder.psp.blocks.2.pool.1.0.weight", "decoder.psp.blocks.2.pool.1.1.bias", "decoder.psp.blocks.2.pool.1.1.num_batches_tracked", "decoder.psp.blocks.2.pool.1.1.running_mean", "decoder.psp.blocks.2.pool.1.1.running_var", "decoder.psp.blocks.2.pool.1.1.weight", "decoder.psp.blocks.3.pool.1.0.weight", "decoder.psp.blocks.3.pool.1.1.bias", "decoder.psp.blocks.3.pool.1.1.num_batches_tracked", "decoder.psp.blocks.3.pool.1.1.running_mean", "decoder.psp.blocks.3.pool.1.1.running_var", "decoder.psp.blocks.3.pool.1.1.weight"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition another test seems to be broken for PSP

=================================== FAILURES ===================================
________________________ TestPspModel.test_torch_script ________________________
[gw0] linux -- Python 3.10.20 /home/runner/work/segmentation_models.pytorch/segmentation_models.pytorch/.venv/bin/python3

self = <tests.models.test_psp.TestPspModel testMethod=test_torch_script>

    @pytest.mark.torch_script
    def test_torch_script(self):
        if not check_run_test_on_diff_or_main(self.files_for_diff):
            self.skipTest("No diff and not on `main`.")
    
        sample = self._get_sample().to(default_device)
        model = self.get_default_model()
        model.eval()
    
        if not model._is_torch_scriptable:
            with self.assertRaises(RuntimeError):
                scripted_model = torch.jit.script(model)
            return
    
>       scripted_model = torch.jit.script(model)

@lapertor
Copy link
Copy Markdown
Contributor Author

Thanks for the review. You're right, I didn't think about this. I applied the modifications you recommended, passed my tests successfully, and I also added self.pool_size as an attribute to support the dynamic interpolation. But I just thought: maybe adding this new attribute self.pool_size can break things somewhere...? This might not be over yet. Let me know what you think.

Copy link
Copy Markdown
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks fine now

Comment thread segmentation_models_pytorch/decoders/pspnet/decoder.py Outdated
@qubvel
Copy link
Copy Markdown
Collaborator

qubvel commented Mar 12, 2026

Can you please

  1. run make fixup
  2. skip torch.script test for PSP, I believe torch.script is deprecated and ONNX is more important. However if you have idea how to fix it and have both, it would be nice
=================================== FAILURES ===================================
________________________ TestPspModel.test_torch_script ________________________
[gw1] linux -- Python 3.14.3 /home/runner/work/segmentation_models.pytorch/segmentation_models.pytorch/.venv/bin/python3

self = <tests.models.test_psp.TestPspModel testMethod=test_torch_script>

    @pytest.mark.torch_script
    def test_torch_script(self):
        if not check_run_test_on_diff_or_main(self.files_for_diff):
            self.skipTest("No diff and not on `main`.")
    
        sample = self._get_sample().to(default_device)
        model = self.get_default_model()
        model.eval()
    
        if not model._is_torch_scriptable:
            with self.assertRaises(RuntimeError):
                scripted_model = torch.jit.script(model)
            return
    
>       scripted_model = torch.jit.script(model)

@lapertor
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback.

  • make fixup has been applied.
  • Regarding torch.script: I found a way to support both TorchScript and ONNX export simultaneously. The PSPBlock.forward now uses torch.jit.is_scripting() and torch.onnx.is_in_onnx_export() to branch between the two paths:
    • TorchScript: uses standard AdaptiveAvgPool2d via self.pool
    • ONNX export: uses F.interpolate with mode="area" which is more robustly supported by ONNX
    • Default: same as TorchScript path
  • All 8 tests pass locally, including test_torch_script and test_torch_export.

@qubvel qubvel merged commit 2a13579 into qubvel-org:main Mar 19, 2026
16 of 17 checks passed
@qubvel
Copy link
Copy Markdown
Collaborator

qubvel commented Mar 19, 2026

Thanks

@lapertor lapertor deleted the fix/pspnet-onnx-export branch March 23, 2026 13:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants