88import numpy as np
99import torch
1010from ase .data import atomic_masses , chemical_symbols
11+ from ase .geometry import wrap_positions
1112
1213from fairchem .core .datasets .atomic_data import AtomicData
1314from lammps import lammps
@@ -32,16 +33,23 @@ def check_input_script(input_script: str):
3233
3334def check_atom_id_match_masses (types_arr , masses ):
3435 for atom_id in types_arr :
35- assert np .allclose (
36- masses [atom_id ], atomic_masses [atom_id ], atol = 1e-1
37- ), f"Atom { chemical_symbols [ atom_id ] } (type { atom_id } ) has mass { masses [ atom_id ] } but is expected to have mass { atomic_masses [ atom_id ] } ."
36+ assert np .allclose (masses [ atom_id ], atomic_masses [ atom_id ], atol = 1e-1 ), (
37+ f"Atom { chemical_symbols [ atom_id ] } (type { atom_id } ) has mass { masses [atom_id ]} but is expected to have mass { atomic_masses [atom_id ]} ."
38+ )
3839
3940
4041def atomic_data_from_lammps_data (
41- x , atomic_numbers , nlocal , cell , periodicity , task_name
42+ x : np .ndarray | torch .Tensor ,
43+ atomic_numbers ,
44+ nlocal ,
45+ cell ,
46+ periodicity ,
47+ task_name ,
48+ charge : int = 0 ,
49+ spin : int = 0 ,
4250):
4351 # TODO: do we need to take of care of wrapping atoms that are outside the cell?
44- pos = torch .tensor (x , dtype = torch .float32 )
52+ pos = torch .as_tensor (x , dtype = torch .float32 )
4553 pbc = torch .tensor (periodicity , dtype = torch .bool ).unsqueeze (0 )
4654 edge_index = torch .empty ((2 , 0 ), dtype = torch .long )
4755 cell_offsets = torch .empty ((0 , 3 ), dtype = torch .float32 )
@@ -58,8 +66,8 @@ def atomic_data_from_lammps_data(
5866 edge_index = edge_index ,
5967 cell_offsets = cell_offsets ,
6068 nedges = nedges ,
61- charge = torch .LongTensor ([0 ]),
62- spin = torch .LongTensor ([0 ]),
69+ charge = torch .LongTensor ([charge ]),
70+ spin = torch .LongTensor ([spin ]),
6371 fixed = fixed ,
6472 tags = tags ,
6573 batch = batch ,
@@ -116,7 +124,7 @@ def lookup_atomic_number_by_mass(mass_arr: np.ndarray | float) -> np.ndarray | i
116124 return atomic_numbers
117125
118126
119- def separate_run_commands (input_script : str ) -> str :
127+ def separate_run_commands (input_script : str ) -> tuple [ list [ str ], list [ str ]] :
120128 lines = input_script .splitlines ()
121129 run_cmds = []
122130 script = []
@@ -145,52 +153,74 @@ def cell_from_lammps_box(boxlo, boxhi, xy, yz, xz):
145153 return unit_cell_matrix .unsqueeze (0 )
146154
147155
148- def fix_external_call_back (lmp , ntimestep , nlocal , tag , x , f ):
149- # force copy here, otherwise we can accident modify the original array in lammps
150- # TODO: only need to get atomic numbers once and cache it?
151- # is there a way to check atom types are mapped correctly?
152- atom_type_np = lmp .numpy .extract_atom ("type" )
153- masses = lmp .numpy .extract_atom ("mass" )
154- atomic_mass_arr = masses [atom_type_np ]
155- atomic_numbers = lookup_atomic_number_by_mass (atomic_mass_arr )
156- boxlo , boxhi , xy , yz , xz , periodicity , box_change = lmp .extract_box ()
157- cell = cell_from_lammps_box (boxlo , boxhi , xy , yz , xz )
158- atomic_data = atomic_data_from_lammps_data (
159- x , atomic_numbers , nlocal , cell , periodicity , lmp ._task_name
160- )
161- results = lmp ._predictor .predict (atomic_data )
162- assert "forces" in results , "forces must be in results"
163- f [:] = results ["forces" ].cpu ().numpy ()[:]
164- lmp .fix_external_set_energy_global (FIX_EXT_ID , results ["energy" ].item ())
165-
166- # during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
167- if box_change :
168- # stress is defined as virial/volume in lammps
169- assert "stress" in results , "stress must be in results to compute virial"
170- volume = torch .det (cell ).abs ().item ()
171- v = (results ["stress" ].cpu () * volume )[0 ]
172- # virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
173- virial_arr = [v [0 ], v [4 ], v [8 ], v [1 ], v [2 ], v [5 ]]
174- lmp .fix_external_set_virial_global (FIX_EXT_ID , virial_arr )
156+ class FixExternalCallback :
157+ def __init__ (self , charge : int = 0 , spin : int = 0 ):
158+ self .charge = charge
159+ self .spin = spin
160+
161+ def __call__ (self , lmp , ntimestep , nlocal , tag , x , f ):
162+ # force copy here, otherwise we can accident modify the original array in lammps
163+ # TODO: only need to get atomic numbers once and cache it?
164+ # is there a way to check atom types are mapped correctly?
165+ atom_type_np = lmp .numpy .extract_atom ("type" )
166+ masses = lmp .numpy .extract_atom ("mass" )
167+ atomic_mass_arr = masses [atom_type_np ]
168+ atomic_numbers = lookup_atomic_number_by_mass (atomic_mass_arr )
169+ boxlo , boxhi , xy , yz , xz , periodicity , box_change = lmp .extract_box ()
170+ cell = cell_from_lammps_box (boxlo , boxhi , xy , yz , xz )
171+
172+ x_wrapped = wrap_positions (
173+ x , cell = cell .squeeze ().numpy (), pbc = periodicity , eps = 0
174+ )
175+
176+ atomic_data = atomic_data_from_lammps_data (
177+ x_wrapped ,
178+ atomic_numbers ,
179+ nlocal ,
180+ cell ,
181+ periodicity ,
182+ lmp ._task_name ,
183+ charge = self .charge ,
184+ spin = self .spin ,
185+ )
186+ results = lmp ._predictor .predict (atomic_data )
187+ assert "forces" in results , "forces must be in results"
188+ f [:] = results ["forces" ].cpu ().numpy ()[:]
189+ lmp .fix_external_set_energy_global (FIX_EXT_ID , results ["energy" ].item ())
190+
191+ # during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
192+ if box_change :
193+ # stress is defined as -virial/volume in lammps
194+ assert "stress" in results , "stress must be in results to compute virial"
195+ volume = torch .det (cell ).abs ().item ()
196+ v = (- results ["stress" ].detach ().cpu () * volume )[0 ].tolist ()
197+ # virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
198+ virial_arr = [v [0 ], v [4 ], v [8 ], v [1 ], v [2 ], v [5 ]]
199+ lmp .fix_external_set_virial_global (FIX_EXT_ID , virial_arr )
175200
176201
177202def run_lammps_with_fairchem (
178- predictor : MLIPPredictUnitProtocol , lammps_input_path : str , task_name : str
203+ predictor : MLIPPredictUnitProtocol ,
204+ lammps_input_path : str ,
205+ task_name : str ,
206+ charge : int = 0 ,
207+ spin : int = 0 ,
179208):
180209 machine = None
181210 if "LAMMPS_MACHINE_NAME" in os .environ :
182211 machine = os .environ ["LAMMPS_MACHINE_NAME" ]
183212 lmp = lammps (name = machine , cmdargs = ["-nocite" , "-log" , "none" , "-echo" , "screen" ])
184213 lmp ._predictor = predictor
185214 lmp ._task_name = task_name
186- run_cmds = []
215+ # run_cmds = []
187216 with open (lammps_input_path ) as f :
188217 input_script = f .read ()
189218 check_input_script (input_script )
190219 script , run_cmds = separate_run_commands (input_script )
191220 logging .info (f"Running input script: { input_script } " )
192221 lmp .commands_list (script )
193222 lmp .command (FIX_EXTERNAL_CMD )
223+ fix_external_call_back = FixExternalCallback (charge = charge , spin = spin )
194224 lmp .set_fix_external_callback (FIX_EXT_ID , fix_external_call_back , lmp )
195225 lmp .commands_list (run_cmds )
196226 return lmp
@@ -203,7 +233,13 @@ def run_lammps_with_fairchem(
203233)
204234def main (cfg : DictConfig ):
205235 predict_unit = hydra .utils .instantiate (cfg .predict_unit )
206- lmp = run_lammps_with_fairchem (predict_unit , cfg .lmp_in , cfg .task_name )
236+ lmp = run_lammps_with_fairchem (
237+ predict_unit ,
238+ cfg .lmp_in ,
239+ cfg .task_name ,
240+ cfg .charge ,
241+ cfg .spin ,
242+ )
207243 # this is required to cleanup the predictor
208244 del lmp ._predictor
209245
0 commit comments