@@ -2861,7 +2861,8 @@ def aten_embedding_bag(
28612861 sparse : bool = False ,
28622862 per_sample_weights : Optional [TFloat ] = None ,
28632863 include_last_offset : bool = False ,
2864- ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
2864+ padding_idx : Optional [int ] = None ,
2865+ ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
28652866 """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
28662867
28672868 # assert(rank(indices) in [1,2])
@@ -2889,7 +2890,7 @@ def _aten_embedding_bag_onnx(
28892890 mode : int ,
28902891 per_sample_weights : TFloat ,
28912892 include_last_offset : bool ,
2892- ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
2893+ ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
28932894 neg_1 = op .Constant (value_ints = [- 1 ])
28942895 # Assume indices is shape(5,2), indices_1d is shape(10,)
28952896 indices_1d = op .Reshape (indices , neg_1 )
@@ -2957,23 +2958,24 @@ def _aten_embedding_bag_onnx(
29572958
29582959 # Only compute the shape of other 3 outputs, we don't care the value
29592960 if mode == 0 : # sum
2960- offset2bag = op .Shape (indices , start = 0 , end = 0 ) # Generate empty tensor
2961+ offset2bag = op .Cast ( op . Shape (indices , start = 0 , end = 0 ), to = INT64 . dtype )
29612962 if op .Equal (include_last_offset , True ):
2962- bag_size = op .Expand (0 , op .Shape (offsets ))
2963+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
2964+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
29632965 else :
2964- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
2965- max_indices = op .Expand (0 , op .Shape (bag_size ) )
2966+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
2967+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
29662968 elif mode == 1 : # mean
2967- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
2968- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
2969- max_indices = op .Expand (0 , op .Shape (bag_size ) )
2969+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
2970+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
2971+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
29702972 else : # max
2971- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
2972- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
2973+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
2974+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
29732975 # shape = (bag_size.dim[0], weight.dim[1])
29742976 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
29752977 dim_1 = op .Shape (weight , start = 1 , end = 2 )
2976- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
2978+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
29772979
29782980 return result , offset2bag , bag_size , max_indices
29792981
@@ -2996,7 +2998,7 @@ def aten_embedding_bag_padding_idx(
29962998 per_sample_weights : Optional [TFloat ] = None ,
29972999 include_last_offset : bool = False ,
29983000 padding_idx : Optional [int ] = None ,
2999- ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
3001+ ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
30003002 """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
30013003
30023004 We add default values for the attributes to accommodate _embedding_bag as well:
@@ -3043,8 +3045,14 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30433045 per_sample_weights : TFloat ,
30443046 include_last_offset : bool ,
30453047 padding_idx : int ,
3046- ) -> Tuple [TFloat , TFloat , TFloat , TFloat ]:
3048+ ) -> Tuple [TFloat , INT64 , INT64 , INT64 ]:
30473049 neg_1 = op .Constant (value_ints = [- 1 ])
3050+
3051+ num_embeddings = op .Shape (weight , start = 0 , end = 1 ) # Get number of rows in weight
3052+ num_embeddings_scalar = op .Squeeze (num_embeddings )
3053+ if padding_idx < 0 :
3054+ padding_idx = padding_idx + num_embeddings_scalar
3055+
30483056 # Get weight out according to indices,
30493057 # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
30503058 indices_weight = op .Gather (weight , indices )
@@ -3080,7 +3088,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30803088 cond_2 = j < end_pos
30813089 while cond_2 :
30823090 index = op .Gather (indices , j )
3083- if not op .Equal (index , padding_idx ):
3091+ normalized_index = index
3092+ if index < 0 :
3093+ normalized_index = index + num_embeddings_scalar
3094+ if not op .Equal (normalized_index , padding_idx ):
30843095 # Something like the 'append' operation
30853096 curr_offsets = op .Concat (curr_offsets , op .Reshape (j , neg_1 ), axis = 0 )
30863097 j = j + 1
@@ -3109,23 +3120,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
31093120 result = op .CastLike (result , weight )
31103121
31113122 if mode == 0 : # sum
3112- offset2bag = op .Expand (0 , op .Shape (indices ))
3123+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices )), to = INT64 . dtype )
31133124 if op .Equal (include_last_offset , True ):
3114- bag_size = op .Expand (0 , op .Shape (offsets ))
3125+ bag_size = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
3126+ max_indices = op .Cast (op .Expand (0 , op .Shape (offsets )), to = INT64 .dtype )
31153127 else :
3116- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3117- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3128+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3129+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
31183130 elif mode == 1 : # mean
3119- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3120- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3121- max_indices = op .Expand (0 , op .Shape (bag_size ) )
3131+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3132+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
3133+ max_indices = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
31223134 else : # mode == 2, max
3123- offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
3124- bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
3135+ offset2bag = op .Cast ( op . Expand (0 , op .Shape (indices , start = 0 , end = 1 )), to = INT64 . dtype )
3136+ bag_size = op .Cast ( op . Expand (0 , op .Shape (offsets ) - 1 ), to = INT64 . dtype )
31253137 # shape = (bag_size.dim[0], weight.dim[1])
31263138 dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
31273139 dim_1 = op .Shape (weight , start = 1 , end = 2 )
3128- max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
3140+ max_indices = op .Cast ( op . Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 )), to = INT64 . dtype )
31293141
31303142 return result , offset2bag , bag_size , max_indices
31313143
0 commit comments