Skip to content

Commit d1664c9

Browse files
authored
Make FastCSP compatible w/ Genarris3 updates (#1959)
1 parent cb0652f commit d1664c9

14 files changed

Lines changed: 72 additions & 108 deletions

File tree

packages/fairchem-core/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"huggingface_hub>=0.27.1",
2020
"ase>=3.26.0",
2121
"ase-db-backends>=0.10.0",
22-
"monty>=v2025.1.3",
22+
"monty>=2026.2.18",
2323
"clusterscope==0.0.18",
2424
"setuptools<81.0.0",
2525
"requests",
@@ -38,7 +38,7 @@ dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"
3838
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets", "jupyter_book>=2.0", "torch-dftd"]
3939
adsorbml = ["dscribe", "x3dase", "scikit-image"]
4040
torchsim = ["torch-sim-atomistic>=0.5.2; python_version >= '3.12'"]
41-
extras = ["pymatgen", "quacc[phonons]>=0.15.3", "pandas", "nvalchemi-toolkit-ops==0.2.0", "pyarrow"]
41+
extras = ["pymatgen>=2025.1.9", "quacc[phonons]>=0.15.3", "pandas", "nvalchemi-toolkit-ops==0.2.0", "pyarrow"]
4242

4343
[project.scripts]
4444
fairchem = "fairchem.core._cli:main"

packages/fairchem-demo-ocpapi/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"requests",
2323
"responses == 0.23.2",
2424
"tenacity == 8.2.3",
25-
"tqdm == 4.66.1",
25+
"tqdm",
2626
]
2727

2828
[project.urls]

src/fairchem/applications/fastcsp/core/configs/example_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ molecules: configs/example_systems.csv # File containing list of target molecule
77
# Genarris configuration for structure generation
88
genarris:
99
# Execution environment
10-
python_cmd: /path/to/python/with/genarris/installed # Python interpreter with Genarris installed
11-
genarris_script: /path/to/genarris_master.py # Path to Genarris master script
1210
mpi_launcher: /path/to/mpirun # MPI launcher: mpirun, srun, or custom path
13-
base_config: configs/gnrs_base.conf # Base Genarris configuration template
11+
python_cmd: /path/to/python/with/genarris/installed # Python interpreter with Genarris installed
12+
genarris_cli: /path/to/genarris_cli.py # Path to Genarris cli script
13+
genarris_base_config: configs/gnrs_base.conf # Base Genarris configuration template
1414
# Structure generation parameters
1515
vars:
1616
Z: [1, 2, 3, 4, 6, 8] # [1] by default

src/fairchem/applications/fastcsp/core/configs/genarris_base.conf

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,26 @@
66
name = ???
77
molecule_path = ???
88
Z = ???
9-
log_level = debug
10-
downselection_workflow = robust_flow
11-
restart = True
9+
log_level = info
1210

13-
[generation]
14-
num_structures_per_spg = ???
15-
spg_distribution_type = ???
16-
specific_radius_proportion = 0.95
17-
max_attempts_per_spg = 100000000
18-
tol = 0.01
19-
unit_cell_volume_mean = predict
20-
volume_mult = 1.5
21-
max_attempts_per_volume = 10000000
22-
generation_type = crystal
23-
natural_cutoff_mult = 1.2
11+
[workflow]
12+
tasks = ['generation', 'symm_rigid_press']
2413

25-
[symm_rigid_press]
26-
sr = 0.85
27-
method = BFGS
28-
natural_cutoff_mult = 1.2
14+
[generation]
15+
num_structures_per_spg = ???
16+
spg_distribution_type = ???
17+
specific_radius_proportion = 0.95
18+
max_attempts_per_spg = 100000000
2919
tol = 0.01
30-
debug_flag = False
20+
unit_cell_volume_mean = predict
21+
volume_mult = 1.5
22+
max_attempts_per_volume = 10000000
23+
generation_type = crystal
24+
natural_cutoff_mult = 1.2
25+
26+
[symm_rigid_press]
27+
sr = 0.85
28+
method = BFGS
29+
natural_cutoff_mult = 1.2
30+
tol = 0.01
31+
debug_flag = False

src/fairchem/applications/fastcsp/core/utils/configuration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def validate_config(config: dict[str, Any], stages: list[str]) -> None:
5050
stage_requirements = {
5151
"generate": {
5252
"keys": ["molecules", "genarris"],
53-
"nested": {"genarris": ["python_cmd", "genarris_script", "base_config"]},
53+
"nested": {
54+
"genarris": ["python_cmd", "genarris_cli", "genarris_base_config"]
55+
},
5456
},
5557
"process_generated": {
5658
"keys": ["pre_relaxation_filter"],

src/fairchem/applications/fastcsp/core/workflow/generate.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from __future__ import annotations
2828

2929
import ast
30+
import json
3031
import shutil
3132
from configparser import ConfigParser
3233
from pathlib import Path
@@ -64,7 +65,7 @@ def create_gnrs_submit_script(
6465
gnrs_config: Genarris configuration containing execution parameters:
6566
- mpi_launcher: MPI command to use (default: "mpirun")
6667
- python_cmd: Python executable path (default: "python")
67-
- genarris_script: Genarris main script name (default: "genarris_master.py")
68+
- genarris_cli: Genarris main script name (default: "genarris_cli.py")
6869
genarris_slurm_config: SLURM resource allocation parameters
6970
single_gnrs_folder: Directory where the SLURM script will be created
7071
@@ -83,7 +84,7 @@ def create_gnrs_submit_script(
8384
export OMP_NUM_THREADS=1
8485
8586
{gnrs_config.get("mpi_launcher", "mpirun")} -np {genarris_slurm_config.get("nodes", 1) * genarris_slurm_config["ntasks-per-node"]} \\
86-
{gnrs_config.get("python_cmd", "python")} {gnrs_config.get("genarris_script", "genarris_master.py")} {single_gnrs_folder}/ui.conf > {single_gnrs_folder}/Genarris.out
87+
{gnrs_config.get("python_cmd", "python")} {gnrs_config.get("genarris_cli", "cli.py")} --config {single_gnrs_folder}/ui.conf > {single_gnrs_folder}/Genarris.out
8788
"""
8889

8990
with open(single_gnrs_folder / "slurm.sh", "w") as f:
@@ -96,7 +97,7 @@ def create_gnrs_config(
9697
mol_name: str,
9798
geometry_path: str | Path,
9899
num_structures: int,
99-
spg_info: str,
100+
spg_distribution_type: str | list[int],
100101
Z: int,
101102
):
102103
"""
@@ -112,7 +113,7 @@ def create_gnrs_config(
112113
mol_name: Identifier for the molecule being processed
113114
geometry_path: Path to the molecular geometry file (XYZ, SDF, etc.)
114115
num_structures: Number of crystal structures to generate
115-
spg_info: Space group specification (number or list or "standard")
116+
spg_distribution_type: Space group distribution type ("standard" or custom list[int])
116117
Z: Number of molecules per unit cell (Z-value)
117118
118119
Side Effects:
@@ -123,10 +124,10 @@ def create_gnrs_config(
123124
config.read_file(config_file)
124125

125126
config["master"]["name"] = mol_name
126-
config["master"]["molecule_path"] = str(geometry_path)
127+
config["master"]["molecule_path"] = json.dumps([str(geometry_path)])
127128
config["master"]["Z"] = str(Z)
128129
config["generation"]["num_structures_per_spg"] = str(num_structures)
129-
config["generation"]["spg_distribution_type"] = spg_info
130+
config["generation"]["spg_distribution_type"] = spg_distribution_type
130131

131132
with open(output_dir / "ui.conf", "w") as f:
132133
config.write(f)
@@ -143,9 +144,11 @@ def create_genarris_jobs(
143144
logger = get_central_logger()
144145
logger.info(f"Starting Genarris generation for {mol_info['name']}")
145146

146-
gnrs_base_config = gnrs_config.get("base_config")
147+
gnrs_base_config = gnrs_config.get("genarris_base_config")
147148
if gnrs_base_config is None:
148-
raise KeyError("Genarris 'base_config' section is missing in the config file.")
149+
raise KeyError(
150+
"Genarris 'genarris_base_config' section is missing in the config file."
151+
)
149152
logger.info(f"Using Genarris base configuration: {gnrs_base_config}")
150153

151154
# Parameters for each Genarris run
@@ -159,24 +162,24 @@ def create_genarris_jobs(
159162

160163
z_list = [str(z) for z in gnrs_vars.get("Z", [1])]
161164
num_structures_per_spg = gnrs_vars.get("num_structures_per_spg", 500)
162-
spg_info = gnrs_vars.get("spg_info", "standard")
165+
spg_distribution_type = gnrs_vars.get("spg_distribution_type", "standard")
163166

164167
# molecule specific spg and z_list info from csv file if provided
165168
# mol_info["z"] is a list of z_values
166169
# mol_info["spg"] is a list of lists of spg values that correspond to each z_value
167170
if gnrs_vars.get("read_z_from_file", False):
168171
z_list = [str(z) for z in ast.literal_eval(mol_info["z"])]
169172
if gnrs_vars.get("read_spg_from_file", False):
170-
spg_info = ast.literal_eval(mol_info["spg"])
171-
if spg_info == "standard":
172-
spg_info = ["standard"] * len(z_list)
173-
if isinstance(spg_info[0], int):
174-
spg_info = [spg_info]
175-
if len(spg_info) == 1 and len(z_list) > 1:
176-
spg_info = spg_info * len(z_list)
177-
if len(spg_info) != len(z_list):
173+
spg_distribution_type = ast.literal_eval(mol_info["spg"])
174+
if spg_distribution_type == "standard":
175+
spg_distribution_type = ["standard"] * len(z_list)
176+
if isinstance(spg_distribution_type[0], int):
177+
spg_distribution_type = [spg_distribution_type]
178+
if len(spg_distribution_type) == 1 and len(z_list) > 1:
179+
spg_distribution_type = spg_distribution_type * len(z_list)
180+
if len(spg_distribution_type) != len(z_list):
178181
raise ValueError(
179-
f"Length of spg_info {spg_info} does not match length of z_list {z_list} for molecule {mol_info['name']}."
182+
f"Length of spg_distribution_type {spg_distribution_type} does not match length of z_list {z_list} for molecule {mol_info['name']}."
180183
)
181184
mol = mol_info["name"] # System name
182185

@@ -222,7 +225,7 @@ def create_genarris_jobs(
222225
geometry_path=new_conf_path,
223226
Z=z,
224227
num_structures=num_structures_per_spg,
225-
spg_info=str(spg_info[i]),
228+
spg_distribution_type=str(spg_distribution_type[i]),
226229
)
227230

228231
# Create SLURM submission script if it doesn't exist

src/fairchem/core/common/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ def decorator(func):
7575
@wraps(func)
7676
def cls_method(self, *args, **kwargs):
7777
f = func
78-
if self.regress_forces and not getattr(self, "direct_forces", 0):
78+
if hasattr(self, "regress_config"):
79+
regress_forces = self.regress_config.forces
80+
direct_forces = self.regress_config.direct_forces
81+
else:
82+
regress_forces = self.regress_forces
83+
direct_forces = getattr(self, "direct_forces", 0)
84+
if regress_forces and not direct_forces:
7985
f = dec(func)
8086
return f(self, *args, **kwargs)
8187

src/fairchem/core/models/base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,6 @@ def tasks(self) -> dict[str, Task]:
161161
"""
162162
return self._tasks
163163

164-
@property
165-
def direct_forces(self) -> bool:
166-
"""
167-
Whether this model uses direct force prediction.
168-
"""
169-
return getattr(self.backbone, "direct_forces", False)
170-
171164
@property
172165
def dataset_to_tasks(self) -> dict[str, list]:
173166
"""
@@ -191,8 +184,10 @@ def _validate_task_compatibility(self, task: Task) -> None:
191184
"""
192185
derivative_properties = ("forces", "stress", "hessian")
193186

187+
backbone_regress_config = getattr(self.backbone, "regress_config", None)
194188
if (
195-
self.direct_forces
189+
backbone_regress_config is not None
190+
and backbone_regress_config.direct_forces
196191
and task.inference_only
197192
and task.property in derivative_properties
198193
):

src/fairchem/core/models/uma/escn_md.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -512,18 +512,6 @@ def __init__(
512512
)
513513
self.register_buffer("coefficient_index", coefficient_index, persistent=False)
514514

515-
@property # deprecate this
516-
def direct_forces(self) -> bool:
517-
return self.regress_config.direct_forces
518-
519-
@property # deprecate this
520-
def regress_forces(self) -> bool:
521-
return self.regress_config.forces
522-
523-
@property # deprecate this
524-
def regress_stress(self) -> bool:
525-
return self.regress_config.stress
526-
527515
def balance_channels(
528516
self,
529517
x_message_prime: torch.Tensor,
@@ -903,7 +891,7 @@ def get_default_untrained_tasks(
903891
stress computation requires energy-conserving force computation.
904892
"""
905893
# Direct force models can't compute stress via autograd
906-
if self.direct_forces:
894+
if self.regress_config.direct_forces:
907895
return []
908896

909897
tasks = []
@@ -1065,14 +1053,6 @@ def __init__(
10651053
backbone.force_block = None
10661054
self.regress_config = backbone.regress_config
10671055

1068-
@property
1069-
def regress_forces(self) -> bool:
1070-
return self.regress_config.forces
1071-
1072-
@property
1073-
def regress_stress(self) -> bool:
1074-
return self.regress_config.stress
1075-
10761056
@conditional_grad(torch.enable_grad())
10771057
def forward(
10781058
self, data: AtomicData, emb: dict[str, torch.Tensor]

src/fairchem/core/models/uma/escn_moe.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,6 @@ def __init__(
449449
self.merged_on_dataset = None
450450
self.non_merged_dataset_names: list[str] = []
451451

452-
@property
453-
def regress_forces(self) -> bool:
454-
return self.regress_config.forces
455-
456-
@property
457-
def regress_stress(self) -> bool:
458-
return self.regress_config.stress
459-
460452
@staticmethod
461453
def _build_expert_mapping(
462454
dataset_names: list[str] | None,
@@ -608,14 +600,6 @@ def __init__(
608600
# keep track if this head has been merged or not
609601
self.merged_on_dataset = None
610602

611-
@property
612-
def regress_forces(self) -> bool:
613-
return self.regress_config.forces
614-
615-
@property
616-
def regress_stress(self) -> bool:
617-
return self.regress_config.stress
618-
619603
def merge_MOLE_model(self, data):
620604
self.merged_on_dataset = data.dataset[0]
621605
self.non_merged_dataset_names = [

0 commit comments

Comments
 (0)