File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed
onnxscript/function_libs/torch_lib/ops Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change 1414
1515from typing import Optional , Sequence
1616
17- from onnxscript import INT64
17+ from onnxscript import INT64 , ir
1818from onnxscript .function_libs .torch_lib .registration import torch_op
1919from onnxscript .function_libs .torch_lib .tensor_typing import TFloat
2020from onnxscript .onnx_opset import opset18 as op
@@ -118,12 +118,18 @@ def aten__fft_c2r(
118118 # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119119 # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120120 # place no such restriction on the ONNX side.
121- transformed = op .DFT (
122- transformed ,
123- dft_length = last_dim_size ,
124- axis = dimension ,
125- inverse = True ,
126- onesided = False ,
121+ scale = (op .CastLike (last_dim_size , self )) / op .CastLike (
122+ op .Shape (transformed , start = dimension , end = dimension + 1 ), self
123+ )
124+ transformed = (
125+ op .DFT (
126+ transformed ,
127+ dft_length = last_dim_size ,
128+ axis = dimension ,
129+ inverse = True ,
130+ onesided = False ,
131+ )
132+ * scale
127133 )
128134 transformed = _fftn_onnx_normalization (
129135 transformed ,
You can’t perform that action at this time.
0 commit comments