@@ -238,6 +238,75 @@ def forward(self, x):
238238 )
239239 _testing .assert_onnx_program (onnx_program )
240240
241+ def test_aten_stft_1 (self ):
242+ class Model (torch .nn .Module ):
243+ def forward (self , x ):
244+ return torch .stft (x , n_fft = 4 , return_complex = True )
245+
246+ x = torch .randn (4 , 16 , dtype = torch .float32 )
247+
248+ onnx_program = torch .onnx .export (
249+ Model (),
250+ (x ,),
251+ dynamo = True ,
252+ verbose = False ,
253+ )
254+ _testing .assert_onnx_program (onnx_program )
255+
256+ def test_aten_stft_2 (self ):
257+ class Model (torch .nn .Module ):
258+ def forward (self , x ):
259+ return torch .stft (x , n_fft = 4 , return_complex = False )
260+
261+ x = torch .randn (4 , 16 , dtype = torch .float32 )
262+
263+ onnx_program = torch .onnx .export (
264+ Model (),
265+ (x ,),
266+ dynamo = True ,
267+ verbose = False ,
268+ )
269+ _testing .assert_onnx_program (onnx_program )
270+
271+ def test_aten_stft_3 (self ):
272+ class Model (torch .nn .Module ):
273+ def forward (self , x ):
274+ window = torch .ones (16 , dtype = torch .float32 )
275+ return torch .stft (x , n_fft = 16 , window = window , return_complex = False )
276+
277+ x = torch .randn (100 , dtype = torch .float32 )
278+
279+ onnx_program = torch .onnx .export (
280+ Model (),
281+ (x ,),
282+ dynamo = True ,
283+ verbose = False ,
284+ )
285+ _testing .assert_onnx_program (onnx_program )
286+
287+ def test_aten_stft_4 (self ):
288+ class Model (torch .nn .Module ):
289+ def forward (self , x ):
290+ return torch .stft (
291+ x ,
292+ n_fft = 4 ,
293+ hop_length = 1 ,
294+ win_length = 4 ,
295+ center = True ,
296+ onesided = True ,
297+ return_complex = True ,
298+ )
299+
300+ x = torch .randn (4 , 16 , dtype = torch .float32 )
301+
302+ onnx_program = torch .onnx .export (
303+ Model (),
304+ (x ,),
305+ dynamo = True ,
306+ verbose = False ,
307+ )
308+ _testing .assert_onnx_program (onnx_program )
309+
241310
242311if __name__ == "__main__" :
243312 unittest .main ()
0 commit comments