Skip to content

nn.PixelUnshuffle is wrongfully exported as SpaceToDepth #2891

@petwu

Description

@petwu

Problem

nn.PixelUnshuffle is converted into a SpaceToDepth op. However, there is a mismatch in the channel order. While the related nn.PixelShuffle is converted to DepthToSpace with mode="CRD", the SpaceToDepth op doesn't have a mode attribute and is implemented in DCR mode.

@torch_op("aten::pixel_shuffle", trace_only=True)
def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
"""pixel_shuffle(Tensor self, int upscale_factor) -> Tensor"""
if len(self.shape) == 4:
return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD")
# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
batch_dims = op.Shape(self, end=-3)
chw_in_dims = op.Shape(self, start=-3)
reshaped_self = op.Reshape(
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
)
depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
final_dims = op.Shape(depth_to_space, start=1)
output_shape = op.Concat(batch_dims, final_dims, axis=0)
return op.Reshape(depth_to_space, output_shape, allowzero=True)
@torch_op("aten::pixel_unshuffle", trace_only=True)
def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
"""pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor"""
if len(self.shape) == 4:
return op.SpaceToDepth(self, blocksize=downscale_factor)
# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
batch_dims = op.Shape(self, end=-3)
chw_in_dims = op.Shape(self, start=-3)
reshaped_self = op.Reshape(
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
)
space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
final_dims = op.Shape(space_to_depth, start=1)
output_shape = op.Concat(batch_dims, final_dims, axis=0)
return op.Reshape(space_to_depth, output_shape, allowzero=True)

Here is an MRE to demonstrate the issue:

Code
import torch
from torch import nn, Tensor


class PixelUnshuffle(nn.Module):
    def __init__(self, downscale_factor: int) -> None:
        super().__init__()
        self.downscale_factor = downscale_factor

    def forward(self, input_tensor: Tensor) -> Tensor:
        r = self.downscale_factor
        if input_tensor.dim() < 3:
            raise ValueError("Input must have shape (*, C, H, W) with at least 3 dims.")

        leading_dims = input_tensor.shape[:-3]
        C, H, W = input_tensor.shape[-3:]
        assert H % r == 0 and W % r == 0, (
            "Spatial dimensions must be divisible by the downscale factor."
        )
        C_out = C * r * r
        H_out = H // r
        W_out = W // r
        return (
            input_tensor.view(-1, C, H_out, r, W_out, r)  # (*, C, H//r, r, W//r, r)
            .permute(0, 1, 3, 5, 2, 4)  # (*, C, r, r, H//r, W//r)
            .contiguous()
            .view(*leading_dims, C_out, H_out, W_out)  # (*, C*r*r, H//r, W//r)
        )


model_builtin = nn.PixelUnshuffle(2)
model_builtin.eval()
model_custom = PixelUnshuffle(2)
model_custom.eval()

torch.manual_seed(0)
x = torch.rand(1, 2, 2, 2)
with torch.no_grad():
    onnx_builtin = torch.onnx.export(model_builtin, (x,), None, dynamo=True)
    onnx_custom = torch.onnx.export(model_custom, (x,), None, dynamo=True)

print("-" * 80)
print("nn.PixelUnshuffle:\n", onnx_builtin.model.graph)
print("PixelUnshuffle:\n", onnx_custom.model.graph)
print("-" * 80)

out_builtin = model_builtin(x)
out_custom = model_custom(x)
out_onnx_builtin = onnx_builtin(x)[0]
out_onnx_custom = onnx_custom(x)[0]

check = lambda a, b: "✅" if torch.allclose(a, b) else "❌"
print("[torch] nn.PixelUnshuffle:", out_builtin.flatten())
print("[torch]    PixelUnshuffle:", out_custom.flatten(), check(out_builtin, out_custom))
print("[onnx]  nn.PixelUnshuffle:", out_onnx_builtin.flatten(), check(out_builtin, out_onnx_builtin))
print("[onnx]     PixelUnshuffle:", out_onnx_custom.flatten(), check(out_custom, out_onnx_custom))
Output
[torch.onnx] Obtain model graph for `PixelUnshuffle(downscale_factor=2)` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PixelUnshuffle(downscale_factor=2)` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
/home/peter/.local/share/uv/python/cpython-3.13.13-linux-x86_64-gnu/lib/python3.13/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
[torch.onnx] Obtain model graph for `PixelUnshuffle()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PixelUnshuffle()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
/home/peter/.local/share/uv/python/cpython-3.13.13-linux-x86_64-gnu/lib/python3.13/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
------------------------------------------------------------------------------------------------------------------------
nn.PixelUnshuffle:
 graph(
    name=main_graph,
    inputs=(
        %"input"<FLOAT,[1,2,2,2]>
    ),
    outputs=(
        %"pixel_unshuffle"<FLOAT,[1,8,1,1]>
    ),
) {
    0 |  # node_pixel_unshuffle
         %"pixel_unshuffle"<FLOAT,[1,8,1,1]> ⬅️ ::SpaceToDepth(%"input") {blocksize=2}
    return %"pixel_unshuffle"<FLOAT,[1,8,1,1]>
}
PixelUnshuffle:
 graph(
    name=main_graph,
    inputs=(
        %"input_tensor"<FLOAT,[1,2,2,2]>
    ),
    outputs=(
        %"view_1"<FLOAT,[1,8,1,1]>
    ),
    initializers=(
        %"val_7"<INT64,[6]>{Tensor<INT64,[6]>(array([-1,  2,  1,  2,  1,  2]), name='val_7')},
        %"val_13"<INT64,[4]>{Tensor<INT64,[4]>(array([1, 8, 1, 1]), name='val_13')}
    ),
) {
    0 |  # node_view
         %"view"<FLOAT,[1,2,1,2,1,2]> ⬅️ ::Reshape(%"input_tensor", %"val_7"{[-1, 2, 1, 2, 1, 2]}) {allowzero=1}
    1 |  # node_permute
         %"permute"<FLOAT,[1,2,2,2,1,1]> ⬅️ ::Transpose(%"view") {perm=(0, 1, 3, 5, 2, 4)}
    2 |  # node_view_1
         %"view_1"<FLOAT,[1,8,1,1]> ⬅️ ::Reshape(%"permute", %"val_13"{[1, 8, 1, 1]}) {allowzero=1}
    return %"view_1"<FLOAT,[1,8,1,1]>
}
------------------------------------------------------------------------------------------------------------------------
[torch] nn.PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964])
[torch]    PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964]) ✅
[onnx]  nn.PixelUnshuffle: tensor([0.4963, 0.3074, 0.7682, 0.6341, 0.0885, 0.4901, 0.1320, 0.8964]) ❌
[onnx]     PixelUnshuffle: tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964]) ✅

Potential Solution

So either ONNX needs to add a mode attribute to SpaceToDepth, or nn.PixelUnshuffle needs to be implemented with Reshape→Transpose→Reshape ops instead.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions