Skip to content

Commit 7e6182b

Browse files
committed
Test fft normalization
1 parent a3e9cbe commit 7e6182b

File tree

1 file changed

+13
-7
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+13
-7
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import Optional, Sequence
1616

17-
from onnxscript import INT64
17+
from onnxscript import INT64, ir
1818
from onnxscript.function_libs.torch_lib.registration import torch_op
1919
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
2020
from onnxscript.onnx_opset import opset18 as op
@@ -118,12 +118,18 @@ def aten__fft_c2r(
118118
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119119
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120120
# place no such restriction on the ONNX side.
121-
transformed = op.DFT(
122-
transformed,
123-
dft_length=last_dim_size,
124-
axis=dimension,
125-
inverse=True,
126-
onesided=False,
121+
scale = (op.CastLike(last_dim_size, self)) / op.CastLike(
122+
op.Shape(transformed, start=dimension, end=dimension + 1), self
123+
)
124+
transformed = (
125+
op.DFT(
126+
transformed,
127+
dft_length=last_dim_size,
128+
axis=dimension,
129+
inverse=True,
130+
onesided=False,
131+
)
132+
* scale
127133
)
128134
transformed = _fftn_onnx_normalization(
129135
transformed,

0 commit comments

Comments
 (0)