[torch.onnx] Obtain model graph for `PixelUnshuffle(downscale_factor=2)` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PixelUnshuffle(downscale_factor=2)` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
/home/peter/.local/share/uv/python/cpython-3.13.13-linux-x86_64-gnu/lib/python3.13/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
[torch.onnx] Obtain model graph for `PixelUnshuffle()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PixelUnshuffle()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
/home/peter/.local/share/uv/python/cpython-3.13.13-linux-x86_64-gnu/lib/python3.13/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
------------------------------------------------------------------------------------------------------------------------
nn.PixelUnshuffle:
graph(
name=main_graph,
inputs=(
%"input"<FLOAT,[1,2,2,2]>
),
outputs=(
%"pixel_unshuffle"<FLOAT,[1,8,1,1]>
),
) {
0 | # node_pixel_unshuffle
%"pixel_unshuffle"<FLOAT,[1,8,1,1]> ⬅️ ::SpaceToDepth(%"input") {blocksize=2}
return %"pixel_unshuffle"<FLOAT,[1,8,1,1]>
}
PixelUnshuffle:
graph(
name=main_graph,
inputs=(
%"input_tensor"<FLOAT,[1,2,2,2]>
),
outputs=(
%"view_1"<FLOAT,[1,8,1,1]>
),
initializers=(
%"val_7"<INT64,[6]>{Tensor<INT64,[6]>(array([-1, 2, 1, 2, 1, 2]), name='val_7')},
%"val_13"<INT64,[4]>{Tensor<INT64,[4]>(array([1, 8, 1, 1]), name='val_13')}
),
) {
0 | # node_view
%"view"<FLOAT,[1,2,1,2,1,2]> ⬅️ ::Reshape(%"input_tensor", %"val_7"{[-1, 2, 1, 2, 1, 2]}) {allowzero=1}
1 | # node_permute
%"permute"<FLOAT,[1,2,2,2,1,1]> ⬅️ ::Transpose(%"view") {perm=(0, 1, 3, 5, 2, 4)}
2 | # node_view_1
%"view_1"<FLOAT,[1,8,1,1]> ⬅️ ::Reshape(%"permute", %"val_13"{[1, 8, 1, 1]}) {allowzero=1}
return %"view_1"<FLOAT,[1,8,1,1]>
}
------------------------------------------------------------------------------------------------------------------------
[torch] nn.PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964])
[torch] PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964]) ✅
[onnx] nn.PixelUnshuffle: tensor([0.4963, 0.3074, 0.7682, 0.6341, 0.0885, 0.4901, 0.1320, 0.8964]) ❌
[onnx] PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964]) ✅
Problem
nn.PixelUnshuffleis converted into aSpaceToDepthop. However, there is a mismatch in the channel order. While the relatednn.PixelShuffleis converted toDepthToSpacewithmode="CRD", theSpaceToDepthop doesn't have a mode attribute and is implemented in DCR mode.onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 7599 to 7634 in c6e8ec6
Here is an MRE to demonstrate the issue:
Code
Output
Potential Solution
So either ONNX needs to add a mode attribute to
SpaceToDepth, ornn.PixelUnshuffleneeds to be implemented with Reshape→Transpose→Reshape ops instead.