99
1010import copy
1111from functools import partial
12- from typing import TYPE_CHECKING , Any , Literal , Optional
12+ from typing import TYPE_CHECKING , Any , Literal
1313
1414from ase .atoms import Atoms
1515from monty .dev import requires
@@ -55,11 +55,11 @@ def relax_job(initial_atoms, calc, optimizer_cls, fmax, steps):
5555@requires (data_oc_installed , message = "Requires `fairchem-data-oc` to be installed" )
5656def adsorb_ml_pipeline (
5757 slab : Slab ,
58- adsorbates_kwargs : dict [str , Any ],
58+ adsorbates : list [ Adsorbate | dict [str , Any ] ],
5959 multiple_adsorbate_slab_config_kwargs : dict [str , Any ],
6060 ml_slab_adslab_relax_job : Callable [..., Any ],
6161 reference_ml_energies : bool = True ,
62- atomic_reference_energies : Optional [ dict ] = None ,
62+ atomic_reference_energies : dict | None = None ,
6363 relaxed_slab_atoms : Atoms = None ,
6464 place_on_relaxed_slab : bool = False ,
6565):
@@ -76,8 +76,9 @@ def adsorb_ml_pipeline(
7676 ----------
7777 slab : Slab
7878 The slab structure to which adsorbates will be added.
79- adsorbates_kwargs : dict[str,Any]
80- Keyword arguments for generating adsorbate configurations.
79+ adsorbates : list[Adsorbate | dict[str, Any]]
80+ List of Adsorbate objects or keyword argument dicts for
81+ constructing them.
8182 multiple_adsorbate_slab_config_kwargs : dict[str, Any]
8283 Keyword arguments for generating multiple adsorbate-slab configurations.
8384 ml_slab_adslab_relax_job : Job
@@ -104,7 +105,7 @@ def adsorb_ml_pipeline(
104105
105106 unrelaxed_adslab_configurations = ocp_adslab_generator (
106107 ml_relaxed_slab_result ["atoms" ] if place_on_relaxed_slab else slab .atoms ,
107- adsorbates_kwargs ,
108+ adsorbates ,
108109 multiple_adsorbate_slab_config_kwargs ,
109110 )
110111
@@ -150,7 +151,7 @@ def adsorb_ml_pipeline(
150151@requires (data_oc_installed , message = "Requires `fairchem-data-oc` to be installed" )
151152def ocp_adslab_generator (
152153 slab : Slab | Atoms ,
153- adsorbates_kwargs : list [dict [str , Any ]] | None = None ,
154+ adsorbates : list [Adsorbate | dict [str , Any ]] | None = None ,
154155 multiple_adsorbate_slab_config_kwargs : dict [str , Any ] | None = None ,
155156) -> list [Atoms ]:
156157 """
@@ -160,8 +161,9 @@ def ocp_adslab_generator(
160161 ----------
161162 slab : Slab | Atoms
162163 The slab structure.
163- adsorbates_kwargs : list[dict[str,Any]], optional
164- List of keyword arguments for generating adsorbates, by default None.
164+ adsorbates : list[Adsorbate | dict[str, Any]], optional
165+ List of Adsorbate objects or keyword argument dicts for
166+ constructing them, by default None.
165167 multiple_adsorbate_slab_config_kwargs : dict[str,Any], optional
166168 Keyword arguments for generating multiple adsorbate-slab configurations, by default None.
167169
@@ -171,7 +173,7 @@ def ocp_adslab_generator(
171173 List of generated adsorbate-slab configurations.
172174 """
173175 adsorbates = [
174- Adsorbate (** adsorbate_kwargs ) for adsorbate_kwargs in adsorbates_kwargs
176+ ads if isinstance ( ads , Adsorbate ) else Adsorbate (** ads ) for ads in adsorbates
175177 ]
176178
177179 if isinstance (slab , Atoms ):
@@ -335,15 +337,17 @@ def run_adsorbml(
335337 reference_ml_energies : bool = True ,
336338 relaxed_slab_atoms : Atoms = None ,
337339 place_on_relaxed_slab : bool = False ,
340+ atomic_reference_energies : dict | None = None ,
338341):
339342 """
340343 Run the AdsorbML pipeline for a given slab and adsorbate using a pretrained ML model.
341344 Parameters
342345 ----------
343346 slab : ase.Atoms
344347 The clean slab structure to which the adsorbate will be added.
345- adsorbate : str
346- A string identifier for the adsorbate from the database (e.g., '*O').
348+ adsorbate : str | Adsorbate | list[Adsorbate]
349+ Either a SMILES string identifier for the adsorbate from the database
350+ (e.g., '*O'), a pre-constructed Adsorbate object, or list of Adsorbate objects.
347351 reference_ml_energies : bool, optional
348352 If True, assumes the model is a total energy model and references energies
349353 to gas phase and bare slab, by default True since the default model is a total energy model.
@@ -364,22 +368,30 @@ def run_adsorbml(
364368 energies, and validation results (matching the AdsorbMLSchema format).
365369 """
366370
367- # if we are using a total energy model, we need to set the DFT atomic reference energies
368- # obtained from the supplementary information of the OC20 paper
369- atomic_reference_energies = {
370- "H" : - 3.477 , # eV
371- "O" : - 7.204 , # eV
372- "C" : - 7.282 , # eV
373- "N" : - 8.083 , # eV
374- }
371+ if atomic_reference_energies is None :
372+ # if we are using a total energy model, we need to set the DFT atomic reference energies
373+ # obtained from the supplementary information of the OC20 paper
374+ atomic_reference_energies = {
375+ "H" : - 3.477 , # eV
376+ "O" : - 7.204 , # eV
377+ "C" : - 7.282 , # eV
378+ "N" : - 8.083 , # eV
379+ }
375380
376381 ml_relax_job = partial (
377382 relax_job , calc = calculator , optimizer_cls = optimizer_cls , fmax = fmax , steps = steps
378383 )
379384
385+ if isinstance (adsorbate , str ):
386+ adsorbates = [{"adsorbate_smiles_from_db" : adsorbate }]
387+ elif isinstance (adsorbate , Adsorbate ):
388+ adsorbates = [adsorbate ]
389+ else :
390+ adsorbates = adsorbate
391+
380392 outputs = adsorb_ml_pipeline (
381393 slab = slab ,
382- adsorbates_kwargs = [{ "adsorbate_smiles_from_db" : adsorbate }] ,
394+ adsorbates = adsorbates ,
383395 multiple_adsorbate_slab_config_kwargs = {"num_configurations" : num_placements },
384396 ml_slab_adslab_relax_job = ml_relax_job ,
385397 reference_ml_energies = reference_ml_energies ,
0 commit comments