@@ -428,20 +428,16 @@ class UMASFastGPUMixedBackend(UMASFastPytorchBackend):
428428 """
429429 GPU backend for mixed-task / mixed-size batches.
430430
431- Inherits the parent so loop experiments can call super() into
432- UMASFastPytorchBackend's helpers, but the seed overrides
433- prepare_model_for_inference with a true no-op: the parent's SO2
434- block-diagonal conversion (convert_so2_conv1/2) assumes fc_m0 is a
435- plain Linear with a `.weight` attribute, which is not true once
436- moe_layer_type=fairchem_cpp wraps fc_m0 as MOLEFairchemCpp. SO2
437- conversion adapted for fairchem_cpp MOLE is a candidate experiment.
431+ Skips the parent's SO2 block-diagonal conversion (convert_so2_conv1/2)
432+ because it assumes fc_m0 is a plain Linear, which is not true when
433+ moe_layer_type=fairchem_cpp wraps fc_m0 as MOLEFairchemCpp.
438434
439435 merge_mole is forbidden — it is incompatible with batches that mix
440436 tasks, charges, or spins (see
441437 eSCNMDMoeBackbone._assert_all_mole_info_consistent).
442438
443- Seed body is otherwise a pure passthrough. Add overrides in this
444- class body to experiment .
439+ Wraps graph-gen / MOLE / wigner helpers with dynamo.disable and
440+ compiles model.forward for throughput .
445441
446442 Requires CUDA, lmax==2, mmax==2, merge_mole=False.
447443 """
@@ -465,63 +461,16 @@ def validate(
465461
466462 @staticmethod
467463 def prepare_model_for_inference (model : torch .nn .Module ) -> None :
468- # Cache graph topology after the first forward, then recompute
469- # edge_distance / distance_vec from pos via get_pbc_distances on
470- # every call so autograd still flows pos → distances → energy →
471- # forces. ASSUMES TOPOLOGY IS CONSTANT across this predictor's
472- # lifetime — true in our benchmark (50 iters of identical
473- # batch); production users with changing topology need a Verlet-
474- # skin invalidation or per-batch cache key.
475- from fairchem .core .graph .compute import get_pbc_distances
476-
477- cache : dict = {}
478- orig_generate = model ._generate_graph
479-
480- def cached_generate (data_dict ):
481- if "edge_index" not in cache :
482- graph_dict = orig_generate (data_dict )
483- cache ["edge_index" ] = graph_dict ["edge_index" ]
484- cache ["cell_offsets" ] = graph_dict .get ("cell_offsets" )
485- cache ["neighbors" ] = graph_dict .get ("neighbors" )
486- cache ["gp_node_offset" ] = data_dict .get ("gp_node_offset" , 0 )
487- return graph_dict
488- data_dict ["gp_node_offset" ] = cache ["gp_node_offset" ]
489- edge_index = cache ["edge_index" ]
490- cell_offsets = cache ["cell_offsets" ]
491- neighbors = cache ["neighbors" ]
492- if cell_offsets is not None :
493- out = get_pbc_distances (
494- data_dict ["pos" ],
495- edge_index ,
496- data_dict ["cell" ],
497- cell_offsets ,
498- neighbors ,
499- return_offsets = True ,
500- return_distance_vec = True ,
501- skip_redundant_filter = True ,
502- )
503- return {
504- "edge_index" : edge_index ,
505- "edge_distance" : out ["distances" ],
506- "edge_distance_vec" : out ["distance_vec" ],
507- "cell_offsets" : cell_offsets ,
508- "offset_distances" : out ["offsets" ],
509- "neighbors" : neighbors ,
510- }
511- distance_vec = (
512- data_dict ["pos" ][edge_index [0 ]] - data_dict ["pos" ][edge_index [1 ]]
513- )
514- return {
515- "edge_index" : edge_index ,
516- "edge_distance" : torch .linalg .norm (distance_vec , dim = - 1 ),
517- "edge_distance_vec" : distance_vec ,
518- }
519-
520- # Make the cached_generate, MOLE setup, and wigner gen opaque to
521- # dynamo — these are eager-only setup operations that compile
522- # would otherwise repeatedly graph-break on (numpy.isclose, list
523- # iteration over Python state, etc.).
524- model ._generate_graph = torch ._dynamo .disable (cached_generate )
464+ # Skip the parent's SO2 block-diagonal conversion — it assumes
465+ # fc_m0 is a plain Linear, which isn't true when
466+ # moe_layer_type=fairchem_cpp wraps it as MOLEFairchemCpp.
467+
468+ # Make graph gen, MOLE setup, and wigner gen opaque to dynamo —
469+ # these are eager-only operations that compile would otherwise
470+ # repeatedly graph-break on (numpy.isclose, list iteration over
471+ # Python state, etc.).
472+ if hasattr (model , "_generate_graph" ):
473+ model ._generate_graph = torch ._dynamo .disable (model ._generate_graph )
525474 if hasattr (model , "_get_rotmat_and_wigner" ):
526475 model ._get_rotmat_and_wigner = torch ._dynamo .disable (
527476 model ._get_rotmat_and_wigner
@@ -536,8 +485,7 @@ def cached_generate(data_dict):
536485 # Compile the backbone forward. With segment_mm registered as a
537486 # custom_op (see fairchem_cpp/ops.py) and the _generate_graph
538487 # branch removed (escn_md.py), dynamo can trace the message-
539- # passing loop cleanly. Static shapes are appropriate for the
540- # mixed-batch use case (one fixed-shape forward per timed iter).
488+ # passing loop cleanly.
541489 torch ._dynamo .config .recompile_limit = 32
542490 model .forward = torch .compile (model .forward , dynamic = False )
543491
0 commit comments