You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Eliminate NCCL index-exchange from A2A setup via local send_info computation
Compute send_info (which local atoms to send to which ranks) directly in
filter_edges_by_node_partition using the full pre-filter edge_index, instead
of requiring an NCCL all_to_all collective in build_gp_context. This removes
the most expensive collective from the per-step setup path.
Also fixes Morton Z-order balanced partition (i*P//N instead of ceil-based)
and adds record_function tracing annotations.
Validated: 27 graph_parallel + 10 escn_md tests pass.
Benchmarked: 4.7-11% speedup at 64 GPUs (exp 18), parity at 8 GPUs (exp 17).
"""Generate a graph representation from atomic structure data.
196
+
"""
197
+
Generate a graph representation from atomic structure data.
128
198
129
199
Args:
130
200
data (dict): A dictionary containing a batch of molecular structures.
@@ -138,6 +208,9 @@ def generate_graph(
138
208
radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA)
139
209
pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc
140
210
node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition.
211
+
rank_assignments: Rank for each atom (for A2A send_info).
212
+
rank: This rank's GP rank (for A2A send_info).
213
+
world_size: GP world size (for A2A send_info).
141
214
142
215
Returns:
143
216
dict: A dictionary containing the generated graph with the following keys:
@@ -147,6 +220,7 @@ def generate_graph(
147
220
- 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge.
148
221
- 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets.
149
222
- 'neighbors' (torch.Tensor): Number of neighbors for each atom.
223
+
- 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided.
150
224
"""
151
225
ifradius_pbc_version==1:
152
226
radius_graph_pbc_fn=radius_graph_pbc
@@ -168,14 +242,22 @@ def generate_graph(
168
242
)
169
243
170
244
# for v2 it is still faster right now to not do this post filtering, need to investigate further
0 commit comments