|
7 | 7 |
|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
| 10 | +from dataclasses import replace |
10 | 11 | from enum import Enum |
11 | 12 | from typing import TYPE_CHECKING |
12 | 13 |
|
|
15 | 16 | from fairchem.core.models.uma.nn.unified_radial import UnifiedRadialMLP |
16 | 17 |
|
17 | 18 | if TYPE_CHECKING: |
18 | | - from fairchem.core.units.mlip_unit.api.inference import InferenceSettings |
| 19 | + from fairchem.core.units.mlip_unit.api.inference import ( |
| 20 | + InferenceSettings, |
| 21 | + ) |
19 | 22 |
|
20 | 23 | __all__ = [ |
21 | 24 | "ExecutionMode", |
22 | 25 | "ExecutionBackend", |
23 | 26 | "UMASFastPytorchBackend", |
24 | 27 | "UMASFastGPUBackend", |
25 | 28 | "get_execution_backend", |
| 29 | + "maybe_update_settings_backend", |
26 | 30 | ] |
27 | 31 |
|
28 | 32 | # Indices for m=0 spherical harmonic coefficients in L-major ordering (lmax=2) |
@@ -58,18 +62,19 @@ class ExecutionBackend: |
58 | 62 |
|
59 | 63 | @staticmethod |
60 | 64 | def validate( |
61 | | - model: torch.nn.Module, |
62 | | - settings: InferenceSettings | None = None, |
| 65 | + lmax: int, |
| 66 | + mmax: int, |
| 67 | + settings: InferenceSettings, |
63 | 68 | ) -> None: |
64 | 69 | """ |
65 | | - Validate that model and settings are compatible with this backend. |
| 70 | + Validate that model parameters and settings are compatible with this backend. |
66 | 71 |
|
67 | | - Called during model construction (settings=None) and before |
68 | | - first inference (settings provided). |
| 72 | + Called before first inference. |
69 | 73 |
|
70 | 74 | Args: |
71 | | - model: The backbone model to validate. |
72 | | - settings: Inference settings, or None at construction time. |
| 75 | + lmax: Maximum degree of spherical harmonics. |
| 76 | + mmax: Maximum order of spherical harmonics. |
| 77 | + settings: Inference settings. |
73 | 78 |
|
74 | 79 | Raises: |
75 | 80 | ValueError: If incompatible with this backend. |
@@ -265,17 +270,13 @@ class UMASFastPytorchBackend(ExecutionBackend): |
265 | 270 |
|
266 | 271 | @staticmethod |
267 | 272 | def validate( |
268 | | - model: torch.nn.Module, |
269 | | - settings: InferenceSettings | None = None, |
| 273 | + lmax: int, |
| 274 | + mmax: int, |
| 275 | + settings: InferenceSettings, |
270 | 276 | ) -> None: |
271 | 277 | """ |
272 | 278 | Validate that settings are compatible with fast pytorch mode. |
273 | 279 | """ |
274 | | - # Check activation_checkpointing from model (chunk_size is None when disabled) |
275 | | - if model.edge_degree_embedding.activation_checkpoint_chunk_size is not None: |
276 | | - raise ValueError( |
277 | | - "UMASFastPytorchBackend requires activation_checkpointing=False" |
278 | | - ) |
279 | 280 | # Also reject if user tries to enable it via inference settings |
280 | 281 | if settings is not None and settings.activation_checkpointing: |
281 | 282 | raise ValueError( |
@@ -338,15 +339,16 @@ class UMASFastGPUBackend(UMASFastPytorchBackend): |
338 | 339 |
|
339 | 340 | @staticmethod |
340 | 341 | def validate( |
341 | | - model: torch.nn.Module, |
342 | | - settings: InferenceSettings | None = None, |
| 342 | + lmax: int, |
| 343 | + mmax: int, |
| 344 | + settings: InferenceSettings, |
343 | 345 | ) -> None: |
344 | | - UMASFastPytorchBackend.validate(model, settings) |
| 346 | + UMASFastPytorchBackend.validate(lmax, mmax, settings) |
345 | 347 | if not torch.cuda.is_available(): |
346 | 348 | raise ValueError("umas_fast_gpu requires CUDA") |
347 | | - if model.lmax != 2 or model.mmax != 2: |
| 349 | + if lmax != 2 or mmax != 2: |
348 | 350 | raise ValueError("umas_fast_gpu requires lmax==2 and mmax==2") |
349 | | - if settings is not None and not settings.merge_mole: |
| 351 | + if not settings.merge_mole: |
350 | 352 | raise ValueError("umas_fast_gpu requires merge_mole=True") |
351 | 353 |
|
352 | 354 | @staticmethod |
@@ -446,3 +448,33 @@ def get_execution_backend( |
446 | 448 | available = [m.value for m in _EXECUTION_BACKENDS] |
447 | 449 | raise ValueError(f"Unknown execution mode: {mode}. Available: {available}") |
448 | 450 | return _EXECUTION_BACKENDS[mode]() |
| 451 | + |
| 452 | + |
| 453 | +def maybe_update_settings_backend( |
| 454 | + settings: InferenceSettings, |
| 455 | + model_config: dict, |
| 456 | +) -> InferenceSettings: |
| 457 | + """ |
| 458 | + Update inference settings to use UMAS_FAST_GPU if conditions are met. |
| 459 | +
|
| 460 | + Sets execution_mode to UMAS_FAST_GPU if: |
| 461 | + - execution_mode is not already set |
| 462 | + - UMASFastGPUBackend.validate passes for the model and settings |
| 463 | +
|
| 464 | + Args: |
| 465 | + settings: Current inference settings. |
| 466 | + model_config: The model configuration dictionary to validate. |
| 467 | +
|
| 468 | + Returns: |
| 469 | + Updated inference settings with the appropriate execution mode. |
| 470 | + """ |
| 471 | + if settings.execution_mode is not None: |
| 472 | + return settings |
| 473 | + |
| 474 | + try: |
| 475 | + lmax = model_config["backbone"]["lmax"] |
| 476 | + mmax = model_config["backbone"]["mmax"] |
| 477 | + UMASFastGPUBackend.validate(lmax, mmax, settings) |
| 478 | + return replace(settings, execution_mode=ExecutionMode.UMAS_FAST_GPU) |
| 479 | + except (ValueError, KeyError): |
| 480 | + return settings |
0 commit comments