Skip to content

Commit 0fae057

Browse files
author
rgao user
committed
Add all-to-all graph parallel communication as alternative to all-gather
Replace the all-gather collective in graph parallel message passing with an all-to-all alternative that exchanges only needed boundary atoms. This reduces communication volume from O(N*P) to O(boundary) where boundary atoms are ~25% of total with spatial partitioning. Key components: - graph_parallel.py: GPContext dataclass, Morton Z-order spatial partitioning, build_gp_context for communication metadata, AllToAllCollect autograd.Function, and compile-friendly funcoll path - escn_md.py: Integration into eSCNMDBackbone with use_all_to_all_gp and gp_partition_strategy config flags, spatial partition computation, GPContext building, and force reordering for non-consecutive partitions - escn_md_block.py: Dual-path Edgewise.forward() supporting both all-gather and all-to-all collection with local edge index remapping - compute.py: send_info computation during edge filtering to skip NCCL index-exchange collective Enabled via backbone config: +runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}} Performance (H200 turbo, UMA-S, 4k atoms/rank weak scaling): 8 GPU: +2.6% (0.6685 vs 0.6513 QPS) 16 GPU: +2.7% (0.6260 vs 0.6097 QPS) 32 GPU: +9.5% (0.6225 vs 0.5683 QPS) 64 GPU: +17.5% (0.5600 vs 0.4766 QPS) Weak scaling efficiency at 64 GPUs: 83.8% vs 73.2% baseline Backward compatible: default behavior unchanged (use_all_to_all_gp=False). All 27 existing gp_utils and escn_md tests continue to pass.
1 parent 28292f0 commit 0fae057

6 files changed

Lines changed: 2545 additions & 49 deletions

File tree

src/fairchem/core/graph/compute.py

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,89 @@ def filter_edges_by_node_partition(
2222
cell_offsets: torch.Tensor,
2323
neighbors: torch.Tensor,
2424
num_atoms: int,
25-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
26-
"""Filter edges to keep only those where the target atom belongs to the node partition.
27-
edge_index is shape (2, num_edges) where the first row is the source atom index and the second row is the target atom index for each edge
28-
cell_offsets is shape (num_edges, 3)
29-
neighbors is cardinality of the edge_index per system in the batch
25+
rank_assignments: torch.Tensor | None = None,
26+
rank: int | None = None,
27+
world_size: int | None = None,
28+
) -> (
29+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
30+
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
31+
):
32+
"""
33+
Filter edges to keep only those where the target atom belongs to
34+
the node partition.
35+
36+
When rank_assignments, rank, and world_size are provided, also
37+
computes send_info: which local atoms need to be sent to which
38+
ranks for all-to-all graph parallel communication. This exploits
39+
access to the full (pre-filter) edge_index to derive send
40+
metadata locally, eliminating the need for an NCCL index-exchange
41+
collective in build_gp_context.
3042
3143
Args:
32-
node_partition: Tensor of atom indices belonging to the current rank's partition.
33-
edge_index: Edge index tensor of shape (2, num_edges), where row 0 is the source and 1 is the target atom.
34-
cell_offsets: Cell offsets tensor of shape (num_edges, 3).
35-
neighbors: Tensor with edge count per system in the batch (length = num_systems).
36-
num_atoms: Total number of atoms across all batches. Used to create a boolean mask for filtering.
44+
node_partition: Atom indices in the current rank's partition.
45+
edge_index: Full edge index, shape (2, num_edges).
46+
cell_offsets: Cell offsets, shape (num_edges, 3).
47+
neighbors: Edge count per system in the batch.
48+
num_atoms: Total atoms across all batches.
49+
rank_assignments: Rank for each atom, shape (num_atoms,).
50+
If provided along with rank and world_size, send_info
51+
is computed and returned as a 4th element.
52+
rank: This rank's GP rank.
53+
world_size: GP world size.
3754
3855
Returns:
39-
Filtered edge_index, cell_offsets, and neighbors tensors.
56+
Filtered (edge_index, cell_offsets, neighbors).
57+
If rank_assignments is provided, also returns send_info dict
58+
with keys: send_counts, send_indices_global.
4059
"""
4160
target_atoms = edge_index[1]
4261
node_mask = torch.zeros(num_atoms, dtype=torch.bool, device=target_atoms.device)
4362
node_mask[node_partition] = True
4463
local_edge_mask = node_mask[target_atoms]
4564

65+
# Compute send info BEFORE discarding non-local edges.
66+
# An edge (src, tgt) where src is LOCAL and tgt is REMOTE means
67+
# src must be sent to rank_assignments[tgt].
68+
send_info = None
69+
if rank_assignments is not None and rank is not None and world_size is not None:
70+
src_is_local = node_mask[edge_index[0]]
71+
tgt_is_remote = ~local_edge_mask
72+
send_edge_mask = src_is_local & tgt_is_remote
73+
74+
if send_edge_mask.any():
75+
send_src = edge_index[0, send_edge_mask]
76+
send_dst_rank = rank_assignments[edge_index[1, send_edge_mask]]
77+
78+
# Unique (dst_rank, src_atom) pairs, sorted by rank then atom.
79+
# Key layout: dst_rank * num_atoms + src_atom ensures rank-major
80+
# ordering, matching what the index exchange produces.
81+
key = send_dst_rank.to(torch.long) * num_atoms + send_src.to(torch.long)
82+
unique_keys = key.unique(sorted=True)
83+
send_ranks = unique_keys // num_atoms
84+
send_atoms = unique_keys % num_atoms
85+
86+
send_counts = torch.zeros(
87+
world_size, dtype=torch.long, device=edge_index.device
88+
)
89+
send_counts.scatter_add_(
90+
0,
91+
send_ranks,
92+
torch.ones_like(send_ranks),
93+
)
94+
send_info = {
95+
"send_counts": send_counts,
96+
"send_indices_global": send_atoms,
97+
}
98+
else:
99+
send_info = {
100+
"send_counts": torch.zeros(
101+
world_size, dtype=torch.long, device=edge_index.device
102+
),
103+
"send_indices_global": torch.empty(
104+
0, dtype=torch.long, device=edge_index.device
105+
),
106+
}
107+
46108
# Create system index for each edge to track which system each edge belongs to
47109
num_systems = neighbors.shape[0]
48110
edge_system_idx = torch.repeat_interleave(
@@ -55,6 +117,8 @@ def filter_edges_by_node_partition(
55117
if neighbors.shape[0] == 1:
56118
# If there's only one system, we can skip the scatter_add step and just return the count of remaining edges
57119
new_neighbors = local_edge_mask.sum(dtype=neighbors.dtype).unsqueeze(0)
120+
if send_info is not None:
121+
return edge_index, cell_offsets, new_neighbors, send_info
58122
return edge_index, cell_offsets, new_neighbors
59123

60124
filtered_edge_system_idx = edge_system_idx[local_edge_mask]
@@ -69,6 +133,8 @@ def filter_edges_by_node_partition(
69133
torch.ones_like(filtered_edge_system_idx, dtype=neighbors.dtype),
70134
)
71135

136+
if send_info is not None:
137+
return edge_index, cell_offsets, new_neighbors, send_info
72138
return edge_index, cell_offsets, new_neighbors
73139

74140

@@ -123,8 +189,12 @@ def generate_graph(
123189
radius_pbc_version: int,
124190
pbc: torch.Tensor,
125191
node_partition: torch.Tensor | None = None,
192+
rank_assignments: torch.Tensor | None = None,
193+
rank: int | None = None,
194+
world_size: int | None = None,
126195
) -> dict:
127-
"""Generate a graph representation from atomic structure data.
196+
"""
197+
Generate a graph representation from atomic structure data.
128198
129199
Args:
130200
data (dict): A dictionary containing a batch of molecular structures.
@@ -138,6 +208,9 @@ def generate_graph(
138208
radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA)
139209
pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc
140210
node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition.
211+
rank_assignments: Rank for each atom (for A2A send_info).
212+
rank: This rank's GP rank (for A2A send_info).
213+
world_size: GP world size (for A2A send_info).
141214
142215
Returns:
143216
dict: A dictionary containing the generated graph with the following keys:
@@ -147,13 +220,19 @@ def generate_graph(
147220
- 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge.
148221
- 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets.
149222
- 'neighbors' (torch.Tensor): Number of neighbors for each atom.
223+
- 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided.
150224
"""
151225
if radius_pbc_version == 1:
152226
radius_graph_pbc_fn = radius_graph_pbc
153227
elif radius_pbc_version == 2:
154228
radius_graph_pbc_fn = radius_graph_pbc_v2
155229
if node_partition is not None:
156-
data["node_partition"] = node_partition
230+
# Use setattr for compatibility with SimpleNamespace
231+
# (used by halo filtering) and regular data dicts.
232+
try:
233+
data["node_partition"] = node_partition
234+
except TypeError:
235+
data.node_partition = node_partition
157236
elif radius_pbc_version == 3:
158237
radius_graph_pbc_fn = radius_graph_pbc_nvidia
159238
else:
@@ -167,15 +246,30 @@ def generate_graph(
167246
pbc=pbc,
168247
)
169248

170-
# for v2 it is still faster right now to not do this post filtering, need to investigate further
249+
# V2 does its own internal edge filtering when node_partition is set,
250+
# which is faster than post-filtering. However, this means send_info
251+
# cannot be computed here for v2 (the full edge_index is needed).
252+
# Instead, build_gp_context falls back to _sparse_index_exchange
253+
# (~4ms NCCL collective) when send_info is None. Bypassing v2's
254+
# internal filter to compute send_info was benchmarked and is ~12ms
255+
# SLOWER because v2 generates edges for ALL atoms instead of local
256+
# partition.
257+
send_info = None
171258
if node_partition is not None and radius_pbc_version != 2:
172-
edge_index, cell_offsets, neighbors = filter_edges_by_node_partition(
259+
filter_result = filter_edges_by_node_partition(
173260
node_partition,
174261
edge_index,
175262
cell_offsets,
176263
neighbors,
177264
num_atoms=data.pos.shape[0],
265+
rank_assignments=rank_assignments,
266+
rank=rank,
267+
world_size=world_size,
178268
)
269+
if rank_assignments is not None:
270+
edge_index, cell_offsets, neighbors, send_info = filter_result
271+
else:
272+
edge_index, cell_offsets, neighbors = filter_result
179273

180274
out = get_pbc_distances(
181275
data.pos,
@@ -192,11 +286,14 @@ def generate_graph(
192286
cell_offset_distances = out["offsets"]
193287
distance_vec = out["distance_vec"]
194288

195-
return {
289+
result = {
196290
"edge_index": edge_index,
197291
"edge_distance": edge_dist,
198292
"edge_distance_vec": distance_vec,
199293
"cell_offsets": cell_offsets,
200294
"offset_distances": cell_offset_distances,
201295
"neighbors": neighbors,
202296
}
297+
if send_info is not None:
298+
result["send_info"] = send_info
299+
return result

0 commit comments

Comments
 (0)