Skip to content

Commit e47e8a6

Browse files
crypto-aAravind-11
authored andcommitted
added test cases for aten_embedding_bag_padding_idx
1 parent 76e217c commit e47e8a6

1 file changed

Lines changed: 0 additions & 20 deletions

File tree

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -203,26 +203,6 @@ def _embedding_bag_input_wrangler(
203203

204204
return args, kwargs
205205

206-
207-
def _embedding_bag_input_wrangler(
208-
args: list[Any], kwargs: dict[str, Any]
209-
) -> tuple[list[Any], dict[str, Any]]:
210-
# ONNX attributes cannot be None; omit padding_idx if it's None.
211-
if "padding_idx" in kwargs:
212-
padding_idx = kwargs.pop("padding_idx")
213-
if padding_idx is not None:
214-
kwargs["padding_idx"] = int(padding_idx)
215-
216-
# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
217-
if len(args) >= 3:
218-
if isinstance(args[1], torch.Tensor):
219-
args[1] = args[1].to(torch.long)
220-
if isinstance(args[2], torch.Tensor):
221-
args[2] = args[2].to(torch.long)
222-
223-
return args, kwargs
224-
225-
226206
def _amin_amax_input_wrangler(
227207
args: list[Any], kwargs: dict[str, Any]
228208
) -> tuple[list[Any], dict[str, Any]]:

0 commit comments

Comments
 (0)