Skip to content

Commit c03f9cb

Browse files
ansobolevAndrei Sobolev
andauthored
Fix test_ratio = 0 case (#478)
* Add safeguard when test_ratio = 0 * Change `test_ratio` in failed tests * Add `disable_testing` setting and sanity checks * Add `disable_testing` logic to all jobs except fitting * Changed `test_ratio` back to 0 in tests * Added `disable_testing` to fitting jobs * Test fixes --------- Co-authored-by: Andrei Sobolev <sobolev@ms1p.org>
1 parent cd1a096 commit c03f9cb

10 files changed

Lines changed: 149 additions & 29 deletions

File tree

src/autoplex/auto/rss/flows.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""RSS (random structure searching) flow for exploring and learning potential energy surfaces from scratch."""
22

3+
import logging
34
from dataclasses import dataclass, field
45

56
from atomate2.forcefields.jobs import ForceFieldStaticMaker
@@ -12,6 +13,10 @@
1213
from autoplex.misc.castep.jobs import CastepStaticMaker
1314
from autoplex.settings import RssConfig
1415

16+
logging.basicConfig(
17+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18+
)
19+
1520

1621
@dataclass
1722
class RssMaker(Maker):
@@ -309,6 +314,19 @@ def make(self, **kwargs):
309314
"'train_from_scratch' must be set in the configuration file or passed as a keyword argument!!"
310315
)
311316

317+
if config_params["disable_testing"] and config_params["test_ratio"] != 0.0:
318+
logging.warning("Testing disabled. Setting test_ratio to 0.0.")
319+
config_params["test_ratio"] = 0.0
320+
321+
if (
322+
config_params["train_from_scratch"]
323+
and config_params["test_ratio"] == 0.0
324+
and not config_params["disable_testing"]
325+
):
326+
raise ValueError(
327+
"A prebuilt test set should be present if testing is enabled and `test_ratio` is set to 0."
328+
)
329+
312330
rss_flow = []
313331

314332
if config_params["train_from_scratch"]:

src/autoplex/auto/rss/jobs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def initial_rss(
8585
dft_ref_file: str = "dft_ref.extxyz",
8686
rss_group: str = "initial",
8787
test_ratio: float = 0.1,
88+
disable_testing: bool = False,
8889
regularization: bool = False,
8990
retain_existing_sigma: bool = False,
9091
scheme: str | None = None,
@@ -171,6 +172,8 @@ def initial_rss(
171172
test_ratio: float
172173
The proportion of the test set after splitting the data.
173174
If None, no splitting will be performed. Default is 0.1.
175+
disable_testing: bool
176+
Whether to disable running the model on test data. Default is False.
174177
regularization: bool
175178
If true, apply regularization. This only works for GAP. Default is False.
176179
retain_existing_sigma: bool
@@ -274,6 +277,7 @@ def initial_rss(
274277
)
275278
do_data_preprocessing = preprocess_data(
276279
test_ratio=test_ratio,
280+
disable_testing=disable_testing,
277281
regularization=regularization,
278282
retain_existing_sigma=retain_existing_sigma,
279283
scheme=scheme,
@@ -295,6 +299,7 @@ def initial_rss(
295299
apply_data_preprocessing=False,
296300
auto_delta=auto_delta,
297301
glue_xml=False,
302+
disable_testing=disable_testing,
298303
).make(
299304
isolated_atom_energies=do_data_collection.output["isolated_atom_energies"],
300305
database_dir=do_data_preprocessing.output,
@@ -352,6 +357,7 @@ def do_rss_iterations(
352357
dft_ref_file: str = "dft_ref.extxyz",
353358
rss_group: str = "rss",
354359
test_ratio: float = 0.1,
360+
disable_testing: bool = False,
355361
regularization: bool = False,
356362
retain_existing_sigma: bool = False,
357363
scheme: str | None = None,
@@ -479,6 +485,8 @@ def do_rss_iterations(
479485
Group name for GAP RSS. Default is 'rss'.
480486
test_ratio: float
481487
The proportion of the test set after splitting the data. Default is 0.1.
488+
disable_testing: bool
489+
Whether to disable running the model on test data. Default is False.
482490
regularization: bool
483491
If true, apply regularization. This only works for GAP. Default is False.
484492
retain_existing_sigma: bool
@@ -674,6 +682,7 @@ def do_rss_iterations(
674682
)
675683
do_data_preprocessing = preprocess_data(
676684
test_ratio=test_ratio,
685+
disable_testing=disable_testing,
677686
regularization=regularization,
678687
retain_existing_sigma=retain_existing_sigma,
679688
scheme=scheme,
@@ -695,6 +704,7 @@ def do_rss_iterations(
695704
apply_data_preprocessing=False,
696705
auto_delta=auto_delta,
697706
glue_xml=False,
707+
disable_testing=disable_testing,
698708
).make(
699709
database_dir=do_data_preprocessing.output,
700710
isolated_atom_energies=input["isolated_atom_energies"],
@@ -744,6 +754,7 @@ def do_rss_iterations(
744754
dft_ref_file=dft_ref_file,
745755
rss_group=rss_group,
746756
test_ratio=test_ratio,
757+
disable_testing=disable_testing,
747758
regularization=regularization,
748759
retain_existing_sigma=retain_existing_sigma,
749760
scheme=scheme,

src/autoplex/data/common/jobs.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ def safe_strip_hostname(value):
745745
def preprocess_data(
746746
dft_ref_dir: str,
747747
test_ratio: float | None = None,
748+
disable_testing: bool = False,
748749
regularization: bool = False,
749750
retain_existing_sigma: bool = False,
750751
scheme: str = "linear-hull",
@@ -758,7 +759,7 @@ def preprocess_data(
758759
isolated_atom_energies: dict | None = None,
759760
) -> Path:
760761
"""
761-
Preprocesse data to before fiting machine learning models.
762+
Preprocess data to before fitting machine learning models.
762763
763764
This function handles tasks such as splitting the dataset,
764765
applying regularization, accumulating database, and filtering
@@ -771,6 +772,8 @@ def preprocess_data(
771772
test_ratio: float
772773
The proportion of the test set after splitting the data.
773774
If None, no splitting will be performed.
775+
disable_testing: bool
776+
Whether to disable running the model on test data. Default is False.
774777
regularization: bool
775778
If true, apply regularization. This only works for GAP.
776779
retain_existing_sigma: bool
@@ -812,14 +815,20 @@ def preprocess_data(
812815
)
813816

814817
if test_ratio == 0 or test_ratio is None:
815-
train_structures, test_structures = atoms, atoms
818+
train_structures, test_structures = atoms, []
816819
else:
817820
train_structures, test_structures = stratified_dataset_split(
818821
atoms, test_ratio, energy_label
819822
)
820823

821824
if pre_database_dir and os.path.exists(pre_database_dir):
822-
files_to_copy = ["train.extxyz", "test.extxyz"]
825+
files_to_copy = [
826+
"train.extxyz",
827+
]
828+
if not disable_testing:
829+
files_to_copy += [
830+
"test.extxyz",
831+
]
823832
current_working_directory = os.getcwd()
824833

825834
for file_name in files_to_copy:
@@ -830,7 +839,8 @@ def preprocess_data(
830839
print(f"File {file_name} has been copied to {destination_file_path}")
831840

832841
write("train.extxyz", train_structures, format="extxyz", append=True)
833-
write("test.extxyz", test_structures, format="extxyz", append=True)
842+
if not disable_testing:
843+
write("test.extxyz", test_structures, format="extxyz", append=True)
834844

835845
if regularization:
836846
atoms_reg: list[Atoms] = read("train.extxyz", index=":")

src/autoplex/fitting/common/flows.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class MLIPFitMaker(Maker):
8686
Determine whether to preprocess the data.
8787
run_fits_on_different_cluster: bool
8888
If true, run fits on different clusters.
89+
disable_testing: bool
90+
Whether to disable running the model on test data.
8991
"""
9092

9193
name: str = "MLpotentialFit"
@@ -110,6 +112,7 @@ class MLIPFitMaker(Maker):
110112
num_processes_fit: int | None = None
111113
apply_data_preprocessing: bool = True
112114
run_fits_on_different_cluster: bool = False
115+
disable_testing: bool = False
113116

114117
def make(
115118
self,
@@ -188,6 +191,7 @@ def make(
188191
device=device,
189192
species_list=species_list,
190193
database_dict=data_prep_job.output["database_dict"],
194+
disable_testing=self.disable_testing,
191195
**fit_kwargs,
192196
)
193197
jobs.append(mlip_fit_job)
@@ -221,6 +225,7 @@ def make(
221225
ref_virial_name=self.ref_virial_name,
222226
device=device,
223227
species_list=species_list,
228+
disable_testing=self.disable_testing,
224229
**fit_kwargs,
225230
)
226231

src/autoplex/fitting/common/jobs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def machine_learning_fit(
3838
database_dict: dict | None = None,
3939
hyperpara_opt: bool = False,
4040
hyperparameters: MLIP_HYPERS = MLIP_HYPERS,
41+
disable_testing: bool = False,
4142
**fit_kwargs,
4243
):
4344
"""
@@ -84,6 +85,8 @@ def machine_learning_fit(
8485
run_fits_on_different_cluster: bool
8586
Indicates if fits are to be run on a different cluster.
8687
If True, the fitting data (train.extxyz, test.extxyz) is stored in the database.
88+
disable_testing: bool
89+
Whether to disable running the model on test data. Default is False.
8790
fit_kwargs: dict
8891
Additional keyword arguments for MLIP fitting.
8992
"""
@@ -125,8 +128,9 @@ def machine_learning_fit(
125128
if mlip_type == "GAP":
126129
for train_name, test_name in zip(train_files, test_files):
127130
if (database_dir / train_name).exists() and (
128-
database_dir / test_name
129-
).exists():
131+
(database_dir / test_name).exists() or disable_testing
132+
):
133+
130134
train_test_error = gap_fitting(
131135
db_dir=database_dir,
132136
hyperparameters=hyperparameters.GAP,
@@ -140,6 +144,7 @@ def machine_learning_fit(
140144
ref_virial_name=ref_virial_name,
141145
train_name=train_name,
142146
test_name=test_name,
147+
disable_testing=disable_testing,
143148
fit_kwargs=fit_kwargs,
144149
)
145150
mlip_paths.append(train_test_error["mlip_path"])
@@ -153,6 +158,7 @@ def machine_learning_fit(
153158
ref_force_name=ref_force_name,
154159
ref_virial_name=ref_virial_name,
155160
num_processes_fit=num_processes_fit,
161+
disable_testing=disable_testing,
156162
fit_kwargs=fit_kwargs,
157163
)
158164
mlip_paths.append(train_test_error["mlip_path"])
@@ -169,6 +175,7 @@ def machine_learning_fit(
169175
ref_virial_name=ref_virial_name,
170176
species_list=species_list,
171177
gpu_identifier_indices=gpu_identifier_indices,
178+
disable_testing=disable_testing,
172179
fit_kwargs=fit_kwargs,
173180
)
174181

@@ -182,6 +189,7 @@ def machine_learning_fit(
182189
ref_energy_name=ref_energy_name,
183190
ref_force_name=ref_force_name,
184191
ref_virial_name=ref_virial_name,
192+
disable_testing=disable_testing,
185193
fit_kwargs=fit_kwargs,
186194
device=device,
187195
)
@@ -194,6 +202,7 @@ def machine_learning_fit(
194202
ref_energy_name=ref_energy_name,
195203
ref_force_name=ref_force_name,
196204
ref_virial_name=ref_virial_name,
205+
disable_testing=disable_testing,
197206
fit_kwargs=fit_kwargs,
198207
device=device,
199208
)
@@ -211,7 +220,12 @@ def machine_learning_fit(
211220
)
212221
mlip_paths.append(train_test_error["mlip_path"])
213222

214-
check_conv = check_convergence(train_test_error["test_error"])
223+
error = (
224+
train_test_error["train_error"]
225+
if disable_testing
226+
else train_test_error["test_error"]
227+
)
228+
check_conv = check_convergence(error)
215229

216230
return {
217231
"mlip_path": mlip_paths,

0 commit comments

Comments
 (0)