Skip to content

Commit 3e4bf1b

Browse files
Copilotxadupre
andcommitted
Optimize repeat_interleave.self_int to use Tile directly on tensor instead of indices
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent 6d62024 commit 3e4bf1b

1 file changed

Lines changed: 40 additions & 34 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7399,46 +7399,52 @@ def aten_repeat_interleave_self_int(
73997399
if dim is None:
74007400
# Flatten the tensor first, then repeat each element 'repeats' times
74017401
self_flat = op.Reshape(self, [-1])
7402-
num_elements = op.Shape(self_flat, start=0, end=1)
7403-
7404-
# Create indices that repeat each original index 'repeats' times
7405-
# For input [a, b, c] with repeats=2, we want indices [0, 0, 1, 1, 2, 2]
7406-
original_indices = op.Range(
7407-
op.Constant(value_ints=[0]), num_elements, op.Constant(value_ints=[1])
7408-
)
7409-
7410-
# Repeat each index 'repeats' times
7411-
# We can use Tile with appropriate reshaping
7412-
indices_reshaped = op.Unsqueeze(original_indices, [1]) # Shape: [num_elements, 1]
7402+
7403+
# Add a new dimension and tile to repeat each element
7404+
self_expanded = op.Unsqueeze(self_flat, [1]) # Shape: [num_elements, 1]
74137405
repeat_pattern = op.Constant(value_ints=[1, repeats])
7414-
repeated_indices = op.Tile(
7415-
indices_reshaped, repeat_pattern
7416-
) # Shape: [num_elements, repeats]
7417-
final_indices = op.Reshape(repeated_indices, [-1]) # Shape: [num_elements * repeats]
7418-
7419-
# Gather elements from the flattened tensor
7420-
result = op.Gather(self_flat, final_indices, axis=0)
7406+
tiled = op.Tile(self_expanded, repeat_pattern) # Shape: [num_elements, repeats]
7407+
result = op.Reshape(tiled, [-1]) # Shape: [num_elements * repeats]
74217408
return result
74227409

74237410
else:
74247411
# Repeat along specific dimension
7425-
dim_size = op.Shape(self, start=dim, end=dim + 1)
7426-
7427-
# Create indices that repeat each original index 'repeats' times
7428-
original_indices = op.Range(
7429-
op.Constant(value_ints=[0]), dim_size, op.Constant(value_ints=[1])
7412+
# Apply Tile directly to the tensor instead of creating indices (more efficient)
7413+
7414+
# Expand tensor by adding dimension after target dim
7415+
self_expanded = op.Unsqueeze(self, [dim + 1])
7416+
7417+
# Get original shape to build tile pattern dynamically
7418+
original_shape = op.Shape(self)
7419+
num_dims = op.Size(original_shape)
7420+
7421+
# Build tile pattern: all 1s except position dim+1 which is 'repeats'
7422+
# Use ConstantOfShape to create array of 1s, then update specific position
7423+
ones_pattern = op.ConstantOfShape(
7424+
op.Add(num_dims, op.Constant(value_ints=[1])), # +1 for the new dimension
7425+
op.Constant(value_ints=[1])
74307426
)
7431-
7432-
# Repeat each index 'repeats' times
7433-
indices_reshaped = op.Unsqueeze(original_indices, [1]) # Shape: [dim_size, 1]
7434-
repeat_pattern = op.Constant(value_ints=[1, repeats])
7435-
repeated_indices = op.Tile(
7436-
indices_reshaped, repeat_pattern
7437-
) # Shape: [dim_size, repeats]
7438-
final_indices = op.Reshape(repeated_indices, [-1]) # Shape: [dim_size * repeats]
7439-
7440-
# Gather elements along the specified dimension
7441-
result = op.Gather(self, final_indices, axis=dim)
7427+
7428+
# Create indices and updates for ScatterND to set position dim+1 to 'repeats'
7429+
update_indices = op.Reshape(op.Constant(value_ints=[dim + 1]), [1, 1])
7430+
update_values = op.Constant(value_ints=[repeats])
7431+
7432+
tile_pattern = op.ScatterND(ones_pattern, update_indices, update_values)
7433+
7434+
# Tile the expanded tensor
7435+
tiled = op.Tile(self_expanded, tile_pattern)
7436+
7437+
# Reshape to merge the two dimensions
7438+
# Calculate new shape: original shape with target dimension multiplied by repeats
7439+
target_dim_size = op.Gather(original_shape, op.Constant(value_ints=[dim]))
7440+
new_target_size = op.Mul(target_dim_size, op.Constant(value_ints=[repeats]))
7441+
7442+
# Create new shape by updating the target dimension
7443+
update_shape_indices = op.Reshape(op.Constant(value_ints=[dim]), [1, 1])
7444+
new_shape = op.ScatterND(original_shape, update_shape_indices,
7445+
op.Reshape(new_target_size, [1]))
7446+
7447+
result = op.Reshape(tiled, new_shape)
74427448
return result
74437449

74447450

0 commit comments

Comments
 (0)