1919import numpy as np
2020from tqdm import tqdm
2121import torch
22+ from alignn .config import TrainingConfig
2223
2324# Reference: https://doi.org/10.1039/D2DD00096B
2425
@@ -209,8 +210,8 @@ def __init__(
209210 if path is None and model is None :
210211 path = default_path ()
211212 if self .config is None :
212- config = loadjson (os .path .join (path , config_filename ))
213- self .config = config
213+ self . config = loadjson (os .path .join (path , config_filename ))
214+ self .config = TrainingConfig ( ** self . config ). model_dump ()
214215 if self .force_mult_natoms :
215216 self .config ["model" ]["force_mult_natoms" ] = True
216217
@@ -262,20 +263,23 @@ def __init__(
262263 torch .load (
263264 os .path .join (path , model_filename ),
264265 map_location = self .device ,
266+ weights_only = False ,
265267 )
266268 )
267269 else :
268270 model .load_state_dict (
269271 torch .load (
270272 os .path .join (path , model_filename ),
271273 map_location = self .device ,
274+ weights_only = False ,
272275 )["model" ]
273276 )
274- model .to (device )
275277 model .eval ()
278+ model .to (device )
276279 self .model = model
277280 else :
278281 model = self .model
282+ self .model = self .model .to (self .device )
279283
280284 def calculate (self , atoms , properties = None , system_changes = None ):
281285 """Calculate properties."""
@@ -294,11 +298,13 @@ def calculate(self, atoms, properties=None, system_changes=None):
294298 )
295299 if self .config ["compute_line_graph" ]:
296300 g , lg = g
301+ # print('self.model',self.model.device)
302+ # print("delf.device",self.device)
297303 result = self .model (
298304 (
299305 g .to (self .device ),
300306 lg .to (self .device ),
301- torch .tensor (atoms .cell )
307+ torch .tensor (np . array ( atoms .cell ) )
302308 .type (torch .get_default_dtype ())
303309 .to (self .device ),
304310 )
@@ -454,6 +460,7 @@ def __init__(
454460 torch .load (
455461 os .path .join (ff_path , ff_model_filename ),
456462 map_location = self .device ,
463+ weights_only = False ,
457464 )
458465 )
459466 ff_model .eval ()
@@ -475,10 +482,13 @@ def __init__(
475482 torch .load (
476483 os .path .join (prop_path , prop_model_filename ),
477484 map_location = self .device ,
485+ weights_only = False ,
478486 )
479487 )
480488 prop_model .eval ()
481489 self .prop_model = prop_model
490+ self .ff_model = self .ff_model .to (self .device )
491+ self .prop_model = self .prop_model .to (self .device )
482492
483493 def calculate (
484494 self ,
0 commit comments