Skip to content

Commit cfeb88c

Browse files
committed
Remove GPContext caching, optimize spatial partition + collectives
- Remove cache_gp_context flag, _cached_gp_ctx, clear_gp_cache() — caching is invalid for MD where graphs change every step - Replace k-means with recursive coordinate bisection: O(N log P) with no iterations, no load-balancing pass, deterministic - Replace list-based _safe_all_to_all for count exchange with direct all_to_all_single (NCCL) / all_gather (Gloo fallback) - Vectorize _compute_send_indices: eliminate Python loop over world_size, use sort + all_to_all_single instead - Remove 3 caching tests (53 tests remain)
1 parent 133f33b commit cfeb88c

4 files changed

Lines changed: 154 additions & 308 deletions

File tree

CLAUDE.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,6 @@ To pass backbone config overrides (e.g., enabling all-to-all GP) through Inferen
287287
### fairchem CLI cannot submit from within SLURM
288288
The `_cli.py` explicitly blocks SLURM submission from within an active SLURM job (`assert os.getenv("SLURM_SUBMIT_HOST") is None`). Always run `fairchem -c ... job=slurm` from a login node, not from an `srun` session. This also means profai-cli's `launch-experiment --qos` (which creates its own SLURM job) cannot be used to wrap `fairchem -c ... job=slurm` — it would create a double submission.
289289
290-
### GPContext caching: cache_gp_context flag
291-
When using `use_all_to_all_gp=true`, set `cache_gp_context=true` to cache the GPContext after the first forward pass. This eliminates the per-forward overhead of k-means clustering, build_gp_context (2 all-to-all calls), and _compute_send_indices (1 all-to-all call). The cache auto-invalidates when the atom count changes. For MD where positions change, call `backbone.clear_gp_cache()` when the neighbor list changes, or implement lazy rebuild with displacement tracking.
292-
293290
### AllToAllCollect backward must match forward arg count
294291
The `AllToAllCollect.forward()` uses `torch.autograd.Function.apply()`. Every argument passed to `apply()` must have a corresponding `None` gradient returned in `backward()`. If you add new arguments to `forward()`, you MUST also add corresponding `None`s to the backward return tuple, or autograd will error: "returned an incorrect number of gradients".
295292

src/fairchem/core/models/uma/escn_md.py

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def __init__(
325325
use_all_to_all_gp: bool = False,
326326
gp_partition_strategy: str = "index_split",
327327
use_overlap_gp: bool = False,
328-
cache_gp_context: bool = False,
329328
) -> None:
330329
super().__init__()
331330
self.max_num_elements = max_num_elements
@@ -380,11 +379,6 @@ def __init__(
380379
self.use_all_to_all_gp = use_all_to_all_gp
381380
self.use_overlap_gp = use_overlap_gp
382381
self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy)
383-
self.cache_gp_context = cache_gp_context
384-
# Cached GPContext for repeated inference on the same structure.
385-
# Only used when cache_gp_context=True and use_all_to_all_gp=True.
386-
self._cached_gp_ctx: GPContext | None = None
387-
self._cached_rank_assignments: torch.Tensor | None = None
388382

389383
self.backend = get_execution_backend(execution_mode)
390384

@@ -608,59 +602,32 @@ def csd_embedding(self, charge, spin, dataset):
608602
)
609603
return torch.nn.SiLU()(self.mix_csd(torch.cat((chg_emb, spin_emb), dim=1)))
610604

611-
def clear_gp_cache(self):
612-
"""
613-
Invalidate the cached GPContext.
614-
615-
Call this when the input structure changes (different atom count,
616-
positions, or cell) to force recomputation on the next forward pass.
617-
The cache is automatically invalidated when the atom count changes.
618-
"""
619-
self._cached_gp_ctx = None
620-
self._cached_rank_assignments = None
621-
622605
def _generate_graph(self, data_dict):
623606
data_dict["gp_node_offset"] = 0
624607
node_partition = None
625608
rank_assignments = None
626-
_use_cached = False
627609
if gp_utils.initialized():
628610
# create the partitions
629611
atomic_numbers_full = data_dict["atomic_numbers_full"]
630612

631613
if self.use_all_to_all_gp:
632-
# Check if we can reuse a cached GPContext.
633-
# Valid when cache_gp_context=True and the atom count
634-
# matches (same system being inferred repeatedly).
635-
_use_cached = (
636-
self.cache_gp_context
637-
and self._cached_gp_ctx is not None
638-
and self._cached_rank_assignments is not None
639-
and len(atomic_numbers_full)
640-
== self._cached_rank_assignments.shape[0]
641-
)
642-
643-
if _use_cached:
644-
# Reuse cached partition (skip k-means / index-split)
645-
rank_assignments = self._cached_rank_assignments
614+
# All-to-all: compute rank_assignments FIRST, then derive
615+
# node_partition from them. This ensures the
616+
# graph-generation partition and the GPContext partition
617+
# are identical, avoiding index mismatches that cause
618+
# OOB crashes.
619+
total_atoms = len(atomic_numbers_full)
620+
device = atomic_numbers_full.device
621+
world_size = gp_utils.get_gp_world_size()
622+
623+
if self.gp_partition_strategy == PartitionStrategy.SPATIAL:
624+
rank_assignments = partition_atoms_spatial(
625+
data_dict["pos"], world_size
626+
)
646627
else:
647-
# All-to-all: compute rank_assignments FIRST, then derive
648-
# node_partition from them. This ensures the
649-
# graph-generation partition and the GPContext partition
650-
# are identical, avoiding index mismatches that cause
651-
# OOB crashes.
652-
total_atoms = len(atomic_numbers_full)
653-
device = atomic_numbers_full.device
654-
world_size = gp_utils.get_gp_world_size()
655-
656-
if self.gp_partition_strategy == PartitionStrategy.SPATIAL:
657-
rank_assignments = partition_atoms_spatial(
658-
data_dict["pos"], world_size
659-
)
660-
else:
661-
rank_assignments = partition_atoms_index_split(
662-
total_atoms, world_size, device
663-
)
628+
rank_assignments = partition_atoms_index_split(
629+
total_atoms, world_size, device
630+
)
664631

665632
node_partition = (rank_assignments == gp_utils.get_gp_rank()).nonzero(
666633
as_tuple=True
@@ -748,22 +715,12 @@ def _generate_graph(self, data_dict):
748715

749716
# Build GPContext for all-to-all communication
750717
if self.use_all_to_all_gp:
751-
if _use_cached:
752-
# Reuse cached GPContext (skip build_gp_context +
753-
# _compute_send_indices — saves 2 all-to-all calls
754-
# and all CPU-side index computation per forward).
755-
gp_ctx = self._cached_gp_ctx
756-
else:
757-
gp_ctx = build_gp_context(
758-
edge_index=graph_dict["edge_index"],
759-
rank_assignments=rank_assignments,
760-
rank=gp_utils.get_gp_rank(),
761-
world_size=gp_utils.get_gp_world_size(),
762-
)
763-
# Cache for reuse on subsequent calls
764-
if self.cache_gp_context:
765-
self._cached_gp_ctx = gp_ctx
766-
self._cached_rank_assignments = rank_assignments
718+
gp_ctx = build_gp_context(
719+
edge_index=graph_dict["edge_index"],
720+
rank_assignments=rank_assignments,
721+
rank=gp_utils.get_gp_rank(),
722+
world_size=gp_utils.get_gp_world_size(),
723+
)
767724
data_dict["gp_ctx"] = gp_ctx
768725
# All-to-all uses local indices via gp_ctx.edge_index_local,
769726
# so node_offset is always 0.

0 commit comments

Comments
 (0)