Skip to content

Commit c815e65

Browse files
author
rgao user
committed
Optimize _sparse_index_exchange: remove redundant sort, batch transfers
- Remove redundant argsort in _sparse_index_exchange: the caller (build_gp_context) already sorts needed_atoms by source rank, so the exchange function can use them directly as the send buffer. - Batch send_counts and recv_counts into a single GPU→CPU transfer (torch.stack + .cpu()) instead of separate .sum().item() and .tolist() calls, eliminating 2 GPU→CPU synchronization points. - Remove now-unused needed_from_ranks parameter from the function.
1 parent 11bb385 commit c815e65

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def build_gp_context(
347347
with record_function("a2a_sparse_index_exchange"):
348348
send_counts, send_indices_global = _sparse_index_exchange(
349349
needed_atoms=needed_atoms,
350-
needed_from_ranks=needed_from_ranks,
351350
recv_counts=recv_counts,
352351
rank=rank,
353352
world_size=world_size,
@@ -497,7 +496,6 @@ def build_gp_context(
497496

498497
def _sparse_index_exchange(
499498
needed_atoms: torch.Tensor,
500-
needed_from_ranks: torch.Tensor,
501499
recv_counts: torch.Tensor,
502500
rank: int,
503501
world_size: int,
@@ -513,8 +511,8 @@ def _sparse_index_exchange(
513511
keeping communication volume minimal.
514512
515513
Args:
516-
needed_atoms: Global indices of atoms this rank needs, sorted.
517-
needed_from_ranks: Source rank for each needed atom.
514+
needed_atoms: Global indices of atoms this rank needs,
515+
pre-sorted by source rank (done by the caller).
518516
recv_counts: Number of atoms needed from each rank.
519517
rank: This rank's GP rank.
520518
world_size: GP world size.
@@ -546,19 +544,21 @@ def _sparse_index_exchange(
546544
_safe_all_to_all(recv_list, send_list, group=gp_group)
547545

548546
# Step 2: Exchange actual atom indices with variable splits.
549-
# Build send buffer: needed_atoms sorted by source rank.
547+
# needed_atoms is already sorted by source rank (done by the
548+
# caller in build_gp_context), so use it directly as send buffer.
550549
if needed_atoms.numel() > 0:
551-
sort_order = needed_from_ranks.argsort(stable=True)
552-
send_buf = needed_atoms[sort_order].contiguous()
550+
send_buf = needed_atoms.contiguous()
553551
else:
554552
send_buf = torch.empty(0, dtype=torch.long, device=device)
555553

556-
total_recv_indices = send_counts.sum().item()
554+
# Batch send_counts and recv_counts into a single GPU→CPU transfer.
555+
# This eliminates 2 extra GPU→CPU syncs vs separate .tolist() calls.
556+
counts_cpu = torch.stack([send_counts, recv_counts]).cpu()
557+
recv_splits = counts_cpu[0].tolist() # what we recv = what we need
558+
send_splits = counts_cpu[1].tolist() # what we send = what others need
559+
total_recv_indices = sum(recv_splits)
557560
recv_buf = torch.empty(total_recv_indices, dtype=torch.long, device=device)
558561

559-
send_splits = recv_counts.tolist() # what we send = what others need from us
560-
recv_splits = send_counts.tolist() # what we recv = what we need from others
561-
562562
if backend == "nccl":
563563
dist.all_to_all_single(
564564
recv_buf,

0 commit comments

Comments
 (0)