Skip to content

Commit 408bb6e

Browse files
authored
Merge pull request #181 from usnistgov/redw
Faster training with compute_line_graph False/lg_on_fly
2 parents b4b68c5 + 3679279 commit 408bb6e

5 files changed

Lines changed: 81 additions & 22 deletions

File tree

alignn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version number."""
22

3-
__version__ = "2024.12.12"
3+
__version__ = "2025.4.1"

alignn/ff/calculators.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(
213213
self.config = config
214214
if self.force_mult_natoms:
215215
self.config["model"]["force_mult_natoms"] = True
216+
216217
if self.include_stress:
217218
self.implemented_properties = ["energy", "forces", "stress"]
218219
if (
@@ -226,6 +227,11 @@ def __init__(
226227

227228
else:
228229
self.implemented_properties = ["energy", "forces"]
230+
if (
231+
"calculate_gradient" in self.config["model"]
232+
and self.config["model"]["calculate_gradient"]
233+
):
234+
self.trained_stress = True
229235

230236
if (
231237
batch_stress is not None
@@ -296,24 +302,57 @@ def calculate(self, atoms, properties=None, system_changes=None):
296302
)
297303
else:
298304
result = self.model(
299-
(g.to(self.device, torch.tensor(atoms.cell).to(self.device)))
305+
(g.to(self.device), torch.tensor(atoms.cell).to(self.device))
300306
)
307+
# print("result",result)
301308
if "atomwise" in self.config["model"]["name"]:
302309
forces = forces = (
303310
result["grad"].detach().cpu().numpy() * self.force_multiplier
304311
)
305312
else:
306313
forces = np.zeros((3, 3))
307-
if "atomwise" in self.config["model"]["name"] and self.trained_stress:
308-
stress = (
309-
full_3x3_to_voigt_6_stress(
310-
result["stresses"][:3].reshape(3, 3).detach().cpu().numpy()
314+
# print("self.trained_stress",self.trained_stress)
315+
# if self.trained_stress:
316+
# # if "atomwise" in self.config["model"]["name"]
317+
# # and self.trained_stress:
318+
# stress = (
319+
# full_3x3_to_voigt_6_stress(
320+
# result["stresses"][:3].r
321+
# eshape(3, 3).detach().cpu().numpy()
322+
# )
323+
# * self.stress_wt
324+
# / 160.21766208
325+
# )
326+
# else:
327+
# stress = np.zeros((3, 3))
328+
if (
329+
"calculate_gradient" in self.config["model"]
330+
and self.config["model"]["calculate_gradient"]
331+
):
332+
try:
333+
stress = (
334+
full_3x3_to_voigt_6_stress(
335+
result["stresses"][:3]
336+
.reshape(3, 3)
337+
.detach()
338+
.cpu()
339+
.numpy()
340+
)
341+
* self.stress_wt
342+
/ 160.21766208
311343
)
312-
* self.stress_wt
313-
/ 160.21766208
314-
)
344+
except Exception:
345+
stress = np.zeros((3, 3))
346+
pass
315347
else:
316348
stress = np.zeros((3, 3))
349+
# stress = (
350+
# full_3x3_to_voigt_6_stress(
351+
# result["stresses"][:3].reshape(3, 3).detach().cpu().numpy()
352+
# )
353+
# * self.stress_wt
354+
# / 160.21766208
355+
# )
317356
if "atomwise" in self.config["model"]["name"]:
318357
energy = result["out"].detach().cpu().numpy()
319358
else:
@@ -325,6 +364,7 @@ def calculate(self, atoms, properties=None, system_changes=None):
325364
if self.force_mult_batchsize:
326365
forces *= self.config["batch_size"]
327366

367+
# print("stress cal",stress)
328368
self.results = {
329369
"energy": energy,
330370
"forces": forces,
@@ -449,6 +489,7 @@ def calculate(
449489
atom_features=self.ff_config["atom_features"],
450490
use_canonize=self.ff_config["use_canonize"],
451491
)
492+
# print("config",self.ff_config)
452493
result_ff = self.ff_model(
453494
(
454495
g.to(self.device),
@@ -459,6 +500,7 @@ def calculate(
459500
)
460501
)
461502
forces = forces = result_ff["grad"].detach().cpu().numpy()
503+
462504
stress = (
463505
full_3x3_to_voigt_6_stress(
464506
result_ff["stresses"][:3].reshape(3, 3).detach().cpu().numpy()

alignn/models/alignn_atomwise.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,20 +371,32 @@ def forward(
371371
z: angle features (lg.edata)
372372
"""
373373
if len(self.alignn_layers) > 0:
374-
g, lg, lat = g
375-
lg = lg.local_var()
376-
# print('lg',lg)
377-
# angle features (fixed)
378-
z = self.angle_embedding(lg.edata.pop("h"))
374+
if len(g) == 3:
375+
g, lg, lat = g
376+
lg = lg.local_var()
377+
# z = self.angle_embedding(lg.edata.pop("h"))
378+
z = self.angle_embedding(lg.edata["h"])
379+
else:
380+
g, lat = g
381+
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
382+
g.ndata["cart_coords"].requires_grad_(True)
383+
r, bondlength = compute_pair_vector_and_distance(g)
384+
lg = g.line_graph(shared=True)
385+
lg.ndata["r"] = r
386+
lg.apply_edges(compute_bond_cosines)
387+
# print('lg',lg)
388+
# angle features (fixed)
389+
else:
390+
g, lat = g
379391
if self.config.extra_features != 0:
380392
features = g.ndata["extra_features"]
381393
# print('features',features,features.shape)
382394
features = self.extra_feature_embedding(features)
383-
g = g.local_var()
395+
# g = g.local_var()
384396
result = {}
385-
386397
# initial node features: atom feature network...
387-
x = g.ndata.pop("atom_features")
398+
x = g.ndata["atom_features"]
399+
# x = g.ndata.pop("atom_features")
388400
# print('x1',x,x.shape)
389401

390402
x = self.atom_embedding(x)
@@ -416,7 +428,8 @@ def forward(
416428

417429
lg.ndata["r"] = r # overwrites precomputed r values
418430
lg.apply_edges(compute_bond_cosines) # overwrites precomputed h
419-
z = self.angle_embedding(lg.edata.pop("h"))
431+
z = self.angle_embedding(lg.edata["h"])
432+
# z = self.angle_embedding(lg.edata.pop("h"))
420433

421434
# r = g.edata["r"].clone().detach().requires_grad_(True)
422435
if self.config.use_cutoff_function:
@@ -623,6 +636,7 @@ def forward(
623636
stress = self.config.stress_multiplier * torch.stack(
624637
stresses
625638
)
639+
# print("stress",stress)
626640
# print("stress2", stress, stress.shape)
627641
# virial = (
628642
# 160.21766208

alignn/tests/test_alignn_ff.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def test_jdft_mbj_gap():
154154
val = atoms.get_potential_energy() # gap
155155

156156

157+
# test_jdft_mbj_gap()
158+
# test_alexandria_gap()
159+
# test_alexandria_gap()
157160
# print('test_graph_builder')
158161
# test_graph_builder()
159162
# print('test_ev')

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010

1111
setuptools.setup(
1212
name="alignn",
13-
version="2024.12.12",
13+
version="2025.4.1",
1414
author="Kamal Choudhary, Brian DeCost",
1515
author_email="kamal.choudhary@nist.gov",
1616
description="alignn",
1717
install_requires=[
18-
"numpy>=1.19.5",
18+
"numpy<2.0",
1919
# "numpy>=1.19.5,<2.0.0",
2020
"scipy>=1.6.1",
2121
"jarvis-tools>=2021.07.19",
22-
"torch>=2.0.0",
22+
"torch<=2.2.1",
2323
"mpmath<=1.3.0",
24-
"dgl>=0.6.0",
24+
"dgl<=1.1.1",
2525
"spglib>=2.0.2",
2626
"scikit-learn>=0.22.2",
2727
"matplotlib>=3.4.1",

0 commit comments

Comments
 (0)