Skip to content

Commit b2ff855

Browse files
author
Michael Dzamba
committed
exp35: InferenceSettings.freeze_params (skip weight-grad backward kernels)
Adds an opt-in flag that calls requires_grad_(False) on every model parameter at predict-time _lazy_init. For gradient-force inference, autograd then skips the weight-grad path of every Linear / segment_mm backward, saving CUDA time and peak memory. The win is conditional — helps when paired with moe_layer_type=fairchem_cpp, can regress under tf32+pytorch MOLE due to cuBLAS fused-(dx,dW) kernel selection. Off by default.
1 parent 081a318 commit b2ff855

2 files changed

Lines changed: 15 additions & 0 deletions

File tree

src/fairchem/core/units/mlip_unit/api/inference.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ class InferenceSettings:
7272
# Flag to enable or disable the compilation of the inference model.
7373
compile: bool = False
7474

75+
# Freeze model parameters at inference (requires_grad=False on all
76+
# weights). For gradient-force models, autograd then short-circuits
77+
# the dW backward kernels for every Linear/segment_mm in the
78+
# backward pass that produces forces. With moe_layer_type=fairchem_cpp
79+
# this is a clear win because segment_mm fwd / bwd are separate
80+
# kernels and the bwd skip is pure savings; with pytorch MOLE under
81+
# tf32 it can be a regression because cuBLAS dispatches a fused
82+
# (dx, dW) Tensor Core path that's not selected when only dx is
83+
# requested. Default off; opt-in.
84+
freeze_params: bool = False
85+
7586
# Deprecated
7687
# Flag to enable or disable the use of CUDA Graphs for compute
7788
# This flag is no longer used and will be removed in future versions

src/fairchem/core/units/mlip_unit/predict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ def _lazy_init(self, data: AtomicData) -> None:
461461

462462
self.move_to_device()
463463

464+
if getattr(self.inference_settings, "freeze_params", False):
465+
for p in self.model.parameters():
466+
p.requires_grad_(False)
467+
464468
if self.inference_settings.compile:
465469
logging.warning(
466470
"Model is being compiled this might take a while for the first time"

0 commit comments

Comments
 (0)