Skip to content

Commit 2c4ecd6

Browse files
committed
Fix calc
1 parent 2773fba commit 2c4ecd6

10 files changed

Lines changed: 53 additions & 16 deletions

File tree

alignn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version number."""
22

3-
__version__ = "2025.4.1"
3+
__version__ = "2025.4.3"

alignn/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class TrainingConfig(BaseSettings):
161161

162162
# training configuration
163163
dtype: str = "float32"
164+
device: str = "cpu"
164165
random_seed: Optional[int] = 123
165166
classification_threshold: Optional[float] = None
166167
# target_range: Optional[List] = None

alignn/ff/calculators.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
from tqdm import tqdm
2121
import torch
22+
from alignn.config import TrainingConfig
2223

2324
# Reference: https://doi.org/10.1039/D2DD00096B
2425

@@ -209,8 +210,8 @@ def __init__(
209210
if path is None and model is None:
210211
path = default_path()
211212
if self.config is None:
212-
config = loadjson(os.path.join(path, config_filename))
213-
self.config = config
213+
self.config = loadjson(os.path.join(path, config_filename))
214+
self.config = TrainingConfig(**self.config).model_dump()
214215
if self.force_mult_natoms:
215216
self.config["model"]["force_mult_natoms"] = True
216217

@@ -262,20 +263,23 @@ def __init__(
262263
torch.load(
263264
os.path.join(path, model_filename),
264265
map_location=self.device,
266+
weights_only=False,
265267
)
266268
)
267269
else:
268270
model.load_state_dict(
269271
torch.load(
270272
os.path.join(path, model_filename),
271273
map_location=self.device,
274+
weights_only=False,
272275
)["model"]
273276
)
274-
model.to(device)
275277
model.eval()
278+
model.to(device)
276279
self.model = model
277280
else:
278281
model = self.model
282+
self.model = self.model.to(self.device)
279283

280284
def calculate(self, atoms, properties=None, system_changes=None):
281285
"""Calculate properties."""
@@ -294,11 +298,13 @@ def calculate(self, atoms, properties=None, system_changes=None):
294298
)
295299
if self.config["compute_line_graph"]:
296300
g, lg = g
301+
# print('self.model',self.model.device)
302+
# print("delf.device",self.device)
297303
result = self.model(
298304
(
299305
g.to(self.device),
300306
lg.to(self.device),
301-
torch.tensor(atoms.cell)
307+
torch.tensor(np.array(atoms.cell))
302308
.type(torch.get_default_dtype())
303309
.to(self.device),
304310
)
@@ -454,6 +460,7 @@ def __init__(
454460
torch.load(
455461
os.path.join(ff_path, ff_model_filename),
456462
map_location=self.device,
463+
weights_only=False,
457464
)
458465
)
459466
ff_model.eval()
@@ -475,10 +482,13 @@ def __init__(
475482
torch.load(
476483
os.path.join(prop_path, prop_model_filename),
477484
map_location=self.device,
485+
weights_only=False,
478486
)
479487
)
480488
prop_model.eval()
481489
self.prop_model = prop_model
490+
self.ff_model = self.ff_model.to(self.device)
491+
self.prop_model = self.prop_model.to(self.device)
482492

483493
def calculate(
484494
self,

alignn/graphs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,15 @@ def get_line_graph(
452452
# print(g)
453453
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
454454
g.ndata["cart_coords"].requires_grad_(True)
455-
r, bondlength = compute_pair_vector_and_distance(g)
456-
dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"]
457-
src_pos = g.ndata["cart_coords"][g.edges()[0]]
458-
bond_vec = dst_pos - src_pos
459-
bond_dist = torch.norm(bond_vec, dim=1)
460-
pos = g.ndata["cart_coords"]
461-
g.edata["bond_dist"] = bond_dist
462-
g.edata["r"] = bond_vec
455+
# r, bondlength = compute_pair_vector_and_distance(g)
456+
# dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"]
457+
# src_pos = g.ndata["cart_coords"][g.edges()[0]]
458+
# bond_vec = dst_pos - src_pos
459+
# bond_dist = torch.norm(bond_vec, dim=1)
460+
# pos = g.ndata["cart_coords"]
461+
# g.edata["bond_dist"] = bond_dist
462+
# g.edata["r"] = bond_vec
463+
463464
src, dst = g.edges() # shape: [E], [E]
464465
pos = g.ndata["cart_coords"]
465466
src1 = src.unsqueeze(1) # [E,1]
@@ -485,7 +486,7 @@ def get_line_graph(
485486
eid1, eid2 = angle_mask.nonzero(as_tuple=True)
486487
# Create the line graph
487488
lg = dgl.graph((eid1, eid2), num_nodes=len(src))
488-
lg.ndata["r"] = bond_vec
489+
lg.ndata["r"] = g.edata["r"] # bond_vec
489490
lg.apply_edges(compute_bond_cosines)
490491
return lg
491492

alignn/models/alignn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class ALIGNNConfig(BaseSettings):
3838
classification: bool = False
3939
num_classes: int = 2
4040
extra_features: int = 0
41+
inner_cutoff: int = 3
42+
lighten_edges: int = False
4143

4244
class Config:
4345
"""Configure model settings behavior."""

alignn/models/alignn_atomwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,11 @@ def forward(
374374
"""
375375
if len(g) == 3:
376376
g, lg, lat = g
377-
lg = lg.local_var()
377+
# lg = lg.local_var()
378378
# print('lg',lg)
379379
# z = self.angle_embedding(lg.edata.pop("h"))
380380
z = self.angle_embedding(lg.edata["h"])
381+
# lg = lg.local_var()
381382
else:
382383
g, lat = g
383384
if len(self.alignn_layers) > 0:

alignn/models/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
# SchNet-style
3232
# set lengthscales relative to granularity of RBF expansion
3333
self.lengthscale = np.diff(self.centers).mean()
34+
3435
self.gamma = 1 / self.lengthscale
3536

3637
else:

alignn/tests/test_alignn_ff.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from jarvis.io.vasp.inputs import Poscar
2121
from alignn.ff.ff import get_figshare_model_prop, get_figshare_model_ff
2222
import os
23+
from elastic import get_elementary_deformations, get_elastic_tensor
24+
import elastic, ase
2325

2426
# JVASP-25139
2527
pos = """Rb8
@@ -154,6 +156,24 @@ def test_jdft_mbj_gap():
154156
val = atoms.get_potential_energy() # gap
155157

156158

159+
def test_elastic():
160+
model_path = get_figshare_model_ff(
161+
model_name="v12.2.2024_dft_3d_307k"
162+
) # default_path()
163+
calc = AlignnAtomwiseCalculator(path=model_path)
164+
atoms = Poscar.from_string(pos).atoms
165+
ase_atoms = atoms.ase_converter()
166+
ase_atoms.calc = calc
167+
systems = get_elementary_deformations(ase_atoms)
168+
cij_order = elastic.elastic.get_cij_order(ase_atoms)
169+
Cij, Bij = get_elastic_tensor(ase_atoms, systems)
170+
for i, j in zip(cij_order, Cij):
171+
# print(i, j / ase.units.GPa)
172+
assert j / ase.units.GPa > 0
173+
174+
175+
# test_elastic()
176+
# test_ev()
157177
# test_jdft_mbj_gap()
158178
# test_alexandria_gap()
159179
# test_alexandria_gap()

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ dependencies:
236236
- cryptography==42.0.5
237237
- dgl==2.1.0
238238
- dill==0.3.8
239+
- elastic==5.2.5.3
239240
- filelock==3.13.1
240241
- flake8==7.0.0
241242
- fsspec==2024.3.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setuptools.setup(
1212
name="alignn",
13-
version="2025.4.1",
13+
version="2025.4.3",
1414
author="Kamal Choudhary, Brian DeCost",
1515
author_email="kamal.choudhary@nist.gov",
1616
description="alignn",

0 commit comments

Comments
 (0)