@@ -184,6 +184,25 @@ def xfail(
184184# Modify this section ##########################################################
185185
186186
187+ def _embedding_bag_input_wrangler (
188+ args : list [Any ], kwargs : dict [str , Any ]
189+ ) -> tuple [list [Any ], dict [str , Any ]]:
190+ # ONNX attributes cannot be None; omit padding_idx if it's None.
191+ if "padding_idx" in kwargs :
192+ padding_idx = kwargs .pop ("padding_idx" )
193+ if padding_idx is not None :
194+ kwargs ["padding_idx" ] = int (padding_idx )
195+
196+ # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
197+ if len (args ) >= 3 :
198+ if isinstance (args [1 ], torch .Tensor ):
199+ args [1 ] = args [1 ].to (torch .long )
200+ if isinstance (args [2 ], torch .Tensor ):
201+ args [2 ] = args [2 ].to (torch .long )
202+
203+ return args , kwargs
204+
205+
187206def _amin_amax_input_wrangler (
188207 args : list [Any ], kwargs : dict [str , Any ]
189208) -> tuple [list [Any ], dict [str , Any ]]:
@@ -908,12 +927,27 @@ def _where_input_wrangler(
908927 core_ops .aten_embedding_bag ,
909928 tolerance = {torch .float32 : (1e-4 , 5e-4 )},
910929 compare_shape_only_for_output = (1 , 2 , 3 ),
911- ).skip (dtypes = (torch .float16 ,), reason = "fixme: results mismatch in torch nightly." ),
930+ input_wrangler = _embedding_bag_input_wrangler ,
931+ ).skip (
932+ dtypes = (torch .float16 ,),
933+ reason = "fixme: results mismatch in torch nightly." ,
934+ ),
935+ TorchLibOpInfo (
936+ "ops.aten.embedding_bag.padding_idx_none" ,
937+ core_ops .aten_embedding_bag ,
938+ input_wrangler = _embedding_bag_input_wrangler ,
939+ ),
940+ TorchLibOpInfo (
941+ "ops.aten.embedding_bag.padding_idx_int" ,
942+ core_ops .aten_embedding_bag ,
943+ input_wrangler = _embedding_bag_input_wrangler ,
944+ ),
912945 TorchLibOpInfo (
913946 "ops.aten.embedding_bag.padding_idx" ,
914947 core_ops .aten_embedding_bag_padding_idx ,
915948 tolerance = {torch .float16 : (1e-2 , 1e-2 )},
916949 compare_shape_only_for_output = (1 , 2 , 3 ),
950+ input_wrangler = _embedding_bag_input_wrangler ,
917951 ),
918952 TorchLibOpInfo (
919953 "ops.aten.embedding_renorm" ,
0 commit comments