Skip to content

Commit fc38fbc

Browse files
committed
Add freeze_model_parameters flag to InferenceSettings
Add optional flag to disable requires_grad on all model parameters during inference. When enabled, only pos and cell retain requires_grad for autograd-based force/stress computation. The freeze is applied before torch.compile so the compiled graph can optimize accordingly. Defaults to False pending further benchmarking.
1 parent 28292f0 commit fc38fbc

2 files changed

Lines changed: 11 additions & 0 deletions

File tree

src/fairchem/core/units/mlip_unit/api/inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ class InferenceSettings:
123123
# (e.g., eSCNMDBackbone adds stress for all energy tasks by default)
124124
auto_add_default_untrained_tasks: bool = True
125125

126+
# When True, set requires_grad=False on all model parameters before
127+
# inference. Only pos and cell retain requires_grad for force/stress
128+
# computation via autograd.
129+
freeze_model_parameters: bool = False
130+
126131
# Maximum number of atoms per system for padding. Required when
127132
# compile=True for models that use padding (e.g., AllScAIP).
128133
# All inputs will be padded to this size. Larger values consume more

src/fairchem/core/units/mlip_unit/predict.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ def _lazy_init(self, data: AtomicData) -> None:
461461

462462
self.move_to_device()
463463

464+
# Freeze all model parameters before compile — only pos/cell need
465+
# requires_grad for force/stress computation via autograd.
466+
if self.inference_settings.freeze_model_parameters:
467+
for p in self.model.parameters():
468+
p.requires_grad_(False)
469+
464470
if self.inference_settings.compile:
465471
logging.warning(
466472
"Model is being compiled this might take a while for the first time"

0 commit comments

Comments
 (0)