Skip to content

Commit 4dd9e92

Browse files
committed
remove cache
1 parent b2ff855 commit 4dd9e92

1 file changed

Lines changed: 16 additions & 68 deletions

File tree

src/fairchem/core/models/uma/nn/execution_backends.py

Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -428,20 +428,16 @@ class UMASFastGPUMixedBackend(UMASFastPytorchBackend):
428428
"""
429429
GPU backend for mixed-task / mixed-size batches.
430430
431-
Inherits the parent so loop experiments can call super() into
432-
UMASFastPytorchBackend's helpers, but the seed overrides
433-
prepare_model_for_inference with a true no-op: the parent's SO2
434-
block-diagonal conversion (convert_so2_conv1/2) assumes fc_m0 is a
435-
plain Linear with a `.weight` attribute, which is not true once
436-
moe_layer_type=fairchem_cpp wraps fc_m0 as MOLEFairchemCpp. SO2
437-
conversion adapted for fairchem_cpp MOLE is a candidate experiment.
431+
Skips the parent's SO2 block-diagonal conversion (convert_so2_conv1/2)
432+
because it assumes fc_m0 is a plain Linear, which is not true when
433+
moe_layer_type=fairchem_cpp wraps fc_m0 as MOLEFairchemCpp.
438434
439435
merge_mole is forbidden — it is incompatible with batches that mix
440436
tasks, charges, or spins (see
441437
eSCNMDMoeBackbone._assert_all_mole_info_consistent).
442438
443-
Seed body is otherwise a pure passthrough. Add overrides in this
444-
class body to experiment.
439+
Wraps graph-gen / MOLE / wigner helpers with dynamo.disable and
440+
compiles model.forward for throughput.
445441
446442
Requires CUDA, lmax==2, mmax==2, merge_mole=False.
447443
"""
@@ -465,63 +461,16 @@ def validate(
465461

466462
@staticmethod
467463
def prepare_model_for_inference(model: torch.nn.Module) -> None:
468-
# Cache graph topology after the first forward, then recompute
469-
# edge_distance / distance_vec from pos via get_pbc_distances on
470-
# every call so autograd still flows pos → distances → energy →
471-
# forces. ASSUMES TOPOLOGY IS CONSTANT across this predictor's
472-
# lifetime — true in our benchmark (50 iters of identical
473-
# batch); production users with changing topology need a Verlet-
474-
# skin invalidation or per-batch cache key.
475-
from fairchem.core.graph.compute import get_pbc_distances
476-
477-
cache: dict = {}
478-
orig_generate = model._generate_graph
479-
480-
def cached_generate(data_dict):
481-
if "edge_index" not in cache:
482-
graph_dict = orig_generate(data_dict)
483-
cache["edge_index"] = graph_dict["edge_index"]
484-
cache["cell_offsets"] = graph_dict.get("cell_offsets")
485-
cache["neighbors"] = graph_dict.get("neighbors")
486-
cache["gp_node_offset"] = data_dict.get("gp_node_offset", 0)
487-
return graph_dict
488-
data_dict["gp_node_offset"] = cache["gp_node_offset"]
489-
edge_index = cache["edge_index"]
490-
cell_offsets = cache["cell_offsets"]
491-
neighbors = cache["neighbors"]
492-
if cell_offsets is not None:
493-
out = get_pbc_distances(
494-
data_dict["pos"],
495-
edge_index,
496-
data_dict["cell"],
497-
cell_offsets,
498-
neighbors,
499-
return_offsets=True,
500-
return_distance_vec=True,
501-
skip_redundant_filter=True,
502-
)
503-
return {
504-
"edge_index": edge_index,
505-
"edge_distance": out["distances"],
506-
"edge_distance_vec": out["distance_vec"],
507-
"cell_offsets": cell_offsets,
508-
"offset_distances": out["offsets"],
509-
"neighbors": neighbors,
510-
}
511-
distance_vec = (
512-
data_dict["pos"][edge_index[0]] - data_dict["pos"][edge_index[1]]
513-
)
514-
return {
515-
"edge_index": edge_index,
516-
"edge_distance": torch.linalg.norm(distance_vec, dim=-1),
517-
"edge_distance_vec": distance_vec,
518-
}
519-
520-
# Make the cached_generate, MOLE setup, and wigner gen opaque to
521-
# dynamo — these are eager-only setup operations that compile
522-
# would otherwise repeatedly graph-break on (numpy.isclose, list
523-
# iteration over Python state, etc.).
524-
model._generate_graph = torch._dynamo.disable(cached_generate)
464+
# Skip the parent's SO2 block-diagonal conversion — it assumes
465+
# fc_m0 is a plain Linear, which isn't true when
466+
# moe_layer_type=fairchem_cpp wraps it as MOLEFairchemCpp.
467+
468+
# Make graph gen, MOLE setup, and wigner gen opaque to dynamo —
469+
# these are eager-only operations that compile would otherwise
470+
# repeatedly graph-break on (numpy.isclose, list iteration over
471+
# Python state, etc.).
472+
if hasattr(model, "_generate_graph"):
473+
model._generate_graph = torch._dynamo.disable(model._generate_graph)
525474
if hasattr(model, "_get_rotmat_and_wigner"):
526475
model._get_rotmat_and_wigner = torch._dynamo.disable(
527476
model._get_rotmat_and_wigner
@@ -536,8 +485,7 @@ def cached_generate(data_dict):
536485
# Compile the backbone forward. With segment_mm registered as a
537486
# custom_op (see fairchem_cpp/ops.py) and the _generate_graph
538487
# branch removed (escn_md.py), dynamo can trace the message-
539-
# passing loop cleanly. Static shapes are appropriate for the
540-
# mixed-batch use case (one fixed-shape forward per timed iter).
488+
# passing loop cleanly.
541489
torch._dynamo.config.recompile_limit = 32
542490
model.forward = torch.compile(model.forward, dynamic=False)
543491

0 commit comments

Comments
 (0)