Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,27 +3101,40 @@ def aten_embedding_bag_padding_idx(
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
padding_idx: Optional[int] = None,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)

We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
per_sample_weights = op.CastLike(per_sample_weights, weight)

# Change padding_idx to positive value, -1 means the last index
if padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx
if padding_idx is not None:
# Call the existing function for handling padding_idx
result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx(
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
padding_idx,
)

result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
return result, offset2bag, bag_size, max_indices

# When padding_idx is None, use the standard embedding_bag implementation
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
)

return result, offset2bag, bag_size, max_indices
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,44 @@ def __init__(self):
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"test_embedding_bag_with_padding_idx_none",
op=torch.nn.functional.embedding_bag,
dtypes=(torch.float32,),
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": None},
)
],
),
opinfo_core.OpInfo(
"test_embedding_bag_with_padding_idx_int",
Comment thread
crypto-a marked this conversation as resolved.
Outdated
op=torch.nn.functional.embedding_bag,
dtypes=(torch.float32,),
Comment thread
justinchuby marked this conversation as resolved.
Outdated
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": 0},
)
],
),
opinfo_core.OpInfo(
"ops.aten.embedding_renorm",
aten_name="embedding_renorm",
Expand Down
41 changes: 41 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,24 @@ def xfail(
# Modify this section ##########################################################


def _embedding_bag_input_wrangler(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final round of reivews: Is this necessary? I think it should accept a None input?

args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# ONNX attributes cannot be None; omit padding_idx if it's None.
if "padding_idx" in kwargs:
padding_idx = kwargs.pop("padding_idx")
if padding_idx is not None:
kwargs["padding_idx"] = int(padding_idx)

# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
if len(args) >= 3:
if isinstance(args[1], torch.Tensor):
args[1] = args[1].to(torch.long)
if isinstance(args[2], torch.Tensor):
args[2] = args[2].to(torch.long)

return args, kwargs

def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1035,15 +1053,38 @@ def _where_input_wrangler(
core_ops.aten_embedding_bag,
tolerance={torch.float32: (1e-4, 5e-4)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
).skip(
dtypes=(torch.float16,),
reason="fixme: results mismatch in torch nightly.",
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_none",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_int",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),

TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx",
core_ops.aten_embedding_bag_padding_idx,
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_none",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_int",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
Expand Down
Loading