Skip to content

Commit aa60e1b

Browse files
committed
Fix BL compile regression: move @torch.compiler.disable to A2A-only code
The @torch.compiler.disable decorator on _generate_graph was wrapping the entire function, including the BL (all-gather) code path that is identical to main. This caused a graph break that made torch.compile 12x slower at 64 GPUs (92 min vs ~8 min on main). Fix: Extract A2A partitioning (spatial/index_split) into a new _compute_a2a_partition() static method with @torch.compiler.disable, leaving _generate_graph fully compilable for the BL path. Verified at 1-GPU and 8-GPU: compile time and inference performance match main branch exactly.
1 parent d7033ab commit aa60e1b

1 file changed

Lines changed: 48 additions & 22 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,33 @@ def csd_embedding(self, charge, spin, dataset):
602602
)
603603
return torch.nn.SiLU()(self.mix_csd(torch.cat((chg_emb, spin_emb), dim=1)))
604604

605+
@staticmethod
606+
@torch.compiler.disable
607+
def _compute_a2a_partition(
608+
pos: torch.Tensor,
609+
total_atoms: int,
610+
device: torch.device,
611+
world_size: int,
612+
rank: int,
613+
strategy: PartitionStrategy,
614+
) -> tuple[torch.Tensor, torch.Tensor]:
615+
"""
616+
Compute A2A rank assignments and node partition.
617+
618+
Separated from _generate_graph so that only the A2A-specific
619+
partitioning is excluded from torch.compile. The BL (all-gather)
620+
path stays fully compilable.
621+
"""
622+
with record_function("a2a_partition"):
623+
if strategy == PartitionStrategy.SPATIAL:
624+
rank_assignments = partition_atoms_spatial(pos, world_size)
625+
else:
626+
rank_assignments = partition_atoms_index_split(
627+
total_atoms, world_size, device
628+
)
629+
node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0]
630+
return rank_assignments, node_partition
631+
605632
def _generate_graph(self, data_dict):
606633
data_dict["gp_node_offset"] = 0
607634
node_partition = None
@@ -616,22 +643,14 @@ def _generate_graph(self, data_dict):
616643
# graph-generation partition and the GPContext partition
617644
# are identical, avoiding index mismatches that cause
618645
# OOB crashes.
619-
total_atoms = len(atomic_numbers_full)
620-
device = atomic_numbers_full.device
621-
world_size = gp_utils.get_gp_world_size()
622-
623-
if self.gp_partition_strategy == PartitionStrategy.SPATIAL:
624-
rank_assignments = partition_atoms_spatial(
625-
data_dict["pos"], world_size
626-
)
627-
else:
628-
rank_assignments = partition_atoms_index_split(
629-
total_atoms, world_size, device
630-
)
631-
632-
node_partition = (rank_assignments == gp_utils.get_gp_rank()).nonzero(
633-
as_tuple=True
634-
)[0]
646+
rank_assignments, node_partition = self._compute_a2a_partition(
647+
pos=data_dict["pos"],
648+
total_atoms=len(atomic_numbers_full),
649+
device=atomic_numbers_full.device,
650+
world_size=gp_utils.get_gp_world_size(),
651+
rank=gp_utils.get_gp_rank(),
652+
strategy=self.gp_partition_strategy,
653+
)
635654
else:
636655
# Legacy all-gather: use consecutive index split
637656
node_partition = torch.tensor_split(
@@ -668,6 +687,11 @@ def _generate_graph(self, data_dict):
668687
radius_pbc_version=self.radius_pbc_version,
669688
pbc=pbc,
670689
node_partition=node_partition,
690+
rank_assignments=rank_assignments if self.use_all_to_all_gp else None,
691+
rank=gp_utils.get_gp_rank() if self.use_all_to_all_gp else None,
692+
world_size=gp_utils.get_gp_world_size()
693+
if self.use_all_to_all_gp
694+
else None,
671695
)
672696
else:
673697
# this assume edge_index is provided
@@ -715,12 +739,14 @@ def _generate_graph(self, data_dict):
715739

716740
# Build GPContext for all-to-all communication
717741
if self.use_all_to_all_gp:
718-
gp_ctx = build_gp_context(
719-
edge_index=graph_dict["edge_index"],
720-
rank_assignments=rank_assignments,
721-
rank=gp_utils.get_gp_rank(),
722-
world_size=gp_utils.get_gp_world_size(),
723-
)
742+
with record_function("a2a_build_gp_context"):
743+
gp_ctx = build_gp_context(
744+
edge_index=graph_dict["edge_index"],
745+
rank_assignments=rank_assignments,
746+
rank=gp_utils.get_gp_rank(),
747+
world_size=gp_utils.get_gp_world_size(),
748+
send_info=graph_dict.get("send_info"),
749+
)
724750
data_dict["gp_ctx"] = gp_ctx
725751
# All-to-all uses local indices via gp_ctx.edge_index_local,
726752
# so node_offset is always 0.

0 commit comments

Comments
 (0)