Skip to content

Commit 5f6156b

Browse files
committed
fix: resolve embedding_bag untyped ONNX outputs with explicit casts
1 parent b8aae85 commit 5f6156b

1 file changed

Lines changed: 37 additions & 25 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,7 +2861,8 @@ def aten_embedding_bag(
28612861
sparse: bool = False,
28622862
per_sample_weights: Optional[TFloat] = None,
28632863
include_last_offset: bool = False,
2864-
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
2864+
padding_idx: Optional[int] = None,
2865+
) -> Tuple[TFloat, INT64, INT64, INT64]:
28652866
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
28662867

28672868
# assert(rank(indices) in [1,2])
@@ -2889,7 +2890,7 @@ def _aten_embedding_bag_onnx(
28892890
mode: int,
28902891
per_sample_weights: TFloat,
28912892
include_last_offset: bool,
2892-
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
2893+
) -> Tuple[TFloat, INT64, INT64, INT64]:
28932894
neg_1 = op.Constant(value_ints=[-1])
28942895
# Assume indices is shape(5,2), indices_1d is shape(10,)
28952896
indices_1d = op.Reshape(indices, neg_1)
@@ -2957,23 +2958,24 @@ def _aten_embedding_bag_onnx(
29572958

29582959
# Only compute the shape of other 3 outputs, we don't care the value
29592960
if mode == 0: # sum
2960-
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
2961+
offset2bag = op.Cast(op.Shape(indices, start=0, end=0), to=INT64.dtype)
29612962
if op.Equal(include_last_offset, True):
2962-
bag_size = op.Expand(0, op.Shape(offsets))
2963+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
2964+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
29632965
else:
2964-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
2965-
max_indices = op.Expand(0, op.Shape(bag_size))
2966+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
2967+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) -1), to=INT64.dtype)
29662968
elif mode == 1: # mean
2967-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
2968-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
2969-
max_indices = op.Expand(0, op.Shape(bag_size))
2969+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
2970+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
2971+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
29702972
else: # max
2971-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
2972-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
2973+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
2974+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
29732975
# shape = (bag_size.dim[0], weight.dim[1])
29742976
dim_0 = op.Shape(bag_size, start=0, end=1)
29752977
dim_1 = op.Shape(weight, start=1, end=2)
2976-
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
2978+
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)
29772979

29782980
return result, offset2bag, bag_size, max_indices
29792981

@@ -2996,7 +2998,7 @@ def aten_embedding_bag_padding_idx(
29962998
per_sample_weights: Optional[TFloat] = None,
29972999
include_last_offset: bool = False,
29983000
padding_idx: Optional[int] = None,
2999-
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
3001+
) -> Tuple[TFloat, INT64, INT64, INT64]:
30003002
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
30013003
30023004
We add default values for the attributes to accommodate _embedding_bag as well:
@@ -3043,8 +3045,14 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30433045
per_sample_weights: TFloat,
30443046
include_last_offset: bool,
30453047
padding_idx: int,
3046-
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
3048+
) -> Tuple[TFloat, INT64, INT64, INT64]:
30473049
neg_1 = op.Constant(value_ints=[-1])
3050+
3051+
num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight
3052+
num_embeddings_scalar = op.Squeeze(num_embeddings)
3053+
if padding_idx < 0:
3054+
padding_idx = padding_idx + num_embeddings_scalar
3055+
30483056
# Get weight out according to indices,
30493057
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
30503058
indices_weight = op.Gather(weight, indices)
@@ -3080,7 +3088,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
30803088
cond_2 = j < end_pos
30813089
while cond_2:
30823090
index = op.Gather(indices, j)
3083-
if not op.Equal(index, padding_idx):
3091+
normalized_index = index
3092+
if index < 0:
3093+
normalized_index = index + num_embeddings_scalar
3094+
if not op.Equal(normalized_index, padding_idx):
30843095
# Something like the 'append' operation
30853096
curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0)
30863097
j = j + 1
@@ -3109,23 +3120,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
31093120
result = op.CastLike(result, weight)
31103121

31113122
if mode == 0: # sum
3112-
offset2bag = op.Expand(0, op.Shape(indices))
3123+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices)), to=INT64.dtype)
31133124
if op.Equal(include_last_offset, True):
3114-
bag_size = op.Expand(0, op.Shape(offsets))
3125+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
3126+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
31153127
else:
3116-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3117-
max_indices = op.Expand(0, op.Shape(bag_size))
3128+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3129+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
31183130
elif mode == 1: # mean
3119-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3120-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3121-
max_indices = op.Expand(0, op.Shape(bag_size))
3131+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3132+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
3133+
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
31223134
else: # mode == 2, max
3123-
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
3124-
bag_size = op.Expand(0, op.Shape(offsets) - 1)
3135+
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
3136+
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
31253137
# shape = (bag_size.dim[0], weight.dim[1])
31263138
dim_0 = op.Shape(bag_size, start=0, end=1)
31273139
dim_1 = op.Shape(weight, start=1, end=2)
3128-
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
3140+
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)
31293141

31303142
return result, offset2bag, bag_size, max_indices
31313143

0 commit comments

Comments
 (0)