@@ -2991,8 +2991,8 @@ def _aten_embedding_bag_onnx(
29912991 indices_1d = op .Reshape (indices , neg_1 )
29922992 # Get weight out according to indices_1d,
29932993 new_weight = op .Gather (weight , indices_1d )
2994- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995- new_weight = op .Mul (new_weight , op .Unsqueeze (per_sample_weights , axes = 1 ))
2994+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995+ new_weight = op .Mul (new_weight , op .Unsqueeze (per_sample_weights , axes = [ 1 ] ))
29962996 weight_dim_1 = op .Reshape (op .Shape (weight , start = 1 ), neg_1 )
29972997 indices_size = op .Shape (indices_1d )
29982998
@@ -3131,8 +3131,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
31313131 # Get weight out according to indices,
31323132 # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
31333133 indices_weight = op .Gather (weight , indices )
3134- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135- indices_weight = op .Mul (indices_weight , op .Unsqueeze (per_sample_weights , axes = 1 ))
3134+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135+ indices_weight = op .Mul (indices_weight , op .Unsqueeze (per_sample_weights , axes = [ 1 ] ))
31363136
31373137 # The element in sequence must be FLOAT32 dtype due to ORT bug
31383138 indices_weight = op .Cast (indices_weight , to = FLOAT .dtype )
@@ -4145,7 +4145,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
41454145 return op .Shape (broadcasted )
41464146
41474147
4148- @torch_op ("aten::index.Tensor" , private = True , trace_only = True )
41494148def _aten_index_onnx (
41504149 self : TensorType ,
41514150 indices : Sequence [Optional [INT64 ]],
@@ -4173,7 +4172,7 @@ def _aten_index_onnx(
41734172 not_none_indices = [idx for idx in indices if idx is not None ]
41744173 broadcast_shape = _shape_of_broadcast_tensors (* not_none_indices )
41754174 final_index = op .Concat (
4176- * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), - 1 ) for idx in not_none_indices ),
4175+ * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), [ - 1 ] ) for idx in not_none_indices ),
41774176 axis = - 1 ,
41784177 )
41794178
@@ -7706,13 +7705,13 @@ def aten_select_backward(
77067705 raise NotImplementedError ()
77077706
77087707
7709- @torch_op ("aten::select_scatter" )
7708+ @torch_op ("aten::select_scatter" , trace_only = True )
77107709def aten_select_scatter (self : TensorType , src : TensorType , dim : int , index : int ) -> TensorType :
77117710 """select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor"""
77127711
77137712 # Change src rank to self rank according to dim
77147713 # e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4]
7715- update = op .Unsqueeze (src , axes = dim )
7714+ update = op .Unsqueeze (src , axes = [ dim ] )
77167715 # Change index rank to the same as 'update' [2,1,4]
77177716 indices = op .Expand (index , op .Shape (update ))
77187717 return op .ScatterElements (self , indices , update , axis = dim , reduction = "none" )
@@ -7880,7 +7879,7 @@ def aten_slice_scatter(
78807879 zero ,
78817880 op .Unsqueeze (step , zero ),
78827881 )
7883- index_base = op .Unsqueeze (index_base , - 1 )
7882+ index_base = op .Unsqueeze (index_base , [ - 1 ] )
78847883
78857884 # Use trace only to construct the perm attribute in Transpose
78867885 dims = None
@@ -8623,7 +8622,7 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
86238622
86248623 self_rank = len (self .shape )
86258624 if self_rank == 0 :
8626- result = op .Unsqueeze (self , 0 )
8625+ result = op .Unsqueeze (self , [ 0 ] )
86278626 else :
86288627 # Handle negative dimension
86298628 if dimension < 0 :
@@ -8792,8 +8791,7 @@ def aten_unsafe_split_with_sizes(
87928791def aten_unsqueeze (self : TTensor , dim : int ) -> TTensor :
87938792 """unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"""
87948793
8795- dim = op .Cast (dim , to = INT64 .dtype )
8796- return op .Unsqueeze (self , dim )
8794+ return op .Unsqueeze (self , [dim ])
87978795
87988796
87998797def aten_unsqueeze_copy (self : TensorType , dim : int ) -> TensorType :
0 commit comments