Skip to content

Commit e06dd92

Browse files
authored
[torchlib] Fix irfft (#2770)
This fixes the onnxscript export for the irfft function. Fixes onnx/onnx#5920, and adds support to the changes in onnx/onnx#7574 and microsoft/onnxruntime#27028. Most of the diff is due to the onnx_opset generated code changes from onnx/onnx#5920. That can be removed if you would prefer. --------- Signed-off-by: Simon Byrne <sbyrne@nvidia.com>
1 parent 1080d6a commit e06dd92

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def aten__fft_c2r(
100100
101101
Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
102102
"""
103-
if len(dim) != 1:
104-
raise NotImplementedError("Only one dimension is supported for inverse FFT")
105103

106104
dimension = dim[0]
107105
unsqueeze_first_dim = dimension == 0
@@ -111,26 +109,34 @@ def aten__fft_c2r(
111109

112110
if unsqueeze_first_dim:
113111
transformed = op.Unsqueeze(self, axes=[0])
114-
dimension = 1
115112
else:
116113
transformed = self
117114

118-
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119-
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120-
# place no such restriction on the ONNX side.
121-
transformed = op.DFT(
122-
transformed,
123-
dft_length=last_dim_size,
124-
axis=dimension,
125-
inverse=True,
126-
onesided=False,
127-
)
128-
transformed = _fftn_onnx_normalization(
129-
transformed,
130-
normalization,
131-
op.Shape(transformed, start=dimension, end=dimension + 1),
132-
inverse=True,
133-
)
115+
for idx, dimension in enumerate(dim):
116+
# Adjust dimension if we unsqueezed at the beginning
117+
dimension += unsqueeze_first_dim
118+
119+
if idx < len(dim) - 1:
120+
transformed = op.DFT(transformed, axis=dimension, inverse=True, onesided=False)
121+
else:
122+
# last operation is one-sided, transform to real
123+
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
124+
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
125+
# place no such restriction on the ONNX side.
126+
transformed = op.DFT(
127+
transformed,
128+
dft_length=last_dim_size,
129+
axis=dimension,
130+
inverse=True,
131+
onesided=True,
132+
)
133+
134+
transformed = _fftn_onnx_normalization(
135+
transformed,
136+
normalization,
137+
op.Shape(transformed, start=dimension, end=dimension + 1),
138+
inverse=True,
139+
)
134140

135141
if unsqueeze_first_dim:
136142
transformed = op.Squeeze(transformed, axes=[0])

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,9 @@ def _where_input_wrangler(
463463
fft_ops.aten__fft_c2r,
464464
tolerance={torch.complex64: (3e-3, 1.8e-4)},
465465
complex=True,
466+
).xfail(
467+
matcher=lambda sample: True,
468+
reason="Requires ONNX with IRFFT support (onesided=True, inverse=True)",
466469
),
467470
TorchLibOpInfo(
468471
"ops.aten._fft_r2c", # Custom from extra_opinfo

0 commit comments

Comments
 (0)