@@ -7399,46 +7399,52 @@ def aten_repeat_interleave_self_int(
73997399 if dim is None :
74007400 # Flatten the tensor first, then repeat each element 'repeats' times
74017401 self_flat = op .Reshape (self , [- 1 ])
7402- num_elements = op .Shape (self_flat , start = 0 , end = 1 )
7403-
7404- # Create indices that repeat each original index 'repeats' times
7405- # For input [a, b, c] with repeats=2, we want indices [0, 0, 1, 1, 2, 2]
7406- original_indices = op .Range (
7407- op .Constant (value_ints = [0 ]), num_elements , op .Constant (value_ints = [1 ])
7408- )
7409-
7410- # Repeat each index 'repeats' times
7411- # We can use Tile with appropriate reshaping
7412- indices_reshaped = op .Unsqueeze (original_indices , [1 ]) # Shape: [num_elements, 1]
7402+
7403+ # Add a new dimension and tile to repeat each element
7404+ self_expanded = op .Unsqueeze (self_flat , [1 ]) # Shape: [num_elements, 1]
74137405 repeat_pattern = op .Constant (value_ints = [1 , repeats ])
7414- repeated_indices = op .Tile (
7415- indices_reshaped , repeat_pattern
7416- ) # Shape: [num_elements, repeats]
7417- final_indices = op .Reshape (repeated_indices , [- 1 ]) # Shape: [num_elements * repeats]
7418-
7419- # Gather elements from the flattened tensor
7420- result = op .Gather (self_flat , final_indices , axis = 0 )
7406+ tiled = op .Tile (self_expanded , repeat_pattern ) # Shape: [num_elements, repeats]
7407+ result = op .Reshape (tiled , [- 1 ]) # Shape: [num_elements * repeats]
74217408 return result
74227409
74237410 else :
74247411 # Repeat along specific dimension
7425- dim_size = op .Shape (self , start = dim , end = dim + 1 )
7426-
7427- # Create indices that repeat each original index 'repeats' times
7428- original_indices = op .Range (
7429- op .Constant (value_ints = [0 ]), dim_size , op .Constant (value_ints = [1 ])
7412+ # Apply Tile directly to the tensor instead of creating indices (more efficient)
7413+
7414+ # Expand tensor by adding dimension after target dim
7415+ self_expanded = op .Unsqueeze (self , [dim + 1 ])
7416+
7417+ # Get original shape to build tile pattern dynamically
7418+ original_shape = op .Shape (self )
7419+ num_dims = op .Size (original_shape )
7420+
7421+ # Build tile pattern: all 1s except position dim+1 which is 'repeats'
7422+ # Use ConstantOfShape to create array of 1s, then update specific position
7423+ ones_pattern = op .ConstantOfShape (
7424+ op .Add (num_dims , op .Constant (value_ints = [1 ])), # +1 for the new dimension
7425+ op .Constant (value_ints = [1 ])
74307426 )
7431-
7432- # Repeat each index 'repeats' times
7433- indices_reshaped = op .Unsqueeze (original_indices , [1 ]) # Shape: [dim_size, 1]
7434- repeat_pattern = op .Constant (value_ints = [1 , repeats ])
7435- repeated_indices = op .Tile (
7436- indices_reshaped , repeat_pattern
7437- ) # Shape: [dim_size, repeats]
7438- final_indices = op .Reshape (repeated_indices , [- 1 ]) # Shape: [dim_size * repeats]
7439-
7440- # Gather elements along the specified dimension
7441- result = op .Gather (self , final_indices , axis = dim )
7427+
7428+ # Create indices and updates for ScatterND to set position dim+1 to 'repeats'
7429+ update_indices = op .Reshape (op .Constant (value_ints = [dim + 1 ]), [1 , 1 ])
7430+ update_values = op .Constant (value_ints = [repeats ])
7431+
7432+ tile_pattern = op .ScatterND (ones_pattern , update_indices , update_values )
7433+
7434+ # Tile the expanded tensor
7435+ tiled = op .Tile (self_expanded , tile_pattern )
7436+
7437+ # Reshape to merge the two dimensions
7438+ # Calculate new shape: original shape with target dimension multiplied by repeats
7439+ target_dim_size = op .Gather (original_shape , op .Constant (value_ints = [dim ]))
7440+ new_target_size = op .Mul (target_dim_size , op .Constant (value_ints = [repeats ]))
7441+
7442+ # Create new shape by updating the target dimension
7443+ update_shape_indices = op .Reshape (op .Constant (value_ints = [dim ]), [1 , 1 ])
7444+ new_shape = op .ScatterND (original_shape , update_shape_indices ,
7445+ op .Reshape (new_target_size , [1 ]))
7446+
7447+ result = op .Reshape (tiled , new_shape )
74427448 return result
74437449
74447450
0 commit comments