Skip to content

Commit 69a64c4

Browse files
author
rgao user
committed
Replace CPU recursive bisection with GPU-only Morton Z-order partitioning
Eliminates the pos.cpu() / assignments.to(device) round-trip that was the dominant setup cost in the A2A path. Morton Z-order curve provides comparable spatial locality to recursive coordinate bisection (O(N^{2/3}) boundary fraction) while running entirely on GPU with no sync points. Key changes: - _expand_bits_10(): bit-spread helper for 30-bit Morton code encoding - partition_atoms_spatial(): now uses Morton sort + equal chunking (no CPU) - build_gp_context(): batch ALL GPU->CPU transfers into single .cpu() call (send_counts + recv_counts + local_edge_count in one tensor) - Uses global normalization (max extent) instead of per-dimension to preserve aspect ratio and spatial locality _recursive_bisect() retained for reference but no longer called.
1 parent b40f6f2 commit 69a64c4

1 file changed

Lines changed: 62 additions & 30 deletions

File tree

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,26 +129,44 @@ class GPContext:
129129
total_recv: int | None = None
130130

131131

132+
def _expand_bits_10(v: torch.Tensor) -> torch.Tensor:
133+
"""
134+
Expand a 10-bit integer so each bit is spaced by 2 zero bits.
135+
136+
Maps bit at position i to position 3*i, producing a 30-bit
137+
output suitable for interleaving with two other axes to form
138+
a Morton Z-order code.
139+
140+
Args:
141+
v: Integer tensor with values in [0, 1023].
142+
143+
Returns:
144+
Tensor with bits expanded (each input bit at position 3*i).
145+
"""
146+
v = (v | (v << 16)) & 0x030000FF
147+
v = (v | (v << 8)) & 0x0300F00F
148+
v = (v | (v << 4)) & 0x030C30C3
149+
v = (v | (v << 2)) & 0x09249249
150+
return v
151+
152+
132153
def partition_atoms_spatial(
133154
pos: torch.Tensor,
134155
num_ranks: int,
135156
num_iters: int = 10,
136157
) -> torch.Tensor:
137158
"""
138-
Spatial partitioning via recursive coordinate bisection.
159+
Spatial partitioning via Morton Z-order curve on GPU.
139160
140-
Recursively halves the atom set along the longest axis of each
141-
sub-group until the desired number of partitions is reached.
142-
Guarantees balanced partition sizes (differ by at most 1) and
143-
runs in O(N log P) time with no iterative refinement.
161+
Computes a 30-bit Morton code per atom by interleaving 10 bits
162+
from each spatial axis. Sorting by Morton code groups spatially
163+
nearby atoms together, minimizing boundary edges. Atoms are
164+
then split into num_ranks equal chunks in sorted order.
144165
145-
Falls back to the longest-axis sort-and-split for non-power-of-2
146-
ranks, which is O(N log N).
147-
148-
All computation is done on CPU to avoid GPU→CPU sync points
149-
from .argmax().item() calls in the recursive bisection (up to
150-
P-1 sync points eliminated). The position tensor is small
151-
(N x 3 floats, ~768KB at 64k atoms), so the transfer is fast.
166+
Runs entirely on GPU with zero CPU transfers or sync points
167+
(unlike recursive coordinate bisection which requires CPU
168+
round-trips). The Morton curve provides O(N^{2/3}) surface
169+
fraction per partition, similar to recursive bisection.
152170
153171
Args:
154172
pos: Atom positions, shape (N, 3).
@@ -168,19 +186,30 @@ def partition_atoms_spatial(
168186
if num_ranks >= N:
169187
return torch.arange(N, dtype=torch.long, device=device)
170188

171-
# Move to CPU to avoid GPU->CPU sync points in recursive bisection.
172-
# Each recursion level calls .argmax().item() which forces a sync.
173-
# For P=64, that's 63 sync points (~10-20us each = 0.6-1.2ms).
174-
# CPU sort on 64k atoms takes <0.1ms, so this is a net win.
175-
pos_cpu = pos.detach().cpu()
176-
assignments = torch.zeros(N, dtype=torch.long)
177-
178-
# Recursive coordinate bisection: O(N log P), no iterations.
179-
# Handles both power-of-2 and non-power-of-2 rank counts via
180-
# uneven splits (right_half = num_parts - left_half).
181-
_recursive_bisect(pos_cpu, assignments, torch.arange(N), 0, num_ranks)
189+
# Normalize positions to [0, 1023] using a SINGLE global scale
190+
# factor (the largest bounding-box extent). Per-dimension
191+
# normalization would amplify noise in short dimensions, breaking
192+
# Morton locality (e.g. a 100-unit x-gap becomes indistinguishable
193+
# from a 2-unit y-gap after independent rescaling).
194+
min_pos = pos.min(0)[0]
195+
extent = (pos.max(0)[0] - min_pos).max().clamp(min=1e-8)
196+
norm = ((pos - min_pos) / extent * 1023).long().clamp(0, 1023)
197+
198+
# 30-bit Morton Z-order code: interleave x, y, z bits
199+
x, y, z = norm[:, 0], norm[:, 1], norm[:, 2]
200+
morton = _expand_bits_10(x) | (_expand_bits_10(y) << 1) | (_expand_bits_10(z) << 2)
201+
202+
# Sort by Morton code and assign to ranks in equal chunks
203+
_, sorted_indices = morton.sort()
204+
chunk_size = (N + num_ranks - 1) // num_ranks
205+
assignments = torch.empty(N, dtype=torch.long, device=device)
206+
assignments[sorted_indices] = torch.div(
207+
torch.arange(N, device=device),
208+
chunk_size,
209+
rounding_mode="floor",
210+
).clamp(max=num_ranks - 1)
182211

183-
return assignments.to(device)
212+
return assignments
184213

185214

186215
def _recursive_bisect(
@@ -358,13 +387,16 @@ def build_gp_context(
358387
tgt_is_local_edge = edge_index_local[1] < total_local_atoms
359388
local_edge_mask = src_is_local & tgt_is_local_edge
360389

361-
# Batch scalar extractions to minimize GPU→CPU sync points:
362-
# move counts to CPU in one transfer instead of per-element syncs.
363-
send_recv_cpu = torch.stack([send_counts, recv_counts]).cpu()
364-
send_splits = send_recv_cpu[0].tolist()
365-
recv_splits = send_recv_cpu[1].tolist()
390+
# Batch ALL GPU→CPU scalar extractions into a single transfer.
391+
# Stacking send_counts, recv_counts, and local_edge_count into
392+
# one tensor avoids 3 separate sync points (each .cpu()/.item()
393+
# is a GPU→CPU sync).
394+
local_edge_count = local_edge_mask.sum().unsqueeze(0).to(torch.long)
395+
all_cpu = torch.cat([send_counts, recv_counts, local_edge_count]).cpu()
396+
send_splits = all_cpu[:world_size].tolist()
397+
recv_splits = all_cpu[world_size : 2 * world_size].tolist()
366398
total_recv = sum(recv_splits)
367-
num_local_edges = int(local_edge_mask.sum().cpu().item())
399+
num_local_edges = int(all_cpu[-1].item())
368400
num_boundary_edges = edge_index_local.shape[1] - num_local_edges
369401

370402
return GPContext(

0 commit comments

Comments
 (0)