Update aten_index_put implementation#2712
Conversation
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2712 +/- ##
==========================================
- Coverage 70.11% 70.06% -0.05%
==========================================
Files 226 226
Lines 27230 27266 +36
Branches 2748 2754 +6
==========================================
+ Hits 19092 19105 +13
- Misses 7193 7210 +17
- Partials 945 951 +6 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
|
@xadupre It looks like not merged yet? |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
|
@titaiwangms @xadupre PTAL |
|
User: Explain this implementation in detail, with examples GitHub Copilot: I'll explain the Overview
Key ConceptsAdvanced Indexing: When indices are tensors (not scalars), PyTorch uses "advanced indexing" which has special broadcasting and dimension placement rules. Implementation Breakdown1. Index Normalization (Lines 4541-4543)self_rank = len(self.shape)
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))Example: If 2. Identifying Index Types (Lines 4548-4565)def is_advanced_index(index):
return index is not None
advanced_indices: list[int] = [] # Positions with tensor indices
none_indices: list[int] = [] # Positions with None (slices)Example: For
3. Broadcasting Advanced Indices (Lines 4570-4600)When multiple advanced indices exist, they must be broadcast to a common shape: # Broadcast advanced indices to a common shape
advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices)
# ... padding and Max operation to get common shapeExample: # tensor shape: [7, 3, 4, 5]
# indices = [idx1, None, idx2, None]
# idx1.shape = [2, 3], idx2.shape = [2, 1]
# After broadcasting: both become [2, 3]4. Creating ONNX ScatterND Index (Lines 4602-4618)ONNX's advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices]
onnx_index = op.Concat(*advanced_index_values, axis=-1)Example: # idx1.shape = [2, 3], idx2.shape = [2, 3]
# After unsqueeze: [2, 3, 1] and [2, 3, 1]
# After concat: [2, 3, 2] (last dim = number of advanced indices)5. Checking Contiguity (Lines 4620-4623)Determines if advanced indices appear consecutively: contiguous = True
if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices):
contiguous = FalseExample:
6. Dimension Permutation (Lines 4625-4627)ONNX perm = advanced_indices + none_indices # Bring advanced indices to front
transposed = op.Transpose(self, perm=perm)Example: # Original: self[dim0, dim1, dim2, dim3]
# indices pattern: [idx, None, idx, None]
# perm = [0, 2, 1, 3] (advanced first: 0,2 then none: 1,3)
# Result: transposed[dim0, dim2, dim1, dim3]7. Values Transposition (Lines 4629-4648)If indices are contiguous, if contiguous:
# values must match where PyTorch places the advanced indexing dimensions
values_perm = (
advanced_index_replacement_positions # Where advanced dims go
+ initial_none_index_positions # Leading None dims
+ final_none_index_positions # Trailing None dims
)
values = op.Transpose(values, perm=values_perm)Example: # For tensor[:, idx1, idx2, :] where idx1, idx2 broadcast to [2, 3]
# target_shape = [2, 3, orig_dim0, orig_dim3]
# values needs to be [orig_dim0, 2, 3, orig_dim3] initially
# Transpose to [2, 3, orig_dim0, orig_dim3]8. ScatterND Operation (Lines 4652-4654)updated = op.ScatterND(
transposed, onnx_index, expanded_values,
reduction="add" if accumulate else None
)
9. Inverse Transpose (Lines 4658-4661)Restore original dimension order: inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)Complete Example# PyTorch operation:
# tensor[idx1, :, idx2] = values
# Given:
tensor.shape = [7, 4, 5, 6]
idx1 = torch.tensor([[0, 1], [2, 3]]) # shape [2, 2]
idx2 = torch.tensor([[1], [2]]) # shape [2, 1]
values.shape = [2, 2, 4, 6] # matches result shape
# Processing:
# 1. indices = [idx1, None, idx2, None]
# 2. advanced_indices = [0, 2], none_indices = [1, 3]
# 3. Broadcast idx1, idx2 to [2, 2]
# 4. onnx_index.shape = [2, 2, 2] (positions for dims 0 and 2)
# 5. perm = [0, 2, 1, 3] (advanced first)
# 6. transposed.shape = [7, 5, 4, 6]
# 7. target_shape = [2, 2, 4, 6]
# 8. ScatterND updates transposed
# 9. Inverse transpose back to [7, 4, 5, 6]Why This Complexity?PyTorch's advanced indexing has nuanced rules about dimension ordering that don't map directly to ONNX
User: Contrast and compare the implementation with _aten_index_onnx GitHub Copilot: I'll contrast Read , lines 4339 to 4430 Read , lines 4430 to 4490 Comparison:
|
| Aspect | _aten_index_onnx |
aten_index_put |
|---|---|---|
| Operation | result = tensor[indices] (read/gather) |
tensor[indices] = values (write/scatter) |
| ONNX Op | GatherND |
ScatterND |
| Returns | New tensor with selected values | Modified tensor |
Key Similarities
1. Index Normalization
Both pad indices to match tensor rank:
# _aten_index_onnx: implicit (handled by caller)
# aten_index_put: explicit
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))2. Index Broadcasting
Both broadcast advanced indices to common shape:
# Both do:
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
# Then expand/broadcast each index to this shape3. Dimension Reordering
Both move advanced indices to front for ONNX compatibility:
# _aten_index_onnx:
reordered_positions = sorted(range(len(indices)),
key=lambda i: (indices[i] is None, i))
# aten_index_put:
perm = advanced_indices + none_indices # Simpler, same effect4. Contiguity Check
Both check if advanced indices are consecutive to determine dimension placement:
# Both use: _has_none_in_middle() or similar logic
# This affects final transpose operationKey Differences
1. Operation Direction
_aten_index_onnx - Extracts values:
self = op.GatherND(self, final_index, batch_dims=0)
# Returns subset of original tensoraten_index_put - Inserts values:
updated = op.ScatterND(transposed, onnx_index, expanded_values,
reduction="add" if accumulate else None)
# Returns full tensor with updates2. Values Handling
_aten_index_onnx: No values input
- Only needs to know where to read from
aten_index_put: Must handle values tensor
# Lines 4629-4648: Complex values transposition logic
if contiguous:
values_perm = (advanced_index_replacement_positions
+ initial_none_index_positions
+ final_none_index_positions)
values = op.Transpose(values, perm=values_perm)
expanded_values = op.Expand(values, target_shape)3. Shape Computation
_aten_index_onnx:
- Output shape = broadcasted index shape + remaining None dimensions
- Simpler calculation
aten_index_put:
# Lines 4603-4607: Must compute target shape for values expansion
none_indices_constant = op.Constant(value_ints=none_indices)
none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0)
target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0)4. Inverse Transpose
_aten_index_onnx: One-way transpose
# Transpose input, gather, potentially transpose result
return op.Transpose(self, perm=perm)aten_index_put: Round-trip transpose
# Must transpose back to original dimension order
inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)5. Accumulation Option
_aten_index_onnx: N/A
aten_index_put: Supports accumulation
updated = op.ScatterND(..., reduction="add" if accumulate else None)
# Can add to existing values instead of replacingExample Comparison
# Setup
tensor = torch.randn(7, 4, 5, 6)
idx1 = torch.tensor([[0, 1], [2, 3]]) # [2, 2]
idx2 = torch.tensor([[1], [2]]) # [2, 1]
# READ (_aten_index_onnx):
result = tensor[idx1, :, idx2]
# Steps:
# 1. Transpose: [7, 5, 4, 6] (advanced indices first: dims 0,2)
# 2. Broadcast indices to [2, 2]
# 3. GatherND → [2, 2, 4, 6] (reads values at those positions)
# 4. Transpose to [4, 2, 2, 6] (if contiguous, move dims to middle)
# WRITE (aten_index_put):
tensor[idx1, :, idx2] = values # values.shape = [4, 2, 2, 6]
# Steps:
# 1. Transpose tensor: [7, 5, 4, 6]
# 2. Broadcast indices to [2, 2]
# 3. Transpose values: [2, 2, 4, 6] (match scatter format)
# 4. ScatterND → writes values at those positions
# 5. Inverse transpose: [7, 4, 5, 6] (restore original order)Complexity Comparison
| Aspect | _aten_index_onnx |
aten_index_put |
|---|---|---|
| Lines of code | ~100 lines | ~160 lines |
| Transposes | 1-2 | 2-3 (includes inverse) |
| Shape calculations | Simpler | More complex (target_shape) |
| Values handling | None | Extensive (transpose + expand) |
Why More Complexity in aten_index_put?
- Bidirectional operation: Must preserve original tensor structure
- Values tensor: Additional input requiring careful alignment
- Shape preservation: Output shape = input shape (not reduced)
- Accumulation: Optional addition instead of replacement
The read operation (_aten_index_onnx) is fundamentally simpler because it only extracts data, while write operation (aten_index_put) must carefully place values back into the correct positions while preserving tensor structure.
Update aten_index_put implementation in torchlib.