Skip to content

Commit 28d9bd3

Browse files
author
rgao user
committed
Remove unused overlap/P2P fields and parameters
Clean up dead code left over from the overlap removal: - Remove local_edge_mask, num_local_edges, num_boundary_edges, edge_reorder fields from GPContext dataclass - Remove edge classification computation from build_gp_context() - Remove use_overlap_gp and use_p2p_gp parameters from eSCNMDBlock and Edgewise constructors - Remove TestEdgeClassification test class (4 tests for removed fields) - Simplify batched GPU→CPU transfer (2 fewer scalars)
1 parent dba967b commit 28d9bd3

4 files changed

Lines changed: 7 additions & 129 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,6 @@ def __init__(
325325
execution_mode: str = "general",
326326
use_all_to_all_gp: bool = False,
327327
gp_partition_strategy: str = "index_split",
328-
use_overlap_gp: bool = False,
329-
use_p2p_gp: bool = False,
330328
) -> None:
331329
super().__init__()
332330
self.max_num_elements = max_num_elements
@@ -379,8 +377,6 @@ def __init__(
379377
)
380378
self.edge_chunk_size = edge_chunk_size
381379
self.use_all_to_all_gp = use_all_to_all_gp
382-
self.use_overlap_gp = use_overlap_gp
383-
self.use_p2p_gp = use_p2p_gp
384380
self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy)
385381

386382
self.backend = get_execution_backend(execution_mode)
@@ -514,8 +510,6 @@ def __init__(
514510
self.ff_type,
515511
activation_checkpoint_chunk_size=activation_checkpoint_chunk_size,
516512
backend=self.backend,
517-
use_overlap_gp=self.use_overlap_gp,
518-
use_p2p_gp=self.use_p2p_gp,
519513
)
520514
self.blocks.append(block)
521515

src/fairchem/core/models/uma/escn_md_block.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def __init__(
5959
activation_checkpoint_chunk_size: int | None,
6060
backend: ExecutionBackend,
6161
act_type: Literal["gate", "s2"] = "gate",
62-
use_overlap_gp: bool = False,
63-
use_p2p_gp: bool = False,
6462
):
6563
super().__init__()
6664

@@ -69,8 +67,6 @@ def __init__(
6967
self.lmax = lmax
7068
self.mmax = mmax
7169
self.activation_checkpoint_chunk_size = activation_checkpoint_chunk_size
72-
self.use_overlap_gp = use_overlap_gp
73-
self.use_p2p_gp = use_p2p_gp
7470
self.backend = backend
7571

7672
self.mappingReduced = mappingReduced
@@ -347,8 +343,6 @@ def __init__(
347343
ff_type: Literal["spectral", "grid"],
348344
activation_checkpoint_chunk_size: int | None,
349345
backend: ExecutionBackend,
350-
use_overlap_gp: bool = False,
351-
use_p2p_gp: bool = False,
352346
) -> None:
353347
super().__init__()
354348
self.sphere_channels = sphere_channels
@@ -372,8 +366,6 @@ def __init__(
372366
act_type=act_type,
373367
activation_checkpoint_chunk_size=activation_checkpoint_chunk_size,
374368
backend=backend,
375-
use_overlap_gp=use_overlap_gp,
376-
use_p2p_gp=use_p2p_gp,
377369
)
378370

379371
self.norm_2 = get_normalization_layer(

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,6 @@ class GPContext:
100100
edge_index_local: Precomputed edge index remapped to local indices.
101101
None if not yet computed (set by build_gp_context when edge_index
102102
is provided).
103-
local_edge_mask: Boolean mask identifying fully-local edges (both src
104-
and tgt are local atoms). Used for comm-compute overlap. Shape:
105-
(num_edges,). None if not yet computed.
106-
num_local_edges: Number of fully-local edges (precomputed from
107-
local_edge_mask). None if not yet computed.
108-
num_boundary_edges: Number of boundary edges (src is remote). None
109-
if not yet computed.
110-
edge_reorder: Permutation that sorts edges so local edges come first,
111-
then boundary edges. Shape: (num_edges,). Applied once to per-edge
112-
tensors (wigner, x_edge, etc.) in the backbone forward. Enables
113-
compile-friendly overlap via split() instead of boolean indexing.
114103
"""
115104

116105
rank: int
@@ -126,10 +115,6 @@ class GPContext:
126115
total_needed_atoms: int
127116
send_indices: torch.Tensor | None = None
128117
edge_index_local: torch.Tensor | None = None
129-
local_edge_mask: torch.Tensor | None = None
130-
num_local_edges: int | None = None
131-
num_boundary_edges: int | None = None
132-
edge_reorder: torch.Tensor | None = None
133118
# Precomputed Python lists to avoid repeated .tolist() in AllToAllCollect
134119
send_splits: list[int] | None = None
135120
recv_splits: list[int] | None = None
@@ -273,7 +258,7 @@ def build_gp_context(
273258
274259
Args:
275260
edge_index: Edge index filtered to edges whose targets are in
276-
this rank's partition, shape (2, num_local_edges).
261+
this rank's partition, shape (2, num_edges).
277262
Row 0 = source, row 1 = target.
278263
rank_assignments: Rank assignment for each atom, shape (total_atoms,).
279264
rank: This rank's GP rank.
@@ -408,23 +393,9 @@ def build_gp_context(
408393
# Precompute edge_index_local
409394
edge_index_local = global_to_local[edge_index]
410395

411-
# Classify edges: fully-local (both endpoints local) vs boundary
412-
# (source is remote). Used for communication-computation overlap.
413-
src_is_local = edge_index_local[0] < total_local_atoms
414-
tgt_is_local_edge = edge_index_local[1] < total_local_atoms
415-
local_edge_mask = src_is_local & tgt_is_local_edge
416-
417-
# Pre-compute edge reorder permutation: local edges first, boundary
418-
# edges last. This enables compile-friendly overlap via split()
419-
# instead of boolean indexing. The reorder is applied in the
420-
# backbone forward to all per-edge tensors simultaneously.
421-
edge_reorder = torch.argsort((~local_edge_mask).to(torch.int32), stable=True)
422-
423396
# Batch ALL GPU→CPU scalar extractions into a single transfer.
424-
# This batches send_counts, recv_counts, local_edge_count, AND
425-
# validation scalars into ONE .cpu() call, eliminating 2 extra
426-
# GPU→CPU syncs from separate .all()/.any() validation checks.
427-
local_edge_count = local_edge_mask.sum().unsqueeze(0).to(torch.long)
397+
# This batches send_counts, recv_counts, AND validation scalars
398+
# into ONE .cpu() call, eliminating extra GPU→CPU syncs.
428399
bad_edge_count = (edge_index_local < 0).sum().unsqueeze(0).to(torch.long)
429400
send_valid = (
430401
torch.ones(1, dtype=torch.long, device=device)
@@ -436,16 +407,12 @@ def build_gp_context(
436407
.to(torch.long)
437408
)
438409
)
439-
all_cpu = torch.cat(
440-
[send_counts, recv_counts, local_edge_count, bad_edge_count, send_valid]
441-
).cpu()
410+
all_cpu = torch.cat([send_counts, recv_counts, bad_edge_count, send_valid]).cpu()
442411
send_splits = all_cpu[:world_size].tolist()
443412
recv_splits = all_cpu[world_size : 2 * world_size].tolist()
444413
total_recv = sum(recv_splits)
445-
num_local_edges = int(all_cpu[2 * world_size].item())
446-
num_boundary_edges = edge_index_local.shape[1] - num_local_edges
447-
n_bad = int(all_cpu[2 * world_size + 1].item())
448-
send_ok = int(all_cpu[2 * world_size + 2].item())
414+
n_bad = int(all_cpu[2 * world_size].item())
415+
send_ok = int(all_cpu[2 * world_size + 1].item())
449416

450417
# Validate AFTER the batched CPU transfer (no extra GPU syncs).
451418
if not send_ok:
@@ -529,15 +496,11 @@ def build_gp_context(
529496
total_needed_atoms=total_needed_atoms,
530497
send_indices=send_indices,
531498
edge_index_local=edge_index_local,
532-
local_edge_mask=local_edge_mask,
533-
num_local_edges=num_local_edges,
534-
num_boundary_edges=num_boundary_edges,
535-
edge_reorder=edge_reorder,
536499
# Precompute Python lists once (avoids .tolist() per layer per forward)
537500
send_splits=send_splits,
538501
recv_splits=recv_splits,
539502
total_recv=total_recv,
540-
# Precompute sparse neighbor lists for P2P communication
503+
# Precompute sparse neighbor lists for communication
541504
send_neighbors=[
542505
r for r in range(world_size) if send_splits[r] > 0 and r != rank
543506
],

tests/core/models/uma/test_graph_parallel.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -568,77 +568,6 @@ def test_a2a_spatial_partition():
568568
)
569569

570570

571-
class TestEdgeClassification:
572-
"""
573-
Tests for local_edge_mask precomputation in GPContext.
574-
"""
575-
576-
def test_edge_mask_types(self):
577-
"""
578-
Verify that local_edge_mask is computed and has correct type/shape.
579-
"""
580-
# 6 atoms, 2 ranks, edges cross the partition boundary
581-
# build_gp_context expects edges pre-filtered to targets in this
582-
# rank's partition (atoms 0, 1, 2 for rank 0).
583-
rank_assignments = torch.tensor([0, 0, 0, 1, 1, 1])
584-
edge_index = torch.tensor(
585-
[
586-
[0, 1, 2, 3],
587-
[1, 2, 0, 0],
588-
]
589-
)
590-
ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2)
591-
assert ctx.local_edge_mask is not None
592-
assert ctx.local_edge_mask.dtype == torch.bool
593-
assert ctx.local_edge_mask.shape[0] == edge_index.shape[1]
594-
assert ctx.num_local_edges is not None
595-
assert ctx.num_boundary_edges is not None
596-
assert ctx.num_local_edges + ctx.num_boundary_edges == edge_index.shape[1]
597-
598-
def test_all_local_edges(self):
599-
"""
600-
When all edges are within the local partition, all should be local.
601-
"""
602-
rank_assignments = torch.tensor([0, 0, 0, 1, 1, 1])
603-
# All edges within rank 0's partition
604-
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])
605-
ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2)
606-
assert ctx.num_local_edges == 3
607-
assert ctx.num_boundary_edges == 0
608-
assert ctx.local_edge_mask.all()
609-
610-
def test_all_boundary_edges(self):
611-
"""
612-
When all edges have remote sources, all should be boundary.
613-
"""
614-
rank_assignments = torch.tensor([0, 0, 0, 1, 1, 1])
615-
# All edges from rank 1 atoms to rank 0 atoms
616-
edge_index = torch.tensor([[3, 4, 5], [0, 1, 2]])
617-
ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2)
618-
assert ctx.num_local_edges == 0
619-
assert ctx.num_boundary_edges == 3
620-
assert not ctx.local_edge_mask.any()
621-
622-
def test_mixed_edges(self):
623-
"""
624-
Verify correct classification of mixed local and boundary edges.
625-
"""
626-
rank_assignments = torch.tensor([0, 0, 0, 1, 1, 1])
627-
# 4 edges: 2 local (0->1, 1->2), 2 boundary (3->0, 4->1)
628-
edge_index = torch.tensor(
629-
[
630-
[0, 1, 3, 4],
631-
[1, 2, 0, 1],
632-
]
633-
)
634-
ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2)
635-
assert ctx.num_local_edges == 2
636-
assert ctx.num_boundary_edges == 2
637-
# First 2 edges are local, last 2 are boundary
638-
expected_mask = torch.tensor([True, True, False, False])
639-
assert torch.equal(ctx.local_edge_mask, expected_mask)
640-
641-
642571
# =========================================================================
643572
# Distributed tests: send_info optimization correctness
644573
# =========================================================================

0 commit comments

Comments
 (0)