@@ -602,6 +602,33 @@ def csd_embedding(self, charge, spin, dataset):
602602 )
603603 return torch .nn .SiLU ()(self .mix_csd (torch .cat ((chg_emb , spin_emb ), dim = 1 )))
604604
605+ @staticmethod
606+ @torch .compiler .disable
607+ def _compute_a2a_partition (
608+ pos : torch .Tensor ,
609+ total_atoms : int ,
610+ device : torch .device ,
611+ world_size : int ,
612+ rank : int ,
613+ strategy : PartitionStrategy ,
614+ ) -> tuple [torch .Tensor , torch .Tensor ]:
615+ """
616+ Compute A2A rank assignments and node partition.
617+
618+ Separated from _generate_graph so that only the A2A-specific
619+ partitioning is excluded from torch.compile. The BL (all-gather)
620+ path stays fully compilable.
621+ """
622+ with record_function ("a2a_partition" ):
623+ if strategy == PartitionStrategy .SPATIAL :
624+ rank_assignments = partition_atoms_spatial (pos , world_size )
625+ else :
626+ rank_assignments = partition_atoms_index_split (
627+ total_atoms , world_size , device
628+ )
629+ node_partition = (rank_assignments == rank ).nonzero (as_tuple = True )[0 ]
630+ return rank_assignments , node_partition
631+
605632 def _generate_graph (self , data_dict ):
606633 data_dict ["gp_node_offset" ] = 0
607634 node_partition = None
@@ -616,22 +643,14 @@ def _generate_graph(self, data_dict):
616643 # graph-generation partition and the GPContext partition
617644 # are identical, avoiding index mismatches that cause
618645 # OOB crashes.
619- total_atoms = len (atomic_numbers_full )
620- device = atomic_numbers_full .device
621- world_size = gp_utils .get_gp_world_size ()
622-
623- if self .gp_partition_strategy == PartitionStrategy .SPATIAL :
624- rank_assignments = partition_atoms_spatial (
625- data_dict ["pos" ], world_size
626- )
627- else :
628- rank_assignments = partition_atoms_index_split (
629- total_atoms , world_size , device
630- )
631-
632- node_partition = (rank_assignments == gp_utils .get_gp_rank ()).nonzero (
633- as_tuple = True
634- )[0 ]
646+ rank_assignments , node_partition = self ._compute_a2a_partition (
647+ pos = data_dict ["pos" ],
648+ total_atoms = len (atomic_numbers_full ),
649+ device = atomic_numbers_full .device ,
650+ world_size = gp_utils .get_gp_world_size (),
651+ rank = gp_utils .get_gp_rank (),
652+ strategy = self .gp_partition_strategy ,
653+ )
635654 else :
636655 # Legacy all-gather: use consecutive index split
637656 node_partition = torch .tensor_split (
@@ -668,6 +687,11 @@ def _generate_graph(self, data_dict):
668687 radius_pbc_version = self .radius_pbc_version ,
669688 pbc = pbc ,
670689 node_partition = node_partition ,
690+ rank_assignments = rank_assignments if self .use_all_to_all_gp else None ,
691+ rank = gp_utils .get_gp_rank () if self .use_all_to_all_gp else None ,
692+ world_size = gp_utils .get_gp_world_size ()
693+ if self .use_all_to_all_gp
694+ else None ,
671695 )
672696 else :
673697 # this assume edge_index is provided
@@ -715,12 +739,14 @@ def _generate_graph(self, data_dict):
715739
716740 # Build GPContext for all-to-all communication
717741 if self .use_all_to_all_gp :
718- gp_ctx = build_gp_context (
719- edge_index = graph_dict ["edge_index" ],
720- rank_assignments = rank_assignments ,
721- rank = gp_utils .get_gp_rank (),
722- world_size = gp_utils .get_gp_world_size (),
723- )
742+ with record_function ("a2a_build_gp_context" ):
743+ gp_ctx = build_gp_context (
744+ edge_index = graph_dict ["edge_index" ],
745+ rank_assignments = rank_assignments ,
746+ rank = gp_utils .get_gp_rank (),
747+ world_size = gp_utils .get_gp_world_size (),
748+ send_info = graph_dict .get ("send_info" ),
749+ )
724750 data_dict ["gp_ctx" ] = gp_ctx
725751 # All-to-all uses local indices via gp_ctx.edge_index_local,
726752 # so node_offset is always 0.
0 commit comments