Skip to content

Commit 5f9aa79

Browse files
rayg1234misko
andauthored
Consolidate turbo and turbo_umas (#1898)
Adding `turbo_umas` mode was not a good choice and is too confusing for users. This will just go back to having a single `turbo` mode and select the appropriate acceleration backend automatically Mentioned in #1872 --------- Co-authored-by: misko <misko@meta.com>
1 parent c45dbcd commit 5f9aa79

7 files changed

Lines changed: 158 additions & 176 deletions

File tree

docs/core/common_tasks/ase_calculator.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ The advanced user might quickly see that **default** mode and **turbo** mode are
8787
| edge_chunk_size | Experimental. Used for padding edge sizes. This helps reduce re-compilations from torch compile, default to None |
8888
| use_quaternion_wigner | enable quaternion-based Wigner D matrix computation. If false we fall back to euler-angle based rotations. default True. |
8989
| base_precision_dtype | governs the main precision type of the computation, default to FP32, FP64 is also supported |
90-
| execution_mode | This allows manually toggling custom backends to maximize speed ups. default to "general". "umas-fast-gpu" will introduce 30-40% speedup for uma-s line of models. |
90+
| execution_mode | This allows manually toggling custom backends to maximize speed ups. default to "None", when set to "None", the predictor will automatically determine the best backend. For example, "umas-fast-gpu" will introduce 30-40% speedup for uma-s line of models. |
9191

9292
For example, for an MD simulation use-case for a system of ~500 atoms, we can choose to use a custom mode like the following:
9393

src/fairchem/core/models/uma/escn_md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ def prepare_for_inference(self, data: AtomicData, settings: InferenceSettings):
976976
self._merged_composition = None
977977

978978
# Validate settings against backend requirements (fail early)
979-
self.backend.validate(self, settings)
979+
self.backend.validate(self.lmax, self.mmax, settings)
980980

981981
if settings.merge_mole:
982982
assert (

src/fairchem/core/models/uma/nn/execution_backends.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
from dataclasses import replace
1011
from enum import Enum
1112
from typing import TYPE_CHECKING
1213

@@ -15,14 +16,17 @@
1516
from fairchem.core.models.uma.nn.unified_radial import UnifiedRadialMLP
1617

1718
if TYPE_CHECKING:
18-
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
19+
from fairchem.core.units.mlip_unit.api.inference import (
20+
InferenceSettings,
21+
)
1922

2023
__all__ = [
2124
"ExecutionMode",
2225
"ExecutionBackend",
2326
"UMASFastPytorchBackend",
2427
"UMASFastGPUBackend",
2528
"get_execution_backend",
29+
"maybe_update_settings_backend",
2630
]
2731

2832
# Indices for m=0 spherical harmonic coefficients in L-major ordering (lmax=2)
@@ -58,18 +62,19 @@ class ExecutionBackend:
5862

5963
@staticmethod
6064
def validate(
61-
model: torch.nn.Module,
62-
settings: InferenceSettings | None = None,
65+
lmax: int,
66+
mmax: int,
67+
settings: InferenceSettings,
6368
) -> None:
6469
"""
65-
Validate that model and settings are compatible with this backend.
70+
Validate that model parameters and settings are compatible with this backend.
6671
67-
Called during model construction (settings=None) and before
68-
first inference (settings provided).
72+
Called before first inference.
6973
7074
Args:
71-
model: The backbone model to validate.
72-
settings: Inference settings, or None at construction time.
75+
lmax: Maximum degree of spherical harmonics.
76+
mmax: Maximum order of spherical harmonics.
77+
settings: Inference settings.
7378
7479
Raises:
7580
ValueError: If incompatible with this backend.
@@ -265,17 +270,13 @@ class UMASFastPytorchBackend(ExecutionBackend):
265270

266271
@staticmethod
267272
def validate(
268-
model: torch.nn.Module,
269-
settings: InferenceSettings | None = None,
273+
lmax: int,
274+
mmax: int,
275+
settings: InferenceSettings,
270276
) -> None:
271277
"""
272278
Validate that settings are compatible with fast pytorch mode.
273279
"""
274-
# Check activation_checkpointing from model (chunk_size is None when disabled)
275-
if model.edge_degree_embedding.activation_checkpoint_chunk_size is not None:
276-
raise ValueError(
277-
"UMASFastPytorchBackend requires activation_checkpointing=False"
278-
)
279280
# Also reject if user tries to enable it via inference settings
280281
if settings is not None and settings.activation_checkpointing:
281282
raise ValueError(
@@ -338,15 +339,16 @@ class UMASFastGPUBackend(UMASFastPytorchBackend):
338339

339340
@staticmethod
340341
def validate(
341-
model: torch.nn.Module,
342-
settings: InferenceSettings | None = None,
342+
lmax: int,
343+
mmax: int,
344+
settings: InferenceSettings,
343345
) -> None:
344-
UMASFastPytorchBackend.validate(model, settings)
346+
UMASFastPytorchBackend.validate(lmax, mmax, settings)
345347
if not torch.cuda.is_available():
346348
raise ValueError("umas_fast_gpu requires CUDA")
347-
if model.lmax != 2 or model.mmax != 2:
349+
if lmax != 2 or mmax != 2:
348350
raise ValueError("umas_fast_gpu requires lmax==2 and mmax==2")
349-
if settings is not None and not settings.merge_mole:
351+
if not settings.merge_mole:
350352
raise ValueError("umas_fast_gpu requires merge_mole=True")
351353

352354
@staticmethod
@@ -446,3 +448,33 @@ def get_execution_backend(
446448
available = [m.value for m in _EXECUTION_BACKENDS]
447449
raise ValueError(f"Unknown execution mode: {mode}. Available: {available}")
448450
return _EXECUTION_BACKENDS[mode]()
451+
452+
453+
def maybe_update_settings_backend(
454+
settings: InferenceSettings,
455+
model_config: dict,
456+
) -> InferenceSettings:
457+
"""
458+
Update inference settings to use UMAS_FAST_GPU if conditions are met.
459+
460+
Sets execution_mode to UMAS_FAST_GPU if:
461+
- execution_mode is not already set
462+
- UMASFastGPUBackend.validate passes for the model and settings
463+
464+
Args:
465+
settings: Current inference settings.
466+
model_config: The model configuration dictionary to validate.
467+
468+
Returns:
469+
Updated inference settings with the appropriate execution mode.
470+
"""
471+
if settings.execution_mode is not None:
472+
return settings
473+
474+
try:
475+
lmax = model_config["backbone"]["lmax"]
476+
mmax = model_config["backbone"]["mmax"]
477+
UMASFastGPUBackend.validate(lmax, mmax, settings)
478+
return replace(settings, execution_mode=ExecutionMode.UMAS_FAST_GPU)
479+
except (ValueError, KeyError):
480+
return settings

src/fairchem/core/units/mlip_unit/api/inference.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ class InferenceSettings:
101101
# Accepts a torch.dtype or a string in ALLOWED_DTYPES (e.g. "float32").
102102
base_precision_dtype: torch.dtype | str = torch.float32
103103

104-
# Execution backend mode for the backbone. The default is "general".
104+
# Execution backend mode for the backbone.
105+
# Set to "general" for the default execution mode that works across all models and hardware.
105106
# Set to "umas_fast_pytorch" to enable block-diagonal SO2 GEMM conversion for faster inference.
106107
# Set to "umas_fast_gpu" to enable highly optimized backend with triton kernels for maximum speed.
107-
execution_mode: str = "general"
108+
# If None, the predictor will decide the best execution mode based on the model and hardware capabilities (e.g., will choose "umas_fast_gpu" for uma-s if running on compatible Nvidia GPU).
109+
execution_mode: str | None = None
108110

109111
# New fields for untrained derivative properties
110112
# These flags request computation of properties NOT in the checkpoint's task list.
@@ -156,7 +158,6 @@ def inference_settings_default():
156158
compile=False,
157159
external_graph_gen=False,
158160
internal_graph_gen_version=2,
159-
execution_mode="general",
160161
)
161162

162163

@@ -175,19 +176,6 @@ def inference_settings_turbo():
175176
)
176177

177178

178-
# this setting is specific for UMA-S on cuda for maximum speed.
179-
def inference_settings_turbo_umas():
180-
return InferenceSettings(
181-
tf32=True,
182-
activation_checkpointing=False,
183-
merge_mole=True,
184-
compile=True,
185-
external_graph_gen=False,
186-
internal_graph_gen_version=2,
187-
execution_mode="umas_fast_gpu",
188-
)
189-
190-
191179
# this mode corresponds to the default settings used for training and evaluation
192180
def inference_settings_traineval():
193181
return InferenceSettings(
@@ -203,7 +191,6 @@ def inference_settings_traineval():
203191
"default": inference_settings_default(),
204192
"turbo": inference_settings_turbo(),
205193
"traineval": inference_settings_traineval(),
206-
"turbo_umas": inference_settings_turbo_umas(),
207194
}
208195

209196

src/fairchem/core/units/mlip_unit/predict.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
setup_env_local_multi_gpu,
3737
)
3838
from fairchem.core.datasets.atomic_data import AtomicData, warn_if_upcasting
39+
from fairchem.core.models.uma.nn.execution_backends import (
40+
maybe_update_settings_backend,
41+
)
3942
from fairchem.core.units.mlip_unit import InferenceSettings
4043
from fairchem.core.units.mlip_unit.mlip_unit import OutputSpec, Task
4144
from fairchem.core.units.mlip_unit.single_atom_patch import (
@@ -114,7 +117,7 @@ def __init__(
114117
self.inference_settings = inference_settings
115118
self._setup_threads(inference_settings)
116119

117-
if inference_settings.wigner_cuda:
120+
if self.inference_settings.wigner_cuda:
118121
logging.warning(
119122
"The wigner_cuda flag is deprecated and will be removed in future versions."
120123
)
@@ -124,16 +127,21 @@ def __init__(
124127
inference_model_path, map_location="cpu", weights_only=False
125128
)
126129

130+
# if the model is uma-s and the execution mode is not explicitly set, default to the optimized uma-s gpu execution mode
131+
self.inference_settings = maybe_update_settings_backend(
132+
self.inference_settings, checkpoint.model_config
133+
)
134+
127135
# Build model-specific overrides
128136
final_overrides = self._build_overrides_from_settings(
129-
checkpoint, overrides, inference_settings
137+
checkpoint, overrides, self.inference_settings
130138
)
131139

132140
# Set default dtype during model construction so that non-persistent
133141
# buffers (SO3_Grid matrices, CoefficientMapping) are created at the
134142
# requested precision rather than being cast from float32 later.
135143
prev_dtype = torch.get_default_dtype()
136-
torch.set_default_dtype(inference_settings.base_precision_dtype)
144+
torch.set_default_dtype(self.inference_settings.base_precision_dtype)
137145

138146
try:
139147
# Load model with overrides, passing pre-loaded checkpoint
@@ -151,17 +159,17 @@ def __init__(
151159

152160
# Get backbone's default untrained tasks (if supported and enabled)
153161
default_backbone_tasks = []
154-
if inference_settings.auto_add_default_untrained_tasks:
162+
if self.inference_settings.auto_add_default_untrained_tasks:
155163
backbone = self.model.module.backbone
156164
if hasattr(backbone, "get_default_untrained_tasks"):
157165
default_backbone_tasks = backbone.get_default_untrained_tasks(
158166
self.model.module.tasks,
159-
inference_settings,
167+
self.inference_settings,
160168
)
161169

162170
# Create explicitly requested untrained tasks
163171
untrained_tasks = self._create_untrained_tasks(
164-
inference_settings, self.model.module.tasks
172+
self.inference_settings, self.model.module.tasks
165173
)
166174

167175
explicit_task_names = {t.name for t in untrained_tasks}

0 commit comments

Comments
 (0)