Skip to content

Commit 4b6289d

Browse files
committed
Eliminate NCCL index-exchange from A2A setup via local send_info computation
Compute send_info (which local atoms to send to which ranks) directly in filter_edges_by_node_partition using the full pre-filter edge_index, instead of requiring an NCCL all_to_all collective in build_gp_context. This removes the most expensive collective from the per-step setup path. Also fixes Morton Z-order balanced partition (i*P//N instead of ceil-based) and adds record_function tracing annotations. Validated: 27 graph_parallel + 10 escn_md tests pass. Benchmarked: 4.7-11% speedup at 64 GPUs (exp 18), parity at 8 GPUs (exp 17).
1 parent 6c95830 commit 4b6289d

4 files changed

Lines changed: 294 additions & 41 deletions

File tree

src/fairchem/core/graph/compute.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,89 @@ def filter_edges_by_node_partition(
2222
cell_offsets: torch.Tensor,
2323
neighbors: torch.Tensor,
2424
num_atoms: int,
25-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
26-
"""Filter edges to keep only those where the target atom belongs to the node partition.
27-
edge_index is shape (2, num_edges) where the first row is the source atom index and the second row is the target atom index for each edge
28-
cell_offsets is shape (num_edges, 3)
29-
neighbors is cardinality of the edge_index per system in the batch
25+
rank_assignments: torch.Tensor | None = None,
26+
rank: int | None = None,
27+
world_size: int | None = None,
28+
) -> (
29+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
30+
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
31+
):
32+
"""
33+
Filter edges to keep only those where the target atom belongs to
34+
the node partition.
35+
36+
When rank_assignments, rank, and world_size are provided, also
37+
computes send_info: which local atoms need to be sent to which
38+
ranks for all-to-all graph parallel communication. This exploits
39+
access to the full (pre-filter) edge_index to derive send
40+
metadata locally, eliminating the need for an NCCL index-exchange
41+
collective in build_gp_context.
3042
3143
Args:
32-
node_partition: Tensor of atom indices belonging to the current rank's partition.
33-
edge_index: Edge index tensor of shape (2, num_edges), where row 0 is the source and 1 is the target atom.
34-
cell_offsets: Cell offsets tensor of shape (num_edges, 3).
35-
neighbors: Tensor with edge count per system in the batch (length = num_systems).
36-
num_atoms: Total number of atoms across all batches. Used to create a boolean mask for filtering.
44+
node_partition: Atom indices in the current rank's partition.
45+
edge_index: Full edge index, shape (2, num_edges).
46+
cell_offsets: Cell offsets, shape (num_edges, 3).
47+
neighbors: Edge count per system in the batch.
48+
num_atoms: Total atoms across all batches.
49+
rank_assignments: Rank for each atom, shape (num_atoms,).
50+
If provided along with rank and world_size, send_info
51+
is computed and returned as a 4th element.
52+
rank: This rank's GP rank.
53+
world_size: GP world size.
3754
3855
Returns:
39-
Filtered edge_index, cell_offsets, and neighbors tensors.
56+
Filtered (edge_index, cell_offsets, neighbors).
57+
If rank_assignments is provided, also returns send_info dict
58+
with keys: send_counts, send_indices_global.
4059
"""
4160
target_atoms = edge_index[1]
4261
node_mask = torch.zeros(num_atoms, dtype=torch.bool, device=target_atoms.device)
4362
node_mask[node_partition] = True
4463
local_edge_mask = node_mask[target_atoms]
4564

65+
# Compute send info BEFORE discarding non-local edges.
66+
# An edge (src, tgt) where src is LOCAL and tgt is REMOTE means
67+
# src must be sent to rank_assignments[tgt].
68+
send_info = None
69+
if rank_assignments is not None and rank is not None and world_size is not None:
70+
src_is_local = node_mask[edge_index[0]]
71+
tgt_is_remote = ~local_edge_mask
72+
send_edge_mask = src_is_local & tgt_is_remote
73+
74+
if send_edge_mask.any():
75+
send_src = edge_index[0, send_edge_mask]
76+
send_dst_rank = rank_assignments[edge_index[1, send_edge_mask]]
77+
78+
# Unique (dst_rank, src_atom) pairs, sorted by rank then atom.
79+
# Key layout: dst_rank * num_atoms + src_atom ensures rank-major
80+
# ordering, matching what _fused_index_exchange produces.
81+
key = send_dst_rank.to(torch.long) * num_atoms + send_src.to(torch.long)
82+
unique_keys = key.unique(sorted=True)
83+
send_ranks = unique_keys // num_atoms
84+
send_atoms = unique_keys % num_atoms
85+
86+
send_counts = torch.zeros(
87+
world_size, dtype=torch.long, device=edge_index.device
88+
)
89+
send_counts.scatter_add_(
90+
0,
91+
send_ranks,
92+
torch.ones_like(send_ranks),
93+
)
94+
send_info = {
95+
"send_counts": send_counts,
96+
"send_indices_global": send_atoms,
97+
}
98+
else:
99+
send_info = {
100+
"send_counts": torch.zeros(
101+
world_size, dtype=torch.long, device=edge_index.device
102+
),
103+
"send_indices_global": torch.empty(
104+
0, dtype=torch.long, device=edge_index.device
105+
),
106+
}
107+
46108
# Create system index for each edge to track which system each edge belongs to
47109
num_systems = neighbors.shape[0]
48110
edge_system_idx = torch.repeat_interleave(
@@ -55,6 +117,8 @@ def filter_edges_by_node_partition(
55117
if neighbors.shape[0] == 1:
56118
# If there's only one system, we can skip the scatter_add step and just return the count of remaining edges
57119
new_neighbors = local_edge_mask.sum(dtype=neighbors.dtype).unsqueeze(0)
120+
if send_info is not None:
121+
return edge_index, cell_offsets, new_neighbors, send_info
58122
return edge_index, cell_offsets, new_neighbors
59123

60124
filtered_edge_system_idx = edge_system_idx[local_edge_mask]
@@ -69,6 +133,8 @@ def filter_edges_by_node_partition(
69133
torch.ones_like(filtered_edge_system_idx, dtype=neighbors.dtype),
70134
)
71135

136+
if send_info is not None:
137+
return edge_index, cell_offsets, new_neighbors, send_info
72138
return edge_index, cell_offsets, new_neighbors
73139

74140

@@ -123,8 +189,12 @@ def generate_graph(
123189
radius_pbc_version: int,
124190
pbc: torch.Tensor,
125191
node_partition: torch.Tensor | None = None,
192+
rank_assignments: torch.Tensor | None = None,
193+
rank: int | None = None,
194+
world_size: int | None = None,
126195
) -> dict:
127-
"""Generate a graph representation from atomic structure data.
196+
"""
197+
Generate a graph representation from atomic structure data.
128198
129199
Args:
130200
data (dict): A dictionary containing a batch of molecular structures.
@@ -138,6 +208,9 @@ def generate_graph(
138208
radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA)
139209
pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc
140210
node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition.
211+
rank_assignments: Rank for each atom (for A2A send_info).
212+
rank: This rank's GP rank (for A2A send_info).
213+
world_size: GP world size (for A2A send_info).
141214
142215
Returns:
143216
dict: A dictionary containing the generated graph with the following keys:
@@ -147,6 +220,7 @@ def generate_graph(
147220
- 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge.
148221
- 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets.
149222
- 'neighbors' (torch.Tensor): Number of neighbors for each atom.
223+
- 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided.
150224
"""
151225
if radius_pbc_version == 1:
152226
radius_graph_pbc_fn = radius_graph_pbc
@@ -168,14 +242,22 @@ def generate_graph(
168242
)
169243

170244
# for v2 it is still faster right now to not do this post filtering, need to investigate further
245+
send_info = None
171246
if node_partition is not None and radius_pbc_version != 2:
172-
edge_index, cell_offsets, neighbors = filter_edges_by_node_partition(
247+
filter_result = filter_edges_by_node_partition(
173248
node_partition,
174249
edge_index,
175250
cell_offsets,
176251
neighbors,
177252
num_atoms=data.pos.shape[0],
253+
rank_assignments=rank_assignments,
254+
rank=rank,
255+
world_size=world_size,
178256
)
257+
if rank_assignments is not None:
258+
edge_index, cell_offsets, neighbors, send_info = filter_result
259+
else:
260+
edge_index, cell_offsets, neighbors = filter_result
179261

180262
out = get_pbc_distances(
181263
data.pos,
@@ -192,11 +274,14 @@ def generate_graph(
192274
cell_offset_distances = out["offsets"]
193275
distance_vec = out["distance_vec"]
194276

195-
return {
277+
result = {
196278
"edge_index": edge_index,
197279
"edge_distance": edge_dist,
198280
"edge_distance_vec": distance_vec,
199281
"cell_offsets": cell_offsets,
200282
"offset_distances": cell_offset_distances,
201283
"neighbors": neighbors,
202284
}
285+
if send_info is not None:
286+
result["send_info"] = send_info
287+
return result

src/fairchem/core/models/uma/escn_md_block.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,10 @@ def forward(
172172
local_node_offset = 0
173173
elif gp_utils.initialized():
174174
# Legacy all-gather path
175-
x_full = gp_utils.gather_from_model_parallel_region_sum_grad(
176-
x, total_atoms_across_gp_ranks
177-
)
175+
with record_function("allgather_collect"):
176+
x_full = gp_utils.gather_from_model_parallel_region_sum_grad(
177+
x, total_atoms_across_gp_ranks
178+
)
178179
edge_index_local = edge_index
179180
local_node_offset = node_offset
180181
else:

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from torch import distributed as dist
15+
from torch.profiler import record_function
1516

1617
from fairchem.core.common import gp_utils
1718

@@ -199,15 +200,15 @@ def partition_atoms_spatial(
199200
x, y, z = norm[:, 0], norm[:, 1], norm[:, 2]
200201
morton = _expand_bits_10(x) | (_expand_bits_10(y) << 1) | (_expand_bits_10(z) << 2)
201202

202-
# Sort by Morton code and assign to ranks in equal chunks
203+
# Sort by Morton code and assign to ranks in balanced chunks.
204+
# Use ``i * P // N`` mapping (not ``i // ceil(N/P)``) to ensure
205+
# EVERY rank receives at least ``floor(N/P)`` atoms. The ceil-based
206+
# formula leaves trailing ranks empty when N is not a multiple of P
207+
# (e.g. 1000 atoms / 64 ranks → rank 63 gets 0 atoms, causing a
208+
# hang in collective communication).
203209
_, sorted_indices = morton.sort()
204-
chunk_size = (N + num_ranks - 1) // num_ranks
205210
assignments = torch.empty(N, dtype=torch.long, device=device)
206-
assignments[sorted_indices] = torch.div(
207-
torch.arange(N, device=device),
208-
chunk_size,
209-
rounding_mode="floor",
210-
).clamp(max=num_ranks - 1)
211+
assignments[sorted_indices] = torch.arange(N, device=device) * num_ranks // N
211212

212213
return assignments
213214

@@ -276,11 +277,13 @@ def partition_atoms_index_split(
276277
return assignments
277278

278279

280+
@torch.compiler.disable
279281
def build_gp_context(
280282
edge_index: torch.Tensor,
281283
rank_assignments: torch.Tensor,
282284
rank: int,
283285
world_size: int,
286+
send_info: dict | None = None,
284287
) -> GPContext:
285288
"""
286289
Build the GP context from edge connectivity and atom assignments.
@@ -289,17 +292,26 @@ def build_gp_context(
289292
other ranks), exchanges atom indices via a single fused all-to-all,
290293
and computes all communication metadata.
291294
292-
Uses a single padded all-to-all collective (instead of the previous
293-
2-step approach of count exchange + index exchange) by padding atom
294-
index lists to a fixed size per rank. This halves the number of
295-
collective operations in the setup path.
295+
When send_info is provided (pre-computed during graph filtering in
296+
filter_edges_by_node_partition), the NCCL index-exchange collective
297+
is skipped entirely — send_counts and send_indices_global are taken
298+
directly from send_info. This eliminates the most expensive collective
299+
in the setup path.
296300
297301
Args:
298-
edge_index: Full graph edge index, shape (2, num_edges).
302+
edge_index: Edge index filtered to edges whose targets are in
303+
this rank's partition, shape (2, num_local_edges).
299304
Row 0 = source, row 1 = target.
300305
rank_assignments: Rank assignment for each atom, shape (total_atoms,).
301306
rank: This rank's GP rank.
302307
world_size: GP world size.
308+
send_info: Pre-computed send metadata from graph filtering.
309+
If provided, must contain:
310+
- send_counts: Tensor of shape (world_size,) with count of
311+
atoms to send to each rank.
312+
- send_indices_global: Tensor of global atom indices to send,
313+
sorted by destination rank.
314+
When provided, _fused_index_exchange is skipped.
303315
304316
Returns:
305317
GPContext with all metadata needed for all-to-all communication.
@@ -341,15 +353,21 @@ def build_gp_context(
341353

342354
# Fused count + index exchange: single padded all-to-all replaces
343355
# the old 2-step approach (count exchange + index exchange).
344-
send_counts, send_indices_global = _fused_index_exchange(
345-
needed_atoms=needed_atoms,
346-
needed_from_ranks=needed_from_ranks,
347-
recv_counts=recv_counts,
348-
rank=rank,
349-
world_size=world_size,
350-
total_atoms=total_atoms,
351-
device=device,
352-
)
356+
if send_info is not None:
357+
# Pre-computed during graph filtering — skip NCCL collective.
358+
send_counts = send_info["send_counts"]
359+
send_indices_global = send_info["send_indices_global"]
360+
else:
361+
with record_function("a2a_fused_index_exchange"):
362+
send_counts, send_indices_global = _fused_index_exchange(
363+
needed_atoms=needed_atoms,
364+
needed_from_ranks=needed_from_ranks,
365+
recv_counts=recv_counts,
366+
rank=rank,
367+
world_size=world_size,
368+
total_atoms=total_atoms,
369+
device=device,
370+
)
353371

354372
# Build global_to_local mapping:
355373
# Local atoms: index 0..total_local_atoms-1 (in order of node_partition)

0 commit comments

Comments
 (0)