Skip to content

Commit 9ba946e

Browse files
committed
rollback return types
1 parent 4789d56 commit 9ba946e

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,7 +2862,7 @@ def aten_embedding_bag(
28622862
per_sample_weights: Optional[TFloat] = None,
28632863
include_last_offset: bool = False,
28642864
padding_idx: Optional[int] = None,
2865-
) -> Tuple[TFloat, INT64, INT64, INT64]:
2865+
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
28662866
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
28672867

28682868
# assert(rank(indices) in [1,2])
@@ -2890,7 +2890,7 @@ def _aten_embedding_bag_onnx(
28902890
mode: int,
28912891
per_sample_weights: TFloat,
28922892
include_last_offset: bool,
2893-
) -> Tuple[TFloat, INT64, INT64, INT64]:
2893+
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
28942894
neg_1 = op.Constant(value_ints=[-1])
28952895
# Assume indices is shape(5,2), indices_1d is shape(10,)
28962896
indices_1d = op.Reshape(indices, neg_1)
@@ -2998,7 +2998,7 @@ def aten_embedding_bag_padding_idx(
29982998
per_sample_weights: Optional[TFloat] = None,
29992999
include_last_offset: bool = False,
30003000
padding_idx: Optional[int] = None,
3001-
) -> Tuple[TFloat, INT64, INT64, INT64]:
3001+
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
30023002
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
30033003
30043004
We add default values for the attributes to accommodate _embedding_bag as well:
@@ -3045,7 +3045,7 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30453045
per_sample_weights: TFloat,
30463046
include_last_offset: bool,
30473047
padding_idx: int,
3048-
) -> Tuple[TFloat, INT64, INT64, INT64]:
3048+
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
30493049
neg_1 = op.Constant(value_ints=[-1])
30503050

30513051
num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight

0 commit comments

Comments
 (0)