Skip to content

Commit 31d0285

Browse files
authored
Batched simulation tutorial (#1689)
* add tutorial * add experimental note
1 parent eb7563d commit 31d0285

1 file changed

Lines changed: 184 additions & 0 deletions

File tree

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Batched atomic simulations with an `InferenceBatcher`
2+
3+
````{admonition} Need to install fairchem-core or get UMA access or getting permissions/401 errors?
4+
:class: dropdown
5+
6+
7+
1. Install the necessary packages using pip, uv etc
8+
```{code-cell} ipython3
9+
:tags: [skip-execution]
10+
11+
! pip install fairchem-core fairchem-data-oc fairchem-applications-cattsunami
12+
```
13+
14+
2. Get access to any necessary huggingface gated models
15+
* Get and login to your Huggingface account
16+
* Request access to https://huggingface.co/facebook/UMA
17+
* Create a Huggingface token at https://huggingface.co/settings/tokens/ with the permission "Permissions: Read access to contents of all public gated repos you can access"
18+
* Add the token as an environment variable using `huggingface-cli login` or by setting the HF_TOKEN environment variable.
19+
20+
```{code-cell} ipython3
21+
:tags: [skip-execution]
22+
23+
# Login using the huggingface-cli utility
24+
! huggingface-cli login
25+
26+
# alternatively,
27+
import os
28+
os.environ['HF_TOKEN'] = 'MY_TOKEN'
29+
```
30+
31+
````
32+
33+
```{admonition} Learning Objectives
34+
:class: note
35+
The `InferenceBatcher` class and underlying concurrent batching implementations are experimental and under current development. The following tutorial is intended to provide a basic understanding of the class and its usage, but the API may change. If you have suggestions for improvements, please open an issue or submit a pull request.
36+
```
37+
38+
When running many independent ASE calculations (relaxations, molecular dynamics, etc.) on small to medium-sized systems, you can significantly improve GPU utilization by batching model inference calls together. The `InferenceBatcher` class provides a high-level API to do this with minimal code changes.
39+
40+
The key idea is simple: instead of running each simulation sequentially, `InferenceBatcher` collects inference requests from multiple concurrent simulations and batches them together for more efficient GPU computation.
41+
42+
## Basic setup
43+
44+
To use `InferenceBatcher`, you need to:
45+
46+
1. Create a predict unit as usual
47+
2. Wrap it with `InferenceBatcher`
48+
3. Use `batcher.batch_predict_unit` instead of the original predict unit in your simulation functions
49+
50+
```python
51+
from fairchem.core import pretrained_mlip
52+
from fairchem.core.calculate import FAIRChemCalculator, InferenceBatcher
53+
54+
# Create a predict unit
55+
predict_unit = pretrained_mlip.get_predict_unit("uma-s-1p1")
56+
57+
# Wrap it with InferenceBatcher
58+
batcher = InferenceBatcher(
59+
predict_unit, concurrency_backend_options=dict(max_workers=32)
60+
)
61+
```
62+
63+
The `max_workers` parameter controls how many concurrent simulations can run concurrently.
64+
65+
## Writing simulation functions
66+
67+
The only requirement for using `InferenceBatcher` is to write your simulation logic as a function that takes an `Atoms` object and a predict unit as arguments:
68+
69+
```python
70+
from ase.build import bulk
71+
from ase.filters import FrechetCellFilter
72+
from ase.optimize import LBFGS
73+
74+
75+
def run_relaxation(atoms, predict_unit):
76+
"""Run a structure relaxation and return the final energy."""
77+
calc = FAIRChemCalculator(predict_unit, task_name="omat")
78+
atoms.calc = calc
79+
opt = LBFGS(FrechetCellFilter(atoms), logfile=None)
80+
opt.run(fmax=0.02, steps=100)
81+
return atoms.get_potential_energy()
82+
```
83+
84+
## Running batched relaxations
85+
86+
Once you have your simulation function, you can run it in batched mode using the executor's `map` or `submit` methods:
87+
88+
### Using `executor.map`
89+
90+
```python
91+
from functools import partial
92+
93+
# Create a list of structures to relax
94+
prim_atoms = [
95+
bulk("Cu"),
96+
bulk("MgO", "rocksalt", a=4.2),
97+
bulk("Si", "diamond", a=5.43),
98+
bulk("NaCl", "rocksalt", a=3.8),
99+
]
100+
101+
atoms_list = [make_supercell(atoms, 3 * np.identity(3)) for atoms in prim_atoms]
102+
103+
for atoms in atoms_list:
104+
atoms.rattle(0.1)
105+
106+
# Create a partial function with the batch predict unit
107+
run_relaxation_batched = partial(
108+
run_relaxation, predict_unit=batcher.batch_predict_unit
109+
)
110+
111+
# Run all relaxations in parallel with batched inference
112+
relaxed_energies = list(batcher.executor.map(run_relaxation_batched, atoms_list))
113+
```
114+
115+
### Using `executor.submit` for more control
116+
117+
If you need more control over the execution or want to process results as they complete:
118+
119+
```python
120+
# Create a new list of structures to relax
121+
atoms_list = [make_supercell(atoms, 3 * np.identity(3)) for atoms in prim_atoms]
122+
123+
for atoms in atoms_list:
124+
atoms.rattle(0.1)
125+
126+
# Submit all jobs
127+
futures = [
128+
batcher.executor.submit(run_relaxation, atoms, batcher.batch_predict_unit)
129+
for atoms in atoms_list
130+
]
131+
132+
# Collect results
133+
relaxed_energies = [future.result() for future in futures]
134+
```
135+
136+
## Running batched molecular dynamics
137+
138+
The same pattern works for molecular dynamics simulations:
139+
140+
```python
141+
from ase import units
142+
from ase.md.langevin import Langevin
143+
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
144+
145+
146+
def run_nvt_md(atoms, predict_unit, temperature, traj_fname):
147+
"""Run NVT molecular dynamics simulation."""
148+
calc = FAIRChemCalculator(predict_unit, task_name="omat")
149+
atoms.calc = calc
150+
MaxwellBoltzmannDistribution(atoms, temperature, force_temp=True)
151+
dyn = Langevin(
152+
atoms,
153+
timestep=2 * units.fs,
154+
temperature_K=temperature,
155+
friction=0.1,
156+
trajectory=traj_fname,
157+
loginterval=5,
158+
)
159+
dyn.run(100)
160+
161+
162+
# Run batched MD simulations
163+
run_md_batched = partial(
164+
run_nvt_md, predict_unit=batcher.batch_predict_unit, temperature=300
165+
)
166+
167+
futures = [
168+
batcher.executor.submit(run_md_batched, atoms, traj_fname=f"traj_{i}.traj")
169+
for i, atoms in enumerate(atoms_list)
170+
]
171+
172+
# Wait for all simulations to complete
173+
[future.result() for future in futures]
174+
```
175+
176+
## When to use an `InferenceBatcher`
177+
178+
`InferenceBatcher` is most beneficial when:
179+
180+
- Running many independent simulations on small to medium-sized systems
181+
- GPU utilization is low with serial execution
182+
- Each individual simulation has many inference steps (relaxations, MD)
183+
184+
When running batch inference over static structures, consider using the [batch inference approach](batch_inference.md) with `AtomicData` directly instead. For single large systems, consider using the `MLIPParallelPredictUnit` for graph parallel inference.

0 commit comments

Comments
 (0)