@@ -2862,7 +2862,7 @@ def aten_embedding_bag(
28622862 per_sample_weights : Optional [TFloat ] = None ,
28632863 include_last_offset : bool = False ,
28642864 padding_idx : Optional [int ] = None ,
2865- ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
2865+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
28662866 """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
28672867
28682868 # assert(rank(indices) in [1,2])
@@ -2890,7 +2890,7 @@ def _aten_embedding_bag_onnx(
28902890 mode : int ,
28912891 per_sample_weights : TFloat ,
28922892 include_last_offset : bool ,
2893- ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
2893+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
28942894 neg_1 = op .Constant (value_ints = [- 1 ])
28952895 # Assume indices is shape(5,2), indices_1d is shape(10,)
28962896 indices_1d = op .Reshape (indices , neg_1 )
@@ -2998,7 +2998,7 @@ def aten_embedding_bag_padding_idx(
29982998 per_sample_weights : Optional [TFloat ] = None ,
29992999 include_last_offset : bool = False ,
30003000 padding_idx : Optional [int ] = None ,
3001- ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
3001+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
30023002 """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
30033003
30043004 We add default values for the attributes to accommodate _embedding_bag as well:
@@ -3045,7 +3045,7 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30453045 per_sample_weights : TFloat ,
30463046 include_last_offset : bool ,
30473047 padding_idx : int ,
3048- ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
3048+ ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
30493049 neg_1 = op .Constant (value_ints = [- 1 ])
30503050
30513051 num_embeddings = op .Shape (weight , start = 0 , end = 1 ) # Get number of rows in weight
0 commit comments