@@ -119,14 +119,6 @@ class GPContext:
119119 send_splits : list [int ] | None = None
120120 recv_splits : list [int ] | None = None
121121 total_recv : int | None = None
122- # Precomputed sparse neighbor lists for P2P communication
123- # Only includes ranks with non-zero send/recv counts, avoiding
124- # O(world_size) iteration in the communication hot path.
125- send_neighbors : list [int ] | None = None # Ranks we send to (non-zero send)
126- recv_neighbors : list [int ] | None = None # Ranks we recv from (non-zero recv)
127- # Pre-allocated receive buffer (lazily initialized).
128- # Reused across layers within a step to avoid per-layer CUDA malloc.
129- _recv_buf : torch .Tensor | None = None
130122
131123
132124def _expand_bits_10 (v : torch .Tensor ) -> torch .Tensor :
@@ -500,13 +492,6 @@ def build_gp_context(
500492 send_splits = send_splits ,
501493 recv_splits = recv_splits ,
502494 total_recv = total_recv ,
503- # Precompute sparse neighbor lists for communication
504- send_neighbors = [
505- r for r in range (world_size ) if send_splits [r ] > 0 and r != rank
506- ],
507- recv_neighbors = [
508- r for r in range (world_size ) if recv_splits [r ] > 0 and r != rank
509- ],
510495 )
511496
512497
@@ -884,96 +869,3 @@ def all_to_all_collect_compiled(
884869 )
885870
886871 return x_recv
887-
888-
889- @torch .compiler .disable
890- def all_to_all_collect_p2p (
891- x_local : torch .Tensor ,
892- gp_ctx : GPContext ,
893- send_indices : torch .Tensor ,
894- ) -> torch .Tensor :
895- """
896- Collect remote embeddings using sparse P2P communication.
897-
898- Instead of ``all_to_all_single`` (which creates P-1 send/recv pairs
899- even for zero-length messages), this uses ``batch_isend_irecv`` with
900- only the non-zero neighbors. At 64 GPUs with spatial partitioning,
901- each rank typically has ~10-15 actual neighbors, so this reduces
902- the number of NCCL operations from ~63 to ~25.
903-
904- Uses pre-allocated send/recv buffers stored on ``gp_ctx`` to avoid
905- per-layer CUDA memory allocation overhead (4 allocations/step saved).
906-
907- Does NOT participate in autograd — intended for eval-mode inference
908- only. For training, use ``all_to_all_collect`` instead.
909-
910- Args:
911- x_local: Local atom embeddings, shape (local_atoms, *features).
912- gp_ctx: Graph parallel context with precomputed neighbor lists.
913- send_indices: Local indices of atoms to send.
914-
915- Returns:
916- x_received: Remote atom embeddings, shape (total_needed, *features).
917- """
918- if send_indices is None :
919- raise ValueError (
920- "send_indices is None — build_gp_context should always "
921- "compute send_indices. Check GP setup."
922- )
923- feature_shape = x_local .shape [1 :]
924- send_splits = gp_ctx .send_splits
925- recv_splits = gp_ctx .recv_splits
926- total_recv = gp_ctx .total_recv
927-
928- # Gather atoms to send
929- # Note: cannot use torch.index_select with out= because x_local
930- # may require grad (for force computation), and out= doesn't
931- # support autograd. Use regular indexing which creates a new tensor
932- # but supports the backward pass. The send buffer pre-allocation
933- # is not worth the autograd complexity.
934- if send_indices .numel () > 0 :
935- x_send = x_local [send_indices ].contiguous ()
936- else :
937- x_send = torch .empty (
938- 0 , * feature_shape , device = x_local .device , dtype = x_local .dtype
939- )
940-
941- # Reuse pre-allocated recv buffer if available and correct size
942- recv_shape = (total_recv , * feature_shape )
943- if (
944- gp_ctx ._recv_buf is not None
945- and gp_ctx ._recv_buf .shape == recv_shape
946- and gp_ctx ._recv_buf .dtype == x_local .dtype
947- ):
948- x_recv = gp_ctx ._recv_buf
949- else :
950- x_recv = torch .empty (recv_shape , device = x_local .device , dtype = x_local .dtype )
951- gp_ctx ._recv_buf = x_recv
952-
953- # Sparse P2P communication: only talk to actual neighbors
954- gp_group = gp_utils .get_gp_group ()
955- backend = dist .get_backend (gp_group )
956-
957- if backend == "nccl" :
958- # Split into per-rank chunks (views into contiguous buffers)
959- send_chunks = list (x_send .split (send_splits ))
960- recv_chunks = list (x_recv .split (recv_splits ))
961-
962- ops = []
963- # Only create ops for non-zero neighbors
964- for r in gp_ctx .send_neighbors :
965- ops .append (dist .P2POp (dist .isend , send_chunks [r ], r , group = gp_group ))
966- for r in gp_ctx .recv_neighbors :
967- ops .append (dist .P2POp (dist .irecv , recv_chunks [r ], r , group = gp_group ))
968-
969- if ops :
970- reqs = dist .batch_isend_irecv (ops )
971- for req in reqs :
972- req .wait ()
973- else :
974- # Gloo fallback: use pairwise send/recv
975- send_chunks = list (x_send .split (send_splits ))
976- recv_chunks = list (x_recv .split (recv_splits ))
977- _safe_all_to_all (recv_chunks , send_chunks , group = gp_group )
978-
979- return x_recv
0 commit comments