Skip to content

Commit c200688

Browse files
authored
Fix virial sign in NPT ensemble, wrap coordinates to unit cell, and support charge/spin configuration (#1546)
* fix: wrap to unit cell * fix: resolve typing error * chore: Add comment * fix: replace torch.from_numpy to torch.as_tensor * test: Add NPT test * feat: Add charge/spin configuration * test: align external pressure * refactor: Replace 'wrap_to_unit_cell' with 'wrap_positions' * chore: Update comment the definition of stress in LAMMPS * test: Reduce NPT test timestep
1 parent a5ab43a commit c200688

4 files changed

Lines changed: 161 additions & 39 deletions

File tree

src/fairchem/lammps/lammps_fc.py

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from ase.data import atomic_masses, chemical_symbols
11+
from ase.geometry import wrap_positions
1112

1213
from fairchem.core.datasets.atomic_data import AtomicData
1314
from lammps import lammps
@@ -32,16 +33,23 @@ def check_input_script(input_script: str):
3233

3334
def check_atom_id_match_masses(types_arr, masses):
3435
for atom_id in types_arr:
35-
assert np.allclose(
36-
masses[atom_id], atomic_masses[atom_id], atol=1e-1
37-
), f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}."
36+
assert np.allclose(masses[atom_id], atomic_masses[atom_id], atol=1e-1), (
37+
f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}."
38+
)
3839

3940

4041
def atomic_data_from_lammps_data(
41-
x, atomic_numbers, nlocal, cell, periodicity, task_name
42+
x: np.ndarray | torch.Tensor,
43+
atomic_numbers,
44+
nlocal,
45+
cell,
46+
periodicity,
47+
task_name,
48+
charge: int = 0,
49+
spin: int = 0,
4250
):
4351
# TODO: do we need to take of care of wrapping atoms that are outside the cell?
44-
pos = torch.tensor(x, dtype=torch.float32)
52+
pos = torch.as_tensor(x, dtype=torch.float32)
4553
pbc = torch.tensor(periodicity, dtype=torch.bool).unsqueeze(0)
4654
edge_index = torch.empty((2, 0), dtype=torch.long)
4755
cell_offsets = torch.empty((0, 3), dtype=torch.float32)
@@ -58,8 +66,8 @@ def atomic_data_from_lammps_data(
5866
edge_index=edge_index,
5967
cell_offsets=cell_offsets,
6068
nedges=nedges,
61-
charge=torch.LongTensor([0]),
62-
spin=torch.LongTensor([0]),
69+
charge=torch.LongTensor([charge]),
70+
spin=torch.LongTensor([spin]),
6371
fixed=fixed,
6472
tags=tags,
6573
batch=batch,
@@ -116,7 +124,7 @@ def lookup_atomic_number_by_mass(mass_arr: np.ndarray | float) -> np.ndarray | i
116124
return atomic_numbers
117125

118126

119-
def separate_run_commands(input_script: str) -> str:
127+
def separate_run_commands(input_script: str) -> tuple[list[str], list[str]]:
120128
lines = input_script.splitlines()
121129
run_cmds = []
122130
script = []
@@ -145,52 +153,74 @@ def cell_from_lammps_box(boxlo, boxhi, xy, yz, xz):
145153
return unit_cell_matrix.unsqueeze(0)
146154

147155

148-
def fix_external_call_back(lmp, ntimestep, nlocal, tag, x, f):
149-
# force copy here, otherwise we can accident modify the original array in lammps
150-
# TODO: only need to get atomic numbers once and cache it?
151-
# is there a way to check atom types are mapped correctly?
152-
atom_type_np = lmp.numpy.extract_atom("type")
153-
masses = lmp.numpy.extract_atom("mass")
154-
atomic_mass_arr = masses[atom_type_np]
155-
atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr)
156-
boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box()
157-
cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz)
158-
atomic_data = atomic_data_from_lammps_data(
159-
x, atomic_numbers, nlocal, cell, periodicity, lmp._task_name
160-
)
161-
results = lmp._predictor.predict(atomic_data)
162-
assert "forces" in results, "forces must be in results"
163-
f[:] = results["forces"].cpu().numpy()[:]
164-
lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item())
165-
166-
# during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
167-
if box_change:
168-
# stress is defined as virial/volume in lammps
169-
assert "stress" in results, "stress must be in results to compute virial"
170-
volume = torch.det(cell).abs().item()
171-
v = (results["stress"].cpu() * volume)[0]
172-
# virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
173-
virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]]
174-
lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr)
156+
class FixExternalCallback:
157+
def __init__(self, charge: int = 0, spin: int = 0):
158+
self.charge = charge
159+
self.spin = spin
160+
161+
def __call__(self, lmp, ntimestep, nlocal, tag, x, f):
162+
# force copy here, otherwise we can accident modify the original array in lammps
163+
# TODO: only need to get atomic numbers once and cache it?
164+
# is there a way to check atom types are mapped correctly?
165+
atom_type_np = lmp.numpy.extract_atom("type")
166+
masses = lmp.numpy.extract_atom("mass")
167+
atomic_mass_arr = masses[atom_type_np]
168+
atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr)
169+
boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box()
170+
cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz)
171+
172+
x_wrapped = wrap_positions(
173+
x, cell=cell.squeeze().numpy(), pbc=periodicity, eps=0
174+
)
175+
176+
atomic_data = atomic_data_from_lammps_data(
177+
x_wrapped,
178+
atomic_numbers,
179+
nlocal,
180+
cell,
181+
periodicity,
182+
lmp._task_name,
183+
charge=self.charge,
184+
spin=self.spin,
185+
)
186+
results = lmp._predictor.predict(atomic_data)
187+
assert "forces" in results, "forces must be in results"
188+
f[:] = results["forces"].cpu().numpy()[:]
189+
lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item())
190+
191+
# during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
192+
if box_change:
193+
# stress is defined as -virial/volume in lammps
194+
assert "stress" in results, "stress must be in results to compute virial"
195+
volume = torch.det(cell).abs().item()
196+
v = (-results["stress"].detach().cpu() * volume)[0].tolist()
197+
# virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
198+
virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]]
199+
lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr)
175200

176201

177202
def run_lammps_with_fairchem(
178-
predictor: MLIPPredictUnitProtocol, lammps_input_path: str, task_name: str
203+
predictor: MLIPPredictUnitProtocol,
204+
lammps_input_path: str,
205+
task_name: str,
206+
charge: int = 0,
207+
spin: int = 0,
179208
):
180209
machine = None
181210
if "LAMMPS_MACHINE_NAME" in os.environ:
182211
machine = os.environ["LAMMPS_MACHINE_NAME"]
183212
lmp = lammps(name=machine, cmdargs=["-nocite", "-log", "none", "-echo", "screen"])
184213
lmp._predictor = predictor
185214
lmp._task_name = task_name
186-
run_cmds = []
215+
# run_cmds = []
187216
with open(lammps_input_path) as f:
188217
input_script = f.read()
189218
check_input_script(input_script)
190219
script, run_cmds = separate_run_commands(input_script)
191220
logging.info(f"Running input script: {input_script}")
192221
lmp.commands_list(script)
193222
lmp.command(FIX_EXTERNAL_CMD)
223+
fix_external_call_back = FixExternalCallback(charge=charge, spin=spin)
194224
lmp.set_fix_external_callback(FIX_EXT_ID, fix_external_call_back, lmp)
195225
lmp.commands_list(run_cmds)
196226
return lmp
@@ -203,7 +233,13 @@ def run_lammps_with_fairchem(
203233
)
204234
def main(cfg: DictConfig):
205235
predict_unit = hydra.utils.instantiate(cfg.predict_unit)
206-
lmp = run_lammps_with_fairchem(predict_unit, cfg.lmp_in, cfg.task_name)
236+
lmp = run_lammps_with_fairchem(
237+
predict_unit,
238+
cfg.lmp_in,
239+
cfg.task_name,
240+
cfg.charge,
241+
cfg.spin,
242+
)
207243
# this is required to cleanup the predictor
208244
del lmp._predictor
209245

src/fairchem/lammps/lammps_fc_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ parallel_predict_unit:
2929
predict_unit: ${local_predict_unit}
3030
lmp_in: "lammps_in_example.file"
3131
task_name: "omol"
32+
charge: 0
33+
spin: 0

tests/lammps/lammps_npt.file

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
units metal # Use metal units (Angstroms, eV, ps)
2+
atom_style atomic # Atoms have a single type and position
3+
lattice fcc 3.567
4+
boundary p p p
5+
6+
region simbox block 0 2 0 2 0 2
7+
create_box 1 simbox
8+
create_atoms 1 region simbox
9+
mass 1 12.011
10+
11+
velocity all create 300.0 12345 dist gaussian # Set initial velocities at 300 K
12+
13+
timestep 0.001 # 1 fs
14+
# Use NPT (isotropic) thermostat+barostat: target temp 300 K, target pressure 0 bar
15+
# fix npt syntax: fix ID group-ID npt temp Tstart Tstop Tdamp iso Pstart Pstop Pdamp
16+
fix 1 all npt temp 300.0 300.0 0.1 iso 0.0 0.0 1.0
17+
thermo_style custom step temp pe ke etotal press vol
18+
thermo 1
19+
run 100

tests/lammps/test_ase_vs_lammps.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from ase import units
66
from ase.build import bulk
77
from ase.md.langevin import Langevin
8+
from ase.md.nose_hoover_chain import IsotropicMTKNPT
89
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
910
from ase.md.verlet import VelocityVerlet
10-
from fairchem.lammps.lammps_fc import run_lammps_with_fairchem
1111

1212
from fairchem.core import FAIRChemCalculator
1313
from fairchem.core.calculate import pretrained_mlip
14+
from fairchem.lammps.lammps_fc import run_lammps_with_fairchem
1415

1516

1617
def run_ase_langevin():
@@ -74,6 +75,62 @@ def print_thermo(a=atoms):
7475
return atoms.get_kinetic_energy(), atoms.get_potential_energy()
7576

7677

78+
def run_ase_npt():
79+
"""Run ASE NPT-like using a Berendsen barostat approximation via NPT wrapper.
80+
81+
ASE doesn't provide a direct NPT integrator in the core; here we mimic
82+
an NPT run by coupling to a thermostat and using the `Parrinello-Rahman`
83+
style barostat if available in user's setup. For portability in tests we
84+
instead run VelocityVerlet with a simple rescaling of the cell using the
85+
`ase.constraints` is out of scope — this is a lightweight smoke test to
86+
exercise the predictor through an NPT LAMMPS run for comparison.
87+
"""
88+
atoms = bulk("C", "fcc", a=3.567, cubic=True)
89+
atoms = atoms.repeat((2, 2, 2))
90+
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda")
91+
atoms.calc = FAIRChemCalculator(predictor, task_name="omat")
92+
initial_temperature_K = 300.0
93+
np.random.seed(12345)
94+
MaxwellBoltzmannDistribution(atoms, temperature_K=initial_temperature_K)
95+
# Use ASE's NPT integrator which couples Nose-Hoover thermostat and
96+
# barostat (Parrinello-Rahman style) and updates the cell. We pick
97+
# thermostat/barostat time constants that map to LAMMPS fix npt's
98+
# Tdamp/Pdamp (units: ps here for LAMMPS). ASE's API expects time in
99+
# fs via ase.units, so use 0.1 ps = 100 fs as the thermostat time constant.
100+
tdamp = 0.1 # ps (thermostat damping time for LAMMPS mapping)
101+
pdamp = 1.0 # ps (barostat damping time for LAMMPS mapping)
102+
103+
# Convert ps -> fs for ASE NPT ttime/pfactor which expect time in fs units
104+
tdamp_fs = tdamp * 1000.0 * units.fs
105+
pdamp_fs = pdamp * 1000.0 * units.fs
106+
107+
# ASE NPT takes timestep in ASE units (seconds via units.fs) and temperature_K
108+
# externalstress is pressure in eV/Å^3 or a scalar (here 0 means 0 pressure)
109+
dyn = IsotropicMTKNPT(
110+
atoms,
111+
timestep=1.0 * units.fs,
112+
temperature_K=300,
113+
pressure_au=0.0 * units.bar,
114+
tdamp=tdamp_fs,
115+
pdamp=pdamp_fs,
116+
)
117+
118+
def print_thermo(a=atoms):
119+
ekin = a.get_kinetic_energy()
120+
epot = a.get_potential_energy()
121+
etot = ekin + epot
122+
temp = ekin / (1.5 * units.kB) / len(a)
123+
vol = a.get_volume()
124+
print(
125+
f"Step: {dyn.get_number_of_steps()}, Temp: {temp:.2f} K, "
126+
f"Ekin: {ekin:.4f} eV, Epot: {epot:.4f} eV, Etot: {etot:.4f} eV, Vol: {vol:.4f} Å^3"
127+
)
128+
129+
dyn.attach(print_thermo, interval=1)
130+
dyn.run(100)
131+
return atoms.get_kinetic_energy(), atoms.get_potential_energy()
132+
133+
77134
def run_lammps(input_file):
78135
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda")
79136
lmp = run_lammps_with_fairchem(predictor, input_file, "omat")
@@ -88,6 +145,14 @@ def test_ase_vs_lammps_nve():
88145
assert np.isclose(ase_pot, lammps_pot, rtol=0.1)
89146

90147

148+
@pytest.mark.gpu()
149+
def test_ase_vs_lammps_npt():
150+
ase_kinetic, ase_pot = run_ase_npt()
151+
lammps_kinetic, lammps_pot = run_lammps("tests/lammps/lammps_npt.file")
152+
assert np.isclose(ase_kinetic, lammps_kinetic, rtol=0.5)
153+
assert np.isclose(ase_pot, lammps_pot, rtol=0.5)
154+
155+
91156
@pytest.mark.xfail(
92157
reason="This is more demo purposes, need to configure the right parameters for ASE langevin to match lammps"
93158
)

0 commit comments

Comments
 (0)