Skip to content

Commit 11bb385

Browse files
author
rgao user
committed
Remove dead P2P communication code
- Remove all_to_all_collect_p2p() function (never called) - Remove send_neighbors, recv_neighbors, _recv_buf fields from GPContext - Remove neighbor list precomputation from build_gp_context() P2P was an experimental alternative to all_to_all_single that used batch_isend_irecv with sparse neighbor lists. It was never integrated into the forward path and benchmarking showed no improvement over all_to_all_single (NCCL handles sparse communication internally).
1 parent 28d9bd3 commit 11bb385

1 file changed

Lines changed: 0 additions & 108 deletions

File tree

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 0 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,6 @@ class GPContext:
119119
send_splits: list[int] | None = None
120120
recv_splits: list[int] | None = None
121121
total_recv: int | None = None
122-
# Precomputed sparse neighbor lists for P2P communication
123-
# Only includes ranks with non-zero send/recv counts, avoiding
124-
# O(world_size) iteration in the communication hot path.
125-
send_neighbors: list[int] | None = None # Ranks we send to (non-zero send)
126-
recv_neighbors: list[int] | None = None # Ranks we recv from (non-zero recv)
127-
# Pre-allocated receive buffer (lazily initialized).
128-
# Reused across layers within a step to avoid per-layer CUDA malloc.
129-
_recv_buf: torch.Tensor | None = None
130122

131123

132124
def _expand_bits_10(v: torch.Tensor) -> torch.Tensor:
@@ -500,13 +492,6 @@ def build_gp_context(
500492
send_splits=send_splits,
501493
recv_splits=recv_splits,
502494
total_recv=total_recv,
503-
# Precompute sparse neighbor lists for communication
504-
send_neighbors=[
505-
r for r in range(world_size) if send_splits[r] > 0 and r != rank
506-
],
507-
recv_neighbors=[
508-
r for r in range(world_size) if recv_splits[r] > 0 and r != rank
509-
],
510495
)
511496

512497

@@ -884,96 +869,3 @@ def all_to_all_collect_compiled(
884869
)
885870

886871
return x_recv
887-
888-
889-
@torch.compiler.disable
890-
def all_to_all_collect_p2p(
891-
x_local: torch.Tensor,
892-
gp_ctx: GPContext,
893-
send_indices: torch.Tensor,
894-
) -> torch.Tensor:
895-
"""
896-
Collect remote embeddings using sparse P2P communication.
897-
898-
Instead of ``all_to_all_single`` (which creates P-1 send/recv pairs
899-
even for zero-length messages), this uses ``batch_isend_irecv`` with
900-
only the non-zero neighbors. At 64 GPUs with spatial partitioning,
901-
each rank typically has ~10-15 actual neighbors, so this reduces
902-
the number of NCCL operations from ~63 to ~25.
903-
904-
Uses pre-allocated send/recv buffers stored on ``gp_ctx`` to avoid
905-
per-layer CUDA memory allocation overhead (4 allocations/step saved).
906-
907-
Does NOT participate in autograd — intended for eval-mode inference
908-
only. For training, use ``all_to_all_collect`` instead.
909-
910-
Args:
911-
x_local: Local atom embeddings, shape (local_atoms, *features).
912-
gp_ctx: Graph parallel context with precomputed neighbor lists.
913-
send_indices: Local indices of atoms to send.
914-
915-
Returns:
916-
x_received: Remote atom embeddings, shape (total_needed, *features).
917-
"""
918-
if send_indices is None:
919-
raise ValueError(
920-
"send_indices is None — build_gp_context should always "
921-
"compute send_indices. Check GP setup."
922-
)
923-
feature_shape = x_local.shape[1:]
924-
send_splits = gp_ctx.send_splits
925-
recv_splits = gp_ctx.recv_splits
926-
total_recv = gp_ctx.total_recv
927-
928-
# Gather atoms to send
929-
# Note: cannot use torch.index_select with out= because x_local
930-
# may require grad (for force computation), and out= doesn't
931-
# support autograd. Use regular indexing which creates a new tensor
932-
# but supports the backward pass. The send buffer pre-allocation
933-
# is not worth the autograd complexity.
934-
if send_indices.numel() > 0:
935-
x_send = x_local[send_indices].contiguous()
936-
else:
937-
x_send = torch.empty(
938-
0, *feature_shape, device=x_local.device, dtype=x_local.dtype
939-
)
940-
941-
# Reuse pre-allocated recv buffer if available and correct size
942-
recv_shape = (total_recv, *feature_shape)
943-
if (
944-
gp_ctx._recv_buf is not None
945-
and gp_ctx._recv_buf.shape == recv_shape
946-
and gp_ctx._recv_buf.dtype == x_local.dtype
947-
):
948-
x_recv = gp_ctx._recv_buf
949-
else:
950-
x_recv = torch.empty(recv_shape, device=x_local.device, dtype=x_local.dtype)
951-
gp_ctx._recv_buf = x_recv
952-
953-
# Sparse P2P communication: only talk to actual neighbors
954-
gp_group = gp_utils.get_gp_group()
955-
backend = dist.get_backend(gp_group)
956-
957-
if backend == "nccl":
958-
# Split into per-rank chunks (views into contiguous buffers)
959-
send_chunks = list(x_send.split(send_splits))
960-
recv_chunks = list(x_recv.split(recv_splits))
961-
962-
ops = []
963-
# Only create ops for non-zero neighbors
964-
for r in gp_ctx.send_neighbors:
965-
ops.append(dist.P2POp(dist.isend, send_chunks[r], r, group=gp_group))
966-
for r in gp_ctx.recv_neighbors:
967-
ops.append(dist.P2POp(dist.irecv, recv_chunks[r], r, group=gp_group))
968-
969-
if ops:
970-
reqs = dist.batch_isend_irecv(ops)
971-
for req in reqs:
972-
req.wait()
973-
else:
974-
# Gloo fallback: use pairwise send/recv
975-
send_chunks = list(x_send.split(send_splits))
976-
recv_chunks = list(x_recv.split(recv_splits))
977-
_safe_all_to_all(recv_chunks, send_chunks, group=gp_group)
978-
979-
return x_recv

0 commit comments

Comments
 (0)