You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add all-to-all graph parallel communication as alternative to all-gather
Replace the all-gather collective in graph parallel message passing with
an all-to-all alternative that exchanges only needed boundary atoms.
This reduces communication volume from O(N*P) to O(boundary) where
boundary atoms are ~25% of total with spatial partitioning.
Key components:
- graph_parallel.py: GPContext dataclass, Morton Z-order spatial
partitioning, build_gp_context for communication metadata, AllToAllCollect
autograd.Function, and compile-friendly funcoll path
- escn_md.py: Integration into eSCNMDBackbone with use_all_to_all_gp and
gp_partition_strategy config flags, spatial partition computation,
GPContext building, and force reordering for non-consecutive partitions
- escn_md_block.py: Dual-path Edgewise.forward() supporting both all-gather
and all-to-all collection with local edge index remapping
- compute.py: send_info computation during edge filtering to skip NCCL
index-exchange collective
Enabled via backbone config:
+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}
Performance (H200 turbo, UMA-S, 4k atoms/rank weak scaling):
8 GPU: +2.6% (0.6685 vs 0.6513 QPS)
16 GPU: +2.7% (0.6260 vs 0.6097 QPS)
32 GPU: +9.5% (0.6225 vs 0.5683 QPS)
64 GPU: +17.5% (0.5600 vs 0.4766 QPS)
Weak scaling efficiency at 64 GPUs: 83.8% vs 73.2% baseline
Backward compatible: default behavior unchanged (use_all_to_all_gp=False).
All 27 existing gp_utils and escn_md tests continue to pass.
"""Generate a graph representation from atomic structure data.
196
+
"""
197
+
Generate a graph representation from atomic structure data.
128
198
129
199
Args:
130
200
data (dict): A dictionary containing a batch of molecular structures.
@@ -138,6 +208,9 @@ def generate_graph(
138
208
radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA)
139
209
pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc
140
210
node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition.
211
+
rank_assignments: Rank for each atom (for A2A send_info).
212
+
rank: This rank's GP rank (for A2A send_info).
213
+
world_size: GP world size (for A2A send_info).
141
214
142
215
Returns:
143
216
dict: A dictionary containing the generated graph with the following keys:
@@ -147,13 +220,19 @@ def generate_graph(
147
220
- 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge.
148
221
- 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets.
149
222
- 'neighbors' (torch.Tensor): Number of neighbors for each atom.
223
+
- 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided.
150
224
"""
151
225
ifradius_pbc_version==1:
152
226
radius_graph_pbc_fn=radius_graph_pbc
153
227
elifradius_pbc_version==2:
154
228
radius_graph_pbc_fn=radius_graph_pbc_v2
155
229
ifnode_partitionisnotNone:
156
-
data["node_partition"] =node_partition
230
+
# Use setattr for compatibility with SimpleNamespace
231
+
# (used by halo filtering) and regular data dicts.
232
+
try:
233
+
data["node_partition"] =node_partition
234
+
exceptTypeError:
235
+
data.node_partition=node_partition
157
236
elifradius_pbc_version==3:
158
237
radius_graph_pbc_fn=radius_graph_pbc_nvidia
159
238
else:
@@ -167,15 +246,30 @@ def generate_graph(
167
246
pbc=pbc,
168
247
)
169
248
170
-
# for v2 it is still faster right now to not do this post filtering, need to investigate further
249
+
# V2 does its own internal edge filtering when node_partition is set,
250
+
# which is faster than post-filtering. However, this means send_info
251
+
# cannot be computed here for v2 (the full edge_index is needed).
252
+
# Instead, build_gp_context falls back to _sparse_index_exchange
253
+
# (~4ms NCCL collective) when send_info is None. Bypassing v2's
254
+
# internal filter to compute send_info was benchmarked and is ~12ms
255
+
# SLOWER because v2 generates edges for ALL atoms instead of local
0 commit comments