Skip to content

Commit 9893baf

Browse files
committed
fixes
1 parent 9e21a91 commit 9893baf

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,8 @@ def aten_embedding_bag_padding_idx(
31143114
per_sample_weights = op.CastLike(per_sample_weights, weight)
31153115

31163116
if padding_idx is not None:
3117+
if padding_idx < 0:
3118+
padding_idx = weight.shape[0] + padding_idx
31173119
# Call the existing function for handling padding_idx
31183120
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
31193121
weight,

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def _where_input_wrangler(
10661066
),
10671067
TorchLibOpInfo(
10681068
"ops.aten.embedding_bag.padding_idx_int",
1069-
core_ops.aten_embedding_bag,
1069+
core_ops.aten_embedding_bag_padding_idx,
10701070
input_wrangler=_embedding_bag_input_wrangler,
10711071
),
10721072
TorchLibOpInfo(

0 commit comments

Comments
 (0)