Fix PSPNet ONNX export issue#1270
Conversation
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
qubvel
left a comment
There was a problem hiding this comment.
Thanks for the fix, appreciate your help! please see the comment
| self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)) | ||
| self.conv = modules.Conv2dReLU( | ||
| in_channels, out_channels, kernel_size=1, use_norm=use_norm | ||
| ) |
There was a problem hiding this comment.
this will lead to an error while loading weights for those who trained their model with previous version, lets keep it as it was in original implementation. instead in forward pass let's make smth like
if torch.onnx.is_in_onnx_export():
x = F.interpolate(x, size=(self.pool_size, self.pool_size), mode='area')
x = self.pool[1](x) # use only ConvRelu block from pool
else:
x = self.pool(x)There was a problem hiding this comment.
P.S. you could see the error in CI
E Missing key(s) in state_dict: "decoder.psp.blocks.0.conv.0.bias", "decoder.psp.blocks.0.conv.0.weight", "decoder.psp.blocks.1.conv.0.weight", "decoder.psp.blocks.1.conv.1.bias", "decoder.psp.blocks.1.conv.1.running_mean", "decoder.psp.blocks.1.conv.1.running_var", "decoder.psp.blocks.1.conv.1.weight", "decoder.psp.blocks.2.conv.0.weight", "decoder.psp.blocks.2.conv.1.bias", "decoder.psp.blocks.2.conv.1.running_mean", "decoder.psp.blocks.2.conv.1.running_var", "decoder.psp.blocks.2.conv.1.weight", "decoder.psp.blocks.3.conv.0.weight", "decoder.psp.blocks.3.conv.1.bias", "decoder.psp.blocks.3.conv.1.running_mean", "decoder.psp.blocks.3.conv.1.running_var", "decoder.psp.blocks.3.conv.1.weight"
E Unexpected key(s) in state_dict: "decoder.psp.blocks.0.pool.1.0.bias", "decoder.psp.blocks.0.pool.1.0.weight", "decoder.psp.blocks.1.pool.1.0.weight", "decoder.psp.blocks.1.pool.1.1.bias", "decoder.psp.blocks.1.pool.1.1.num_batches_tracked", "decoder.psp.blocks.1.pool.1.1.running_mean", "decoder.psp.blocks.1.pool.1.1.running_var", "decoder.psp.blocks.1.pool.1.1.weight", "decoder.psp.blocks.2.pool.1.0.weight", "decoder.psp.blocks.2.pool.1.1.bias", "decoder.psp.blocks.2.pool.1.1.num_batches_tracked", "decoder.psp.blocks.2.pool.1.1.running_mean", "decoder.psp.blocks.2.pool.1.1.running_var", "decoder.psp.blocks.2.pool.1.1.weight", "decoder.psp.blocks.3.pool.1.0.weight", "decoder.psp.blocks.3.pool.1.1.bias", "decoder.psp.blocks.3.pool.1.1.num_batches_tracked", "decoder.psp.blocks.3.pool.1.1.running_mean", "decoder.psp.blocks.3.pool.1.1.running_var", "decoder.psp.blocks.3.pool.1.1.weight"
There was a problem hiding this comment.
In addition another test seems to be broken for PSP
=================================== FAILURES ===================================
________________________ TestPspModel.test_torch_script ________________________
[gw0] linux -- Python 3.10.20 /home/runner/work/segmentation_models.pytorch/segmentation_models.pytorch/.venv/bin/python3
self = <tests.models.test_psp.TestPspModel testMethod=test_torch_script>
@pytest.mark.torch_script
def test_torch_script(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")
sample = self._get_sample().to(default_device)
model = self.get_default_model()
model.eval()
if not model._is_torch_scriptable:
with self.assertRaises(RuntimeError):
scripted_model = torch.jit.script(model)
return
> scripted_model = torch.jit.script(model)
|
Thanks for the review. You're right, I didn't think about this. I applied the modifications you recommended, passed my tests successfully, and I also added |
|
Can you please
|
|
Thanks for the feedback.
|
|
Thanks |
Hello!
This PR addresses Issue #1266 reported by @staticplasma.
The problem
Exporting a
PSPNetmodel to ONNX currently fails due toAdaptiveAvgPool2d. It typically throws aSymbolicValueErrorbecause ONNX cannot always compute the kernel/stride sizes dynamically for certain input shapes or opsets.The fix
I replaced
AdaptiveAvgPool2dwithF.interpolate(mode='area')specifically during ONNX export usingtorch.onnx.is_in_onnx_export(). This has several benefits:Verification
I have verified the fix with the following environment: Python 3.13, PyTorch 2.5+, ONNX 1.17.
Click to see the Reproduction/Export Script
Click to see the ONNX Runtime Validation Script
Best regards,
Raphaël
P.S.: This is my first ever contribution and pull request to an open-source projet on Github! So I probably didn't do everything perfectly, even though it's a small change. Tell me what I could do / have done better!