@@ -129,26 +129,44 @@ class GPContext:
129129 total_recv : int | None = None
130130
131131
132+ def _expand_bits_10 (v : torch .Tensor ) -> torch .Tensor :
133+ """
134+ Expand a 10-bit integer so each bit is spaced by 2 zero bits.
135+
136+ Maps bit at position i to position 3*i, producing a 30-bit
137+ output suitable for interleaving with two other axes to form
138+ a Morton Z-order code.
139+
140+ Args:
141+ v: Integer tensor with values in [0, 1023].
142+
143+ Returns:
144+ Tensor with bits expanded (each input bit at position 3*i).
145+ """
146+ v = (v | (v << 16 )) & 0x030000FF
147+ v = (v | (v << 8 )) & 0x0300F00F
148+ v = (v | (v << 4 )) & 0x030C30C3
149+ v = (v | (v << 2 )) & 0x09249249
150+ return v
151+
152+
132153def partition_atoms_spatial (
133154 pos : torch .Tensor ,
134155 num_ranks : int ,
135156 num_iters : int = 10 ,
136157) -> torch .Tensor :
137158 """
138- Spatial partitioning via recursive coordinate bisection .
159+ Spatial partitioning via Morton Z-order curve on GPU .
139160
140- Recursively halves the atom set along the longest axis of each
141- sub-group until the desired number of partitions is reached.
142- Guarantees balanced partition sizes (differ by at most 1) and
143- runs in O(N log P) time with no iterative refinement .
161+ Computes a 30-bit Morton code per atom by interleaving 10 bits
162+ from each spatial axis. Sorting by Morton code groups spatially
163+ nearby atoms together, minimizing boundary edges. Atoms are
164+ then split into num_ranks equal chunks in sorted order .
144165
145- Falls back to the longest-axis sort-and-split for non-power-of-2
146- ranks, which is O(N log N).
147-
148- All computation is done on CPU to avoid GPU→CPU sync points
149- from .argmax().item() calls in the recursive bisection (up to
150- P-1 sync points eliminated). The position tensor is small
151- (N x 3 floats, ~768KB at 64k atoms), so the transfer is fast.
166+ Runs entirely on GPU with zero CPU transfers or sync points
167+ (unlike recursive coordinate bisection which requires CPU
168+ round-trips). The Morton curve provides O(N^{2/3}) surface
169+ fraction per partition, similar to recursive bisection.
152170
153171 Args:
154172 pos: Atom positions, shape (N, 3).
@@ -168,19 +186,30 @@ def partition_atoms_spatial(
168186 if num_ranks >= N :
169187 return torch .arange (N , dtype = torch .long , device = device )
170188
171- # Move to CPU to avoid GPU->CPU sync points in recursive bisection.
172- # Each recursion level calls .argmax().item() which forces a sync.
173- # For P=64, that's 63 sync points (~10-20us each = 0.6-1.2ms).
174- # CPU sort on 64k atoms takes <0.1ms, so this is a net win.
175- pos_cpu = pos .detach ().cpu ()
176- assignments = torch .zeros (N , dtype = torch .long )
177-
178- # Recursive coordinate bisection: O(N log P), no iterations.
179- # Handles both power-of-2 and non-power-of-2 rank counts via
180- # uneven splits (right_half = num_parts - left_half).
181- _recursive_bisect (pos_cpu , assignments , torch .arange (N ), 0 , num_ranks )
189+ # Normalize positions to [0, 1023] using a SINGLE global scale
190+ # factor (the largest bounding-box extent). Per-dimension
191+ # normalization would amplify noise in short dimensions, breaking
192+ # Morton locality (e.g. a 100-unit x-gap becomes indistinguishable
193+ # from a 2-unit y-gap after independent rescaling).
194+ min_pos = pos .min (0 )[0 ]
195+ extent = (pos .max (0 )[0 ] - min_pos ).max ().clamp (min = 1e-8 )
196+ norm = ((pos - min_pos ) / extent * 1023 ).long ().clamp (0 , 1023 )
197+
198+ # 30-bit Morton Z-order code: interleave x, y, z bits
199+ x , y , z = norm [:, 0 ], norm [:, 1 ], norm [:, 2 ]
200+ morton = _expand_bits_10 (x ) | (_expand_bits_10 (y ) << 1 ) | (_expand_bits_10 (z ) << 2 )
201+
202+ # Sort by Morton code and assign to ranks in equal chunks
203+ _ , sorted_indices = morton .sort ()
204+ chunk_size = (N + num_ranks - 1 ) // num_ranks
205+ assignments = torch .empty (N , dtype = torch .long , device = device )
206+ assignments [sorted_indices ] = torch .div (
207+ torch .arange (N , device = device ),
208+ chunk_size ,
209+ rounding_mode = "floor" ,
210+ ).clamp (max = num_ranks - 1 )
182211
183- return assignments . to ( device )
212+ return assignments
184213
185214
186215def _recursive_bisect (
@@ -358,13 +387,16 @@ def build_gp_context(
358387 tgt_is_local_edge = edge_index_local [1 ] < total_local_atoms
359388 local_edge_mask = src_is_local & tgt_is_local_edge
360389
361- # Batch scalar extractions to minimize GPU→CPU sync points:
362- # move counts to CPU in one transfer instead of per-element syncs.
363- send_recv_cpu = torch .stack ([send_counts , recv_counts ]).cpu ()
364- send_splits = send_recv_cpu [0 ].tolist ()
365- recv_splits = send_recv_cpu [1 ].tolist ()
390+ # Batch ALL GPU→CPU scalar extractions into a single transfer.
391+ # Stacking send_counts, recv_counts, and local_edge_count into
392+ # one tensor avoids 3 separate sync points (each .cpu()/.item()
393+ # is a GPU→CPU sync).
394+ local_edge_count = local_edge_mask .sum ().unsqueeze (0 ).to (torch .long )
395+ all_cpu = torch .cat ([send_counts , recv_counts , local_edge_count ]).cpu ()
396+ send_splits = all_cpu [:world_size ].tolist ()
397+ recv_splits = all_cpu [world_size : 2 * world_size ].tolist ()
366398 total_recv = sum (recv_splits )
367- num_local_edges = int (local_edge_mask . sum (). cpu () .item ())
399+ num_local_edges = int (all_cpu [ - 1 ] .item ())
368400 num_boundary_edges = edge_index_local .shape [1 ] - num_local_edges
369401
370402 return GPContext (
0 commit comments