Skip to content

Commit 1f84cc7

Browse files
authored
Merge branch 'fix-2219' into main
2 parents da967e3 + d2da96e commit 1f84cc7

3 files changed

Lines changed: 95 additions & 10 deletions

File tree

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,27 +2995,40 @@ def aten_embedding_bag_padding_idx(
29952995
sparse: bool = False,
29962996
per_sample_weights: Optional[TFloat] = None,
29972997
include_last_offset: bool = False,
2998-
padding_idx: int = -1,
2998+
padding_idx: Optional[int] = None,
29992999
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
30003000
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
30013001
30023002
We add default values for the attributes to accommodate _embedding_bag as well:
30033003
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
30043004
"""
3005-
assert padding_idx is not None, (
3006-
"padding_idx must not be None. This is likely a dispatcher error"
3007-
)
30083005

30093006
if per_sample_weights is None:
30103007
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
30113008
per_sample_weights = op.CastLike(per_sample_weights, weight)
30123009

3013-
# Change padding_idx to positive value, -1 means the last index
3014-
if padding_idx < 0:
3015-
padding_idx = weight.shape[0] + padding_idx
3010+
if padding_idx is not None:
3011+
# Call the existing function for handling padding_idx
3012+
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
3013+
weight,
3014+
indices,
3015+
offsets,
3016+
mode,
3017+
per_sample_weights,
3018+
include_last_offset,
3019+
padding_idx,
3020+
)
30163021

3017-
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
3018-
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
3022+
return result, offset2bag, bag_size, max_indices
3023+
3024+
# When padding_idx is None, use the standard embedding_bag implementation
3025+
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
3026+
weight,
3027+
indices,
3028+
offsets,
3029+
mode,
3030+
per_sample_weights,
3031+
include_last_offset,
30193032
)
30203033

30213034
return result, offset2bag, bag_size, max_indices

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,44 @@ def __init__(self):
24832483
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
24842484
supports_out=False,
24852485
),
2486+
opinfo_core.OpInfo(
2487+
"ops.aten.embedding_bag.padding_idx_none",
2488+
op=torch.nn.functional.embedding_bag,
2489+
dtypes=common_dtype.floating_types_and_half(),
2490+
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
2491+
opinfo_core.SampleInput(
2492+
torch.tensor(
2493+
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
2494+
dtype=dtype,
2495+
device=device,
2496+
),
2497+
args=(
2498+
torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device),
2499+
torch.tensor([0, 2], dtype=torch.int64, device=device),
2500+
),
2501+
kwargs={"padding_idx": None},
2502+
)
2503+
],
2504+
),
2505+
opinfo_core.OpInfo(
2506+
"ops.aten.embedding_bag.padding_idx_int",
2507+
op=torch.nn.functional.embedding_bag,
2508+
dtypes=common_dtype.floating_types_and_half(),
2509+
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
2510+
opinfo_core.SampleInput(
2511+
torch.tensor(
2512+
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
2513+
dtype=dtype,
2514+
device=device,
2515+
),
2516+
args=(
2517+
torch.tensor([0, 1, 2], dtype=torch.int64, device=device),
2518+
torch.tensor([0, 2], dtype=torch.int64, device=device),
2519+
),
2520+
kwargs={"padding_idx": 0},
2521+
)
2522+
],
2523+
),
24862524
opinfo_core.OpInfo(
24872525
"ops.aten.embedding_renorm",
24882526
aten_name="embedding_renorm",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,25 @@ def xfail(
184184
# Modify this section ##########################################################
185185

186186

187+
def _embedding_bag_input_wrangler(
188+
args: list[Any], kwargs: dict[str, Any]
189+
) -> tuple[list[Any], dict[str, Any]]:
190+
# ONNX attributes cannot be None; omit padding_idx if it's None.
191+
if "padding_idx" in kwargs:
192+
padding_idx = kwargs.pop("padding_idx")
193+
if padding_idx is not None:
194+
kwargs["padding_idx"] = int(padding_idx)
195+
196+
# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
197+
if len(args) >= 3:
198+
if isinstance(args[1], torch.Tensor):
199+
args[1] = args[1].to(torch.long)
200+
if isinstance(args[2], torch.Tensor):
201+
args[2] = args[2].to(torch.long)
202+
203+
return args, kwargs
204+
205+
187206
def _amin_amax_input_wrangler(
188207
args: list[Any], kwargs: dict[str, Any]
189208
) -> tuple[list[Any], dict[str, Any]]:
@@ -908,12 +927,27 @@ def _where_input_wrangler(
908927
core_ops.aten_embedding_bag,
909928
tolerance={torch.float32: (1e-4, 5e-4)},
910929
compare_shape_only_for_output=(1, 2, 3),
911-
).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."),
930+
input_wrangler=_embedding_bag_input_wrangler,
931+
).skip(
932+
dtypes=(torch.float16,),
933+
reason="fixme: results mismatch in torch nightly.",
934+
),
935+
TorchLibOpInfo(
936+
"ops.aten.embedding_bag.padding_idx_none",
937+
core_ops.aten_embedding_bag,
938+
input_wrangler=_embedding_bag_input_wrangler,
939+
),
940+
TorchLibOpInfo(
941+
"ops.aten.embedding_bag.padding_idx_int",
942+
core_ops.aten_embedding_bag,
943+
input_wrangler=_embedding_bag_input_wrangler,
944+
),
912945
TorchLibOpInfo(
913946
"ops.aten.embedding_bag.padding_idx",
914947
core_ops.aten_embedding_bag_padding_idx,
915948
tolerance={torch.float16: (1e-2, 1e-2)},
916949
compare_shape_only_for_output=(1, 2, 3),
950+
input_wrangler=_embedding_bag_input_wrangler,
917951
),
918952
TorchLibOpInfo(
919953
"ops.aten.embedding_renorm",

0 commit comments

Comments
 (0)