Skip to content

Commit d029820

Browse files
lbluqueericyuan00000zulissimetamisko
authored
Enable Untrained Property Predictions Including Hessians (#1811)
## Summary This PR extends FAIRChem's inference capabilities to compute derivative properties (forces, stress, hessians) for models trained only on energy, eliminating the need to retrain models when derivative properties are needed. Adds infrastructure for computing "untrained" properties—derivative quantities that can be calculated via automatic differentiation from trained energy predictions. Users can now request forces, stress, and hessians at inference time even if these properties were not included in the original training task. ## Key Changes 1. New Output Computation Modules (+463 lines) src/fairchem/core/models/uma/outputs.py - UMA-specific gradient computations: - get_displacement_and_cell() - Prepares strain tensors for stress computation - compute_energy() - Reduces node energies to system level - compute_forces() - Computes forces as -∇E/∇pos - compute_forces_and_stress() - Computes forces and stress from virial - compute_hessian() - Second-order derivatives with vmap and loop implementations - Helper functions: get_l_component(), reduce_node_to_system() src/fairchem/core/models/utils/outputs.py - General utilities: - get_numerical_hessian() - Finite difference hessian computation for validation 2. Enhanced Inference Settings & Predict Unit (+240 lines) src/fairchem/core/units/mlip_unit/predict.py: - _create_untrained_tasks() - Dynamically generates Task objects for untrained properties - _configure_head_gradients() - Enables autograd in model heads for requested derivatives - _validate_untrained_property_requests() - Validates inference settings - Supports per-dataset property requests via InferenceSettings: - compute_untrained_forces: set[str] - Dataset names requiring forces - compute_untrained_stress: set[str] - Dataset names requiring stress - compute_untrained_hessian: set[str] - Dataset names requiring hessian 3. Model Task Management (+72 lines) src/fairchem/core/models/base.py: - Added add_tasks() method to HydraModel and HydraModelV2 - Allows runtime addition of inference-only tasks - Converted tasks to property with backing _tasks attribute - Rebuilds dataset_to_tasks map when tasks are added 4. ESCN-MD Refactoring (+381/-209 lines) src/fairchem/core/models/uma/escn_md.py: - Refactored to use centralized gradient computation from outputs.py - Cleaner separation between energy prediction and derivative computation - Maintains backward compatibility ## Usage Example ```python from fairchem.core.units.mlip_unit import InferenceSettings, MLIPPredictUnit # Load energy-only checkpoint and enable untrained derivatives settings = InferenceSettings( predict_untrained_forces={"omol"}, predict_untrained_stress={"omol"}, predict_untrained_hessian={"omol"} ) predictor = MLIPPredictUnit( "uma-s.pt", device="cuda", inference_settings=settings ) # Predictions now include all requested properties preds = predictor.predict(batch) # preds = {"energy": ..., "forces": ..., "stress": ..., "hessian": ...} ``` --------- Co-authored-by: Eric Yuan <87563575+ericyuan00000@users.noreply.github.com> Co-authored-by: zulissimeta <122578103+zulissimeta@users.noreply.github.com> Co-authored-by: ericyuan00000 <ericyuan@berkeley.edu> Co-authored-by: misko <misko@meta.com>
1 parent 8f431ae commit d029820

11 files changed

Lines changed: 721 additions & 50 deletions

File tree

docs/core/common_tasks/ase_calculator.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ predictor = pretrained_mlip.get_predict_unit(
108108
)
109109
```
110110

111+
## Enabling gradient stress or Hessian prediction
112+
113+
Some tasks, for example omol, odac, or oc20/25, were not trained using stress labels. Similarly, no tasks were supervised to predict Hessians. However, predictions of untrained derivatives of energy, such as stress and Hessians, can be enabled by using the following inference settings flags,
114+
115+
| Setting Flag | Description |
116+
| ----- | ----- |
117+
| predict_untrained_forces | A set of task/dataset names (e.g., `{"omol", "oc20"}`) for which forces will be computed via autograd even though the checkpoint was not trained with a forces head for those tasks. |
118+
| predict_untrained_stress | A set of task/dataset names for which stress tensors will be computed via autograd even though the checkpoint was not trained with a stress head for those tasks. The default empty set disables this. |
119+
| predict_untrained_hessian | A set of task/dataset names for which the Hessian matrix will be computed via autograd. |
120+
121+
For example, to enable stress and Hessian predictions with `omol` level of theory, the following settings can be used,
122+
123+
```{code-cell} python3
124+
settings = InferenceSettings(
125+
predict_untrained_stress={'omol'},
126+
predict_untrained_hessian={'omol'}
127+
)
128+
129+
predictor = pretrained_mlip.get_predict_unit(
130+
"uma-s-1p1", device="cuda", inference_settings=settings
131+
)
132+
```
133+
111134
## Multi-GPU Inference
112135

113136
UMA supports Graph Parallel inference natively. The graph is chunked into each rank and both the forward and backwards communication is handled by the built-in graph parallel algorithm with torch distributed. Because Multi-GPU inference requires special setup of communication protocols within a node and across nodes, we leverage [ray](https://www.ray.io/) to launch Ray Actors for each GPU-rank under the hood. This allows us to seamlessly scale to any infrastructure that can run Ray.

src/fairchem/core/calculate/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,11 @@
1212
FAIRChemCalculator,
1313
FormationEnergyCalculator,
1414
)
15+
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
1516

16-
__all__ = ["FAIRChemCalculator", "FormationEnergyCalculator", "InferenceBatcher"]
17+
__all__ = [
18+
"FAIRChemCalculator",
19+
"FormationEnergyCalculator",
20+
"InferenceBatcher",
21+
"InferenceSettings",
22+
]

src/fairchem/core/calculate/ase_calculator.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import TYPE_CHECKING, Literal
1515

1616
import numpy as np
17-
from ase.calculators.calculator import Calculator
17+
from ase.calculators.calculator import Calculator, PropertyNotImplementedError
1818
from ase.stress import full_3x3_to_voigt_6_stress
1919

2020
from fairchem.core.calculate import pretrained_mlip
@@ -181,6 +181,22 @@ def check_state(self, atoms: Atoms, tol: float = 1e-15) -> list:
181181
state.append("info")
182182
return state
183183

184+
def get_property(self, name, atoms=None, allow_calculation=True):
185+
try:
186+
result = super().get_property(
187+
name, atoms=atoms, allow_calculation=allow_calculation
188+
)
189+
except PropertyNotImplementedError as exc:
190+
msg = str(exc)
191+
if name in ("forces", "stress", "hessian"):
192+
msg += (
193+
f"\n {name} prediction can be enabled by setting `predict_untrained_{name}=set('{self.task_name}')` "
194+
f"in the InferenceSettings."
195+
)
196+
raise PropertyNotImplementedError(msg) from exc
197+
198+
return result
199+
184200
def calculate(
185201
self, atoms: Atoms, properties: list[str], system_changes: list[str]
186202
) -> None:
@@ -236,6 +252,9 @@ def calculate(
236252
stress = pred[calc_key].detach().cpu().numpy().reshape(3, 3)
237253
stress_voigt = full_3x3_to_voigt_6_stress(stress)
238254
self.results["stress"] = stress_voigt
255+
if calc_key == "hessian":
256+
hessian = pred[calc_key].detach().cpu().numpy().squeeze()
257+
self.results["hessian"] = hessian
239258

240259
def _check_atoms_pbc(self, atoms) -> None:
241260
"""

src/fairchem/core/models/base.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def __init__(
147147
# the old config system at some point, this will prevent the need to make major modifications to the trainer
148148
# because they all expect the name of the outputs directly instead of the head_name.property_name
149149
self.pass_through_head_outputs = pass_through_head_outputs
150+
self._tasks = None
151+
self._dataset_to_tasks = None
150152

151153
# Does this model support inference on single atom systems
152154
self.supports_single_atoms = supports_single_atoms
@@ -255,6 +257,13 @@ def forward(self, data: AtomicData):
255257

256258
return out
257259

260+
@property
261+
def tasks(self) -> dict[str, Task]:
262+
"""
263+
Mapping from task names to their associated Task objects.
264+
"""
265+
return self._tasks
266+
258267
@property
259268
def direct_forces(self) -> bool:
260269
"""
@@ -317,18 +326,46 @@ def setup_tasks(self, tasks_config: list) -> None:
317326
tasks_config: List of task configurations from checkpoint
318327
"""
319328
tasks = [hydra.utils.instantiate(task_config) for task_config in tasks_config]
320-
self.tasks = {t.name: t for t in tasks}
329+
self._tasks = {t.name: t for t in tasks}
321330
self._dataset_to_tasks = _get_dataset_to_tasks_map(tasks)
322331

323332
# Let backbone validate tasks
324333
self.backbone.validate_tasks(self._dataset_to_tasks)
325334

335+
def add_tasks(self, tasks: Sequence[Task]) -> None:
336+
"""
337+
Add additional tasks to the model.
338+
339+
This is useful for adding inference-only tasks that weren't in the
340+
original checkpoint, such as untrained derivative properties.
341+
342+
Args:
343+
tasks: List of Task objects to add
344+
"""
345+
if not hasattr(self, "tasks"):
346+
raise RuntimeError("setup_tasks() must be called before add_tasks()")
347+
348+
# Add new tasks to the tasks dict
349+
for task in tasks:
350+
if task.name in self.tasks:
351+
logging.warning(
352+
f"Task '{task.name}' already exists, skipping adding as a new task."
353+
)
354+
continue
355+
self.tasks[task.name] = task
356+
357+
# Rebuild dataset_to_tasks map
358+
self._dataset_to_tasks = _get_dataset_to_tasks_map(self.tasks.values())
359+
360+
# Let backbone validate the updated task set
361+
self.backbone.validate_tasks(self._dataset_to_tasks)
362+
326363
@property
327364
def dataset_to_tasks(self) -> dict[str, list]:
328365
"""
329366
Mapping from dataset names to their associated tasks.
330367
"""
331-
if not hasattr(self, "_dataset_to_tasks"):
368+
if self._dataset_to_tasks is None:
332369
raise RuntimeError(
333370
"setup_tasks() must be called before accessing dataset_to_tasks"
334371
)
@@ -346,6 +383,9 @@ def __init__(
346383
self.backbone = backbone
347384
self.output_heads = torch.nn.ModuleDict(heads)
348385
self.device = None
386+
self._tasks = None
387+
self._dataset_to_tasks = None
388+
349389
if freeze_backbone:
350390
for param in self.backbone.parameters():
351391
param.requires_grad = False
@@ -371,6 +411,13 @@ def forward(self, data):
371411
out[k] = self.output_heads[k](data, emb)
372412
return out
373413

414+
@property
415+
def tasks(self) -> dict[str, Task]:
416+
"""
417+
Mapping from task names to their associated Task objects.
418+
"""
419+
return self._tasks
420+
374421
@property
375422
def direct_forces(self) -> bool:
376423
"""
@@ -430,18 +477,41 @@ def setup_tasks(self, tasks_config: list) -> None:
430477
tasks_config: List of task configurations from checkpoint
431478
"""
432479
tasks = [hydra.utils.instantiate(task_config) for task_config in tasks_config]
433-
self.tasks = {t.name: t for t in tasks}
480+
self._tasks = {t.name: t for t in tasks}
434481
self._dataset_to_tasks = _get_dataset_to_tasks_map(tasks)
435482

436483
# Let backbone validate tasks
437484
self.backbone.validate_tasks(self._dataset_to_tasks)
438485

486+
def add_tasks(self, tasks: Sequence[Task]) -> None:
487+
"""
488+
Add additional tasks to the model.
489+
490+
This is useful for adding inference-only tasks that weren't in the
491+
original checkpoint, such as untrained derivative properties.
492+
493+
Args:
494+
tasks: List of Task objects to add
495+
"""
496+
# Add new tasks to the tasks dict
497+
for task in tasks:
498+
if task.name in self.tasks:
499+
logging.warning(f"Task '{task.name}' already exists, skipping addition")
500+
continue
501+
self._tasks[task.name] = task
502+
503+
# Rebuild dataset_to_tasks map
504+
self._dataset_to_tasks = _get_dataset_to_tasks_map(self.tasks.values())
505+
506+
# Let backbone validate the updated task set
507+
self.backbone.validate_tasks(self._dataset_to_tasks)
508+
439509
@property
440510
def dataset_to_tasks(self) -> dict[str, list]:
441511
"""
442512
Mapping from dataset names to their associated tasks.
443513
"""
444-
if not hasattr(self, "_dataset_to_tasks"):
514+
if self._dataset_to_tasks is None:
445515
raise RuntimeError(
446516
"setup_tasks() must be called before accessing dataset_to_tasks"
447517
)

src/fairchem/core/models/uma/escn_md.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,14 +1100,6 @@ def regress_forces(self) -> bool:
11001100
def regress_stress(self) -> bool:
11011101
return self.regress_config.stress
11021102

1103-
@property
1104-
def regress_hessian(self) -> bool:
1105-
return self.regress_config.hessian
1106-
1107-
@property
1108-
def hessian_vmap(self) -> bool:
1109-
return self.regress_config.hessian_vmap
1110-
11111103
@conditional_grad(torch.enable_grad())
11121104
def forward(
11131105
self, data: AtomicData, emb: dict[str, torch.Tensor]
@@ -1139,7 +1131,7 @@ def forward(
11391131

11401132
# Determine if we need create_graph for higher-order derivatives
11411133
# Hessian computation requires second derivatives, so we need create_graph=True
1142-
create_graph = self.training or self.regress_hessian
1134+
create_graph = self.training or self.regress_config.hessian
11431135

11441136
if self.regress_config.stress and not self.regress_config.direct_stress:
11451137
forces, stress = compute_forces_and_stress(
@@ -1158,7 +1150,7 @@ def forward(
11581150
else:
11591151
forces = None
11601152

1161-
if self.regress_hessian:
1153+
if self.regress_config.hessian:
11621154
if forces is None:
11631155
raise ValueError(
11641156
"Hessian computation requires forces. "
@@ -1171,7 +1163,10 @@ def forward(
11711163
)
11721164

11731165
hessian = compute_hessian(
1174-
forces, data["pos"], vmap=self.hessian_vmap, training=self.training
1166+
forces,
1167+
data["pos"],
1168+
vmap=self.regress_config.hessian_vmap,
1169+
training=create_graph,
11751170
)
11761171
outputs[hessian_key] = (
11771172
{"hessian": hessian} if self.wrap_property else hessian

src/fairchem/core/models/uma/outputs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ def compute_grad_component(vec):
244244
)[0]
245245

246246
# Use vmap to compute all components in parallel
247+
# autograd.grad returns shape [N, 3] (same as pos), vmap gives [N*3, N, 3]
247248
hessian = torch.vmap(compute_grad_component)(
248-
torch.eye(forces_flat.numel(), device=forces_flat.device)
249+
torch.eye(forces_flat.shape[0], device=forces_flat.device)
249250
)
250251

251-
return hessian
252+
n_dof = forces_flat.numel()
253+
return hessian.reshape(n_dof, n_dof)
252254

253255

254256
def compute_hessian_loop(
@@ -306,7 +308,8 @@ def compute_hessian(
306308
training: Whether to create graph for third-order derivatives.
307309
308310
Returns:
309-
Hessian matrix of shape [N*3, N*3].
311+
Hessian matrix of shape [1, N*3, N*3] (batch dim always 1 since
312+
hessian requires single-system batches).
310313
311314
Note:
312315
Graph parallel (GP) mode is not fully supported. The Hessian should
@@ -325,4 +328,4 @@ def compute_hessian(
325328
else:
326329
hessian = compute_hessian_loop(forces_flat, pos, create_graph=training)
327330

328-
return hessian
331+
return hessian.unsqueeze(0)

src/fairchem/core/units/mlip_unit/api/inference.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from __future__ import annotations
99

10-
from dataclasses import asdict, dataclass
10+
from dataclasses import asdict, dataclass, field
1111

1212
import torch # - needed at runtime for dataclass field type resolution
1313

@@ -106,6 +106,17 @@ class InferenceSettings:
106106
# Set to "umas_fast_gpu" to enable highly optimized backend with triton kernels for maximum speed.
107107
execution_mode: str = "general"
108108

109+
# New fields for untrained derivative properties
110+
# These flags request computation of properties NOT in the checkpoint's task list.
111+
# If a property is already in the checkpoint (e.g., omol_forces task exists),
112+
# it will be computed regardless of these flags.
113+
# Specify datasets as a set of strings (e.g., {"omol", "oc20"}).
114+
# Empty set means no untrained properties will be computed (default).
115+
predict_untrained_forces: set[str] = field(default_factory=set)
116+
predict_untrained_stress: set[str] = field(default_factory=set)
117+
predict_untrained_hessian: set[str] = field(default_factory=set)
118+
hessian_vmap: bool = True # Use fast vmap vs memory-efficient loop
119+
109120
def __post_init__(self):
110121
if isinstance(self.base_precision_dtype, str):
111122
self.base_precision_dtype = getattr(torch, self.base_precision_dtype)

0 commit comments

Comments
 (0)