2121from onnxscript .onnx_types import TensorType
2222
2323
24- @torch_op (
25- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
26- private = True ,
27- complex = True ,
28- trace_only = True ,
29- )
3024def _fftn_onnx_normalization (
31- self ,
32- transformed : TFloat ,
25+ self : TFloat ,
3326 normalization : int ,
34- forward : bool ,
35- dims : Sequence [int ],
36- ) -> TFloat :
37- # Obtain the total_sample_count (n) for normalization
38- self_shape = op .Shape (self )
39- total_sample_count = op .ReduceProd (op .Gather (self_shape , dims ), keepdims = 0 )
40- total_sample_count = op .CastLike (total_sample_count , transformed )
41-
42- # Normalize the result
43- # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44- # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
45- if normalization == 1 :
46- # "forward" - normalize by 1/n
47- if forward :
48- result = op .Div (transformed , op .Sqrt (total_sample_count ))
49- else :
50- result = op .Mul (transformed , op .Sqrt (total_sample_count ))
51- elif normalization == 2 :
52- # "ortho" - normalize by 1/sqrt(n)
53- if forward :
54- result = op .Div (transformed , total_sample_count )
55- else :
56- result = transformed
57- else :
58- # "backward" - no normalization
59- if forward :
60- result = transformed
61- else :
62- result = op .Mul (transformed , total_sample_count )
63-
64- return result
65-
66-
67- @torch_op (
68- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
69- trace_only = True ,
70- private = True ,
71- complex = True ,
72- )
73- def _fftn_onnx (
74- self : TFloat , dims : Sequence [int ], normalization : int , inverse : bool , onesided : bool
27+ signal_size : INT64 ,
28+ inverse : bool = False ,
7529) -> TFloat :
76- """Standard complex to complex or real to complex FFT (forward or backward).
77-
78- This is a private shared function for implementing the various FFT functions.
79-
80- Args:
81- self: The input tensor.
82- dims: The dimensions to apply FFT.
83- normalization: The normalization mode.
84- inverse: Whether to compute the inverse FFT.
85- onesided: Whether to compute the one-sided FFT, which retains only the
86- positive frequencies.
87-
88- Returns:
89- The transformed tensor.
90- """
91-
92- # NOTE: trace_only because we need to process each dimension in a loop
93- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
94- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95-
96- # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97- # dimension at the beginning to represent the batch dimension.
98- transformed = op .Unsqueeze (self , axes = [0 ])
99-
100- # Add 1 to account for the batch dimension when counting axes from the left
101- new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims ]
102-
103- for dim in new_dims [:- 1 ]:
104- transformed = op .DFT (transformed , axis = dim , inverse = inverse , onesided = False )
105-
106- # Torch computers one-sided FFT on the last dimension only.
107- if onesided :
108- transformed = op .DFT (transformed , axis = new_dims [- 1 ], inverse = inverse , onesided = True )
30+ """Normalize in forward or backward direction."""
31+ # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
32+ # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
33+ # Modes:
34+ # 0: no normalization (backward)
35+ # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
36+ # 2: divide by signal_size (forward)
37+ signal_size = op .CastLike (signal_size , self )
38+ if not inverse :
39+ # Forward normalization
40+ if normalization == 1 :
41+ self = op .Div (self , op .Sqrt (signal_size ))
42+ elif normalization == 2 :
43+ self = op .Div (self , signal_size )
10944 else :
110- transformed = op .DFT ( transformed , axis = new_dims [ - 1 ], inverse = inverse , onesided = False )
111-
112- # Remove the batch dimension
113- transformed = op . Squeeze ( transformed , axes = [ 0 ])
114-
115- return _fftn_onnx_normalization ( self , transformed , normalization , not inverse , dims )
45+ # Backward normalization, accounting for op.DFT already dividing by signal_size
46+ if normalization == 0 :
47+ self = op . Mul ( self , signal_size )
48+ elif normalization == 1 :
49+ self = op . Mul ( self , op . Sqrt ( signal_size ))
50+ return self
11651
11752
11853@torch_op ("aten::_fft_c2c" , trace_only = True , complex = True )
@@ -124,39 +59,87 @@ def aten__fft_c2c(
12459 Standard complex to complex FFT (forward or backward).
12560 """
12661
127- # NOTE: trace_only because we need to negate forward
128- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
129- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
62+ # NOTE: SymInt dim is not supported because DFT-17 needs a static axis
13063
13164 # ONNX DFT input assumes the last dimension is the complex dimension.
132- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133- dim = [d - 1 if d < 0 else d for d in dim ]
134- return _fftn_onnx (self , dim , normalization , inverse = not forward , onesided = False )
65+
66+ unsqueeze_first_dim = 0 in dim
67+ # 1. Add a new dimension for the end and batch dimension, if needed
68+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
69+ # If needed, add 1 to account for the batch dimension.
70+
71+ if unsqueeze_first_dim :
72+ transformed = op .Unsqueeze (self , axes = [0 ])
73+ dim = [d + 1 for d in dim ]
74+ else :
75+ transformed = self
76+
77+ for dimension in reversed (dim ):
78+ transformed = op .DFT (transformed , axis = dimension , inverse = not forward , onesided = False )
79+ transformed = _fftn_onnx_normalization (
80+ transformed ,
81+ normalization ,
82+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
83+ not forward ,
84+ )
85+
86+ if unsqueeze_first_dim :
87+ transformed = op .Squeeze (transformed , axes = [0 ])
88+
89+ return transformed
13590
13691
13792@torch_op ("aten::_fft_c2r" , trace_only = True , complex = True )
13893def aten__fft_c2r (
13994 self : TFloat ,
14095 dim : Sequence [int ],
14196 normalization : int ,
142- last_dim_size : INT64 , # pylint: disable=unused-argument
97+ last_dim_size : INT64 ,
14398) -> TFloat :
14499 """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145100
146- Complex to real inverse FFT.
101+ Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
147102 """
148-
149- # TODO(justinchuby): Figure out what last_dim_size does
150-
151- self_rank = len (self .shape )
152- # ONNX DFT input assumes the last dimension is the complex dimension.
153- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154- dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
155- transformed = _fftn_onnx (self , dim , normalization , inverse = True , onesided = False )
156- # Take only the real part
157- real_part = op .Slice (transformed , axes = [- 1 ], starts = [0 ], ends = [1 ])
158-
159- return op .Squeeze (real_part , axes = [- 1 ])
103+ if len (dim ) != 1 :
104+ raise NotImplementedError ("Only one dimension is supported for inverse FFT" )
105+
106+ dimension = dim [0 ]
107+ unsqueeze_first_dim = dimension == 0
108+ # 1. Add a new dimension for batch dimension, if needed
109+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
110+ # If needed, add 1 to account for the batch dimension.
111+
112+ if unsqueeze_first_dim :
113+ transformed = op .Unsqueeze (self , axes = [0 ])
114+ dimension = 1
115+ else :
116+ transformed = self
117+
118+ # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119+ # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120+ # place no such restriction on the ONNX side.
121+ transformed = op .DFT (
122+ transformed ,
123+ dft_length = last_dim_size ,
124+ axis = dimension ,
125+ inverse = True ,
126+ onesided = False ,
127+ )
128+ transformed = _fftn_onnx_normalization (
129+ transformed ,
130+ normalization ,
131+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
132+ inverse = True ,
133+ )
134+
135+ if unsqueeze_first_dim :
136+ transformed = op .Squeeze (transformed , axes = [0 ])
137+
138+ # Remove the imaginary part
139+ transformed = op .Slice (transformed , [0 ], [1 ], [- 1 ])
140+ transformed = op .Squeeze (transformed , axes = [- 1 ])
141+
142+ return transformed
160143
161144
162145@torch_op ("aten::_fft_r2c" , trace_only = True )
@@ -168,17 +151,37 @@ def aten__fft_r2c(
168151 Real to complex forward FFT.
169152 """
170153
171- # Add a new dimension at the end
172- signal = op .Unsqueeze (self , axes = [- 1 ])
173154 # No need to fill the imaginary part because ONNX DFT accepts real inputs
174155 # https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175156
176- self_rank = len (self .shape )
177- # ONNX DFT input assumes the last dimension is the complex dimension.
178- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179- dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
157+ unsqueeze_first_dim = 0 in dim
158+ # 1. Add a new dimension for the end and batch dimension, if needed
159+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
160+ # If needed, add 1 to account for the batch dimension.
161+
162+ if unsqueeze_first_dim :
163+ transformed = op .Unsqueeze (self , axes = [0 , - 1 ])
164+ dim = [d + 1 for d in dim ]
165+ else :
166+ transformed = op .Unsqueeze (self , axes = [- 1 ])
167+
168+ for idx , dimension in enumerate (reversed (dim )):
169+ transformed = _fftn_onnx_normalization (
170+ transformed ,
171+ normalization ,
172+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
173+ inverse = False ,
174+ )
175+ if idx > 0 :
176+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = False )
177+ else :
178+ # Torch computes one-sided FFT on the last dimension only.
179+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = onesided )
180+
181+ if unsqueeze_first_dim :
182+ transformed = op .Squeeze (transformed , axes = [0 ])
180183
181- return _fftn_onnx ( signal , dim , normalization , inverse = False , onesided = onesided )
184+ return transformed
182185
183186
184187def aten_fft_fft (
0 commit comments