Skip to content

Commit dba967b

Browse files
author
rgao user
committed
Remove communication-computation overlap code
Benchmarked overlap approach (splitting forward_chunk into local + boundary edge passes while A2A runs async) on H200 turbo mode: - 8 GPU: ~17% slower than non-overlap A2A - 16 GPU: ~12% slower than non-overlap A2A Root cause: splitting forward_chunk into two calls loses torch.compile kernel fusion efficiency, costing more than the ~2ms communication latency hidden per layer. The overhead from two separate SO2Conv passes (each with its own kernel launch, sync, scatter/gather) dominates. Removed: - _forward_overlap() method from Edgewise - start_all_to_all_collect() / finish_all_to_all_collect() from graph_parallel - Edge reorder pre-sorting from escn_md backbone - Overlap-related test functions
1 parent f53e1e7 commit dba967b

4 files changed

Lines changed: 0 additions & 439 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import dataclasses
1110
import logging
1211
import os
1312
import types
@@ -1013,25 +1012,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
10131012
with record_function("layer_radial_emb"):
10141013
x_edge_per_layer = self.backend.get_layer_radial_emb(x_edge, self)
10151014

1016-
# When overlap is enabled, pre-sort all per-edge tensors so local
1017-
# edges come first and boundary edges last. This lets the overlap
1018-
# path use compile-friendly split() instead of boolean indexing.
1019-
if (
1020-
gp_ctx is not None
1021-
and gp_ctx.edge_reorder is not None
1022-
and self.use_overlap_gp
1023-
):
1024-
reorder = gp_ctx.edge_reorder
1025-
wigner = wigner[reorder]
1026-
wigner_inv_envelope = wigner_inv_envelope[reorder]
1027-
x_edge_per_layer = [xl[reorder] for xl in x_edge_per_layer]
1028-
gp_ctx = dataclasses.replace(
1029-
gp_ctx,
1030-
edge_index_local=gp_ctx.edge_index_local[:, reorder],
1031-
local_edge_mask=gp_ctx.local_edge_mask[reorder],
1032-
edge_reorder=None, # consumed; don't reorder again
1033-
)
1034-
10351015
for i in range(self.num_layers):
10361016
with record_function(f"message passing {i}"):
10371017
x_message = self.blocks[i](

src/fairchem/core/models/uma/escn_md_block.py

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
GPContext,
2121
all_to_all_collect,
2222
all_to_all_collect_compiled,
23-
finish_all_to_all_collect,
24-
start_all_to_all_collect,
2523
)
2624
from fairchem.core.models.uma.nn.activation import (
2725
GateActivation,
@@ -138,38 +136,7 @@ def forward(
138136
139137
When gp_ctx is provided, uses all-to-all to collect only the
140138
needed remote embeddings. Otherwise falls back to all-gather.
141-
142-
When use_overlap_gp is True and in eval mode, overlaps
143-
communication with local edge computation for better latency.
144139
"""
145-
# Check if we should use the overlapped path:
146-
# - gp_ctx must be provided (all-to-all mode)
147-
# - use_overlap_gp must be enabled
148-
# - must NOT be in training mode (overlap path doesn't support autograd)
149-
# - must NOT need gradients (autograd forces/stress require
150-
# autograd-compatible communication, overlap path doesn't provide this)
151-
# - must NOT use activation checkpointing (incompatible with edge split)
152-
# - must have both local and boundary edges
153-
needs_grad_for_overlap = torch.is_grad_enabled() and (
154-
x.requires_grad if isinstance(x, torch.Tensor) else False
155-
)
156-
use_overlap = (
157-
self.use_overlap_gp
158-
and gp_ctx is not None
159-
and gp_utils.initialized()
160-
and not self.training
161-
and not needs_grad_for_overlap
162-
and self.activation_checkpoint_chunk_size is None
163-
and gp_ctx.local_edge_mask is not None
164-
and gp_ctx.num_local_edges > 0
165-
and gp_ctx.num_boundary_edges > 0
166-
)
167-
168-
if use_overlap:
169-
return self._forward_overlap(
170-
x, x_edge, wigner, wigner_inv_envelope, gp_ctx, send_indices
171-
)
172-
173140
if gp_ctx is not None and gp_utils.initialized():
174141
# All-to-all path: collect only needed remote embeddings.
175142
# When x requires grad (autograd forces/stress), we use the
@@ -252,85 +219,6 @@ def forward(
252219
new_embeddings = [torch.stack(new_embeddings).sum(axis=0)]
253220
return torch.stack(new_embeddings).sum(axis=0)
254221

255-
def _forward_overlap(
256-
self,
257-
x,
258-
x_edge,
259-
wigner,
260-
wigner_inv_envelope,
261-
gp_ctx: GPContext,
262-
send_indices: torch.Tensor | None,
263-
):
264-
"""
265-
Overlapped communication-computation forward pass.
266-
267-
Overlaps the all-to-all communication with local edge
268-
computation for better inference latency. Only used in
269-
eval mode (no autograd through the communication).
270-
271-
Edges are pre-sorted in build_gp_context (local edges first,
272-
boundary edges last via edge_reorder). This allows using
273-
compile-friendly split() instead of boolean indexing.
274-
275-
Steps:
276-
1. Start async all-to-all to exchange boundary embeddings.
277-
2. Compute local edges (both endpoints are local atoms)
278-
while communication is in flight.
279-
3. Wait for communication to complete.
280-
4. Compute boundary edges (source is remote).
281-
5. Sum local + boundary contributions.
282-
"""
283-
edge_index_local = gp_ctx.edge_index_local
284-
num_local_atoms = x.shape[0]
285-
n_local = gp_ctx.num_local_edges
286-
287-
# Split pre-sorted per-edge data: local first, boundary last.
288-
# No boolean indexing — compile-friendly.
289-
local_edge_idx = edge_index_local[:, :n_local]
290-
boundary_edge_idx = edge_index_local[:, n_local:]
291-
local_x_edge = x_edge[:n_local]
292-
boundary_x_edge = x_edge[n_local:]
293-
local_wigner = wigner[:n_local]
294-
boundary_wigner = wigner[n_local:]
295-
local_wigner_inv = wigner_inv_envelope[:n_local]
296-
boundary_wigner_inv = wigner_inv_envelope[n_local:]
297-
298-
# Step 1: Start async all-to-all
299-
with record_function("a2a_collect_async_start"):
300-
recv_buf, work_handles = start_all_to_all_collect(x, gp_ctx, send_indices)
301-
302-
# Step 2: Compute local edges while comm is in flight
303-
with record_function("local_edges"):
304-
local_contribution = self.forward_chunk(
305-
x,
306-
num_local_atoms,
307-
local_x_edge,
308-
local_edge_idx,
309-
local_wigner,
310-
local_wigner_inv,
311-
0,
312-
)
313-
314-
# Step 3: Wait for communication
315-
with record_function("a2a_collect_async_wait"):
316-
x_received = finish_all_to_all_collect(recv_buf, work_handles)
317-
x_full = torch.cat([x, x_received], dim=0)
318-
319-
# Step 4: Compute boundary edges
320-
with record_function("boundary_edges"):
321-
boundary_contribution = self.forward_chunk(
322-
x_full,
323-
num_local_atoms,
324-
boundary_x_edge,
325-
boundary_edge_idx,
326-
boundary_wigner,
327-
boundary_wigner_inv,
328-
0,
329-
)
330-
331-
# Step 5: Sum contributions
332-
return local_contribution + boundary_contribution
333-
334222
def forward_chunk(
335223
self,
336224
x_full,

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,120 +1014,3 @@ def all_to_all_collect_p2p(
10141014
_safe_all_to_all(recv_chunks, send_chunks, group=gp_group)
10151015

10161016
return x_recv
1017-
1018-
1019-
@torch.compiler.disable
1020-
def start_all_to_all_collect(
1021-
x_local: torch.Tensor,
1022-
gp_ctx: GPContext,
1023-
send_indices: torch.Tensor,
1024-
) -> tuple[torch.Tensor, list[dist.Work]]:
1025-
"""
1026-
Start async all-to-all communication for comm-compute overlap.
1027-
1028-
Launches the all-to-all without waiting for completion. Returns
1029-
the pre-allocated receive buffer and work handles. The caller
1030-
should do useful compute, then call ``finish_all_to_all_collect``
1031-
to wait for completion and get the received embeddings.
1032-
1033-
This function does NOT participate in autograd. For differentiable
1034-
all-to-all, use ``all_to_all_collect`` instead. This async variant
1035-
is intended for the overlap path where gradients are handled
1036-
separately.
1037-
1038-
Uses ``all_to_all_single`` on NCCL for efficiency (avoids Python
1039-
list creation from ``split()``).
1040-
1041-
Args:
1042-
x_local: Local atom embeddings, shape (local_atoms, *features).
1043-
gp_ctx: Graph parallel context.
1044-
send_indices: Local indices of atoms to send.
1045-
1046-
Returns:
1047-
Tuple of (recv_buffer, work_handles):
1048-
recv_buffer: Pre-allocated tensor for received embeddings.
1049-
work_handles: List of dist.Work handles to wait on.
1050-
"""
1051-
feature_shape = x_local.shape[1:]
1052-
1053-
# Gather atoms to send
1054-
if send_indices.numel() > 0:
1055-
x_send = x_local[send_indices].contiguous()
1056-
else:
1057-
x_send = torch.empty(
1058-
0, *feature_shape, device=x_local.device, dtype=x_local.dtype
1059-
)
1060-
1061-
# Use precomputed splits if available (avoids .tolist() per layer)
1062-
send_splits = (
1063-
gp_ctx.send_splits
1064-
if gp_ctx.send_splits is not None
1065-
else gp_ctx.send_counts.tolist()
1066-
)
1067-
recv_splits = (
1068-
gp_ctx.recv_splits
1069-
if gp_ctx.recv_splits is not None
1070-
else gp_ctx.recv_counts.tolist()
1071-
)
1072-
total_recv = (
1073-
gp_ctx.total_recv if gp_ctx.total_recv is not None else sum(recv_splits)
1074-
)
1075-
x_recv = torch.empty(
1076-
total_recv, *feature_shape, device=x_local.device, dtype=x_local.dtype
1077-
)
1078-
1079-
# Launch async all-to-all
1080-
gp_group = gp_utils.get_gp_group()
1081-
backend = dist.get_backend(gp_group)
1082-
1083-
work_handles = []
1084-
if backend == "nccl":
1085-
# Use all_to_all_single for NCCL — packed tensor, no list creation
1086-
work = dist.all_to_all_single(
1087-
x_recv,
1088-
x_send,
1089-
output_split_sizes=recv_splits,
1090-
input_split_sizes=send_splits,
1091-
group=gp_group,
1092-
async_op=True,
1093-
)
1094-
work_handles.append(work)
1095-
else:
1096-
# Gloo fallback: use pairwise send/recv
1097-
send_list = list(x_send.split(send_splits))
1098-
recv_list = list(x_recv.split(recv_splits))
1099-
rank = dist.get_rank(gp_group)
1100-
world_size = dist.get_world_size(gp_group)
1101-
ops = []
1102-
for r in range(world_size):
1103-
if r == rank:
1104-
if send_list[r].numel() > 0:
1105-
recv_list[r].copy_(send_list[r])
1106-
elif send_list[r].numel() > 0 or recv_list[r].numel() > 0:
1107-
# Skip zero-length P2P ops to avoid potential hangs
1108-
ops.append(dist.P2POp(dist.isend, send_list[r], r, group=gp_group))
1109-
ops.append(dist.P2POp(dist.irecv, recv_list[r], r, group=gp_group))
1110-
if ops:
1111-
work_handles = dist.batch_isend_irecv(ops)
1112-
1113-
return x_recv, work_handles
1114-
1115-
1116-
@torch.compiler.disable
1117-
def finish_all_to_all_collect(
1118-
recv_buffer: torch.Tensor,
1119-
work_handles: list[dist.Work],
1120-
) -> torch.Tensor:
1121-
"""
1122-
Wait for async all-to-all to complete and return received embeddings.
1123-
1124-
Args:
1125-
recv_buffer: Pre-allocated receive buffer from start_all_to_all_collect.
1126-
work_handles: Work handles from start_all_to_all_collect.
1127-
1128-
Returns:
1129-
x_received: Received remote atom embeddings.
1130-
"""
1131-
for work in work_handles:
1132-
work.wait()
1133-
return recv_buffer

0 commit comments

Comments
 (0)