Skip to content

Commit 77670cd

Browse files
committed
Conceptions on lighten_edges
1 parent 8dfae25 commit 77670cd

3 files changed

Lines changed: 72 additions & 54 deletions

File tree

alignn/ff/calculators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,19 @@ def calculate(self, atoms, properties=None, system_changes=None):
281281
"""Calculate properties."""
282282
j_atoms = ase_to_atoms(atoms)
283283
num_atoms = j_atoms.num_atoms
284-
g, lg = Graph.atom_dgl_multigraph(
284+
g = Graph.atom_dgl_multigraph(
285285
j_atoms,
286286
neighbor_strategy=self.config["neighbor_strategy"],
287287
cutoff=self.config["cutoff"],
288288
max_neighbors=self.config["max_neighbors"],
289289
atom_features=self.config["atom_features"],
290290
use_canonize=self.config["use_canonize"],
291+
inner_cutoff=self.config["model"]["inner_cutoff"],
292+
lighten_edges=self.config["model"]["lighten_edges"],
293+
compute_line_graph=self.config["compute_line_graph"],
291294
)
295+
if self.config["compute_line_graph"]:
296+
g, lg = g
292297

293298
if self.config["model"]["alignn_layers"] > 0:
294299
result = self.model(

alignn/graphs.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
import dgl
1717
from tqdm import tqdm
1818
from jarvis.core.atoms import Atoms
19+
from alignn.models.utils import (
20+
compute_cartesian_coordinates,
21+
compute_pair_vector_and_distance,
22+
)
1923

2024
# import matgl
2125

@@ -435,6 +439,55 @@ def radius_graph_old(
435439
###
436440

437441

442+
def get_line_graph(
443+
g, lat=[], inner_cutoff=3.0, lighten_edges=False, backtracking=True
444+
):
445+
if not lighten_edges:
446+
lg = g.line_graph(shared=True, backtracking=backtracking)
447+
# lg.ndata["r"] = r
448+
lg.apply_edges(compute_bond_cosines)
449+
return lg
450+
else:
451+
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
452+
g.ndata["cart_coords"].requires_grad_(True)
453+
r, bondlength = compute_pair_vector_and_distance(g)
454+
dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"]
455+
src_pos = g.ndata["cart_coords"][g.edges()[0]]
456+
bond_vec = dst_pos - src_pos
457+
bond_dist = torch.norm(bond_vec, dim=1)
458+
pos = g.ndata["cart_coords"]
459+
g.edata["bond_dist"] = bond_dist
460+
g.edata["r"] = bond_vec
461+
src, dst = g.edges() # shape: [E], [E]
462+
pos = g.ndata["cart_coords"]
463+
src1 = src.unsqueeze(1) # [E,1]
464+
dst1 = dst.unsqueeze(1) # [E,1]
465+
src2 = src.unsqueeze(0) # [1,E]
466+
dst2 = dst.unsqueeze(0) # [1,E]
467+
# Broadcasted match on center node
468+
center_match = dst1 == src2 # [E, E] -> bool matrix
469+
# Get u, v, w for matching triples
470+
u = src1.expand(-1, len(src)) # [E, E]
471+
# v = dst1.expand(-1, len(src)) # [E, E]
472+
# v2 = src2.expand(len(src), -1) # [E, E]
473+
w = dst2.expand(len(src), -1) # [E, E]
474+
# Mask out u == w (no backtracking)
475+
non_backtrack = u != w
476+
# Compute distance from u to w for all pairs (eid1, eid2)
477+
pos_u = pos[u]
478+
pos_w = pos[w]
479+
uw_dist = torch.norm(pos_u - pos_w, dim=-1) # [E, E]
480+
# Apply angular cutoff
481+
angle_mask = center_match & non_backtrack & (uw_dist < inner_cutoff)
482+
# Get edge pairs (eid1, eid2) for the line graph
483+
eid1, eid2 = angle_mask.nonzero(as_tuple=True)
484+
# Create the line graph
485+
lg = dgl.graph((eid1, eid2), num_nodes=len(src))
486+
lg.ndata["r"] = bond_vec
487+
lg.apply_edges(compute_bond_cosines)
488+
return lg
489+
490+
438491
class Graph(object):
439492
"""Generate a graph object."""
440493

@@ -484,6 +537,8 @@ def atom_dgl_multigraph(
484537
cutoff_extra=3.5,
485538
dtype="float32",
486539
backtracking=True,
540+
inner_cutoff=3.0,
541+
lighten_edges=False,
487542
):
488543
"""Obtain a DGLGraph for Atoms object."""
489544
# print('id',id)
@@ -586,8 +641,15 @@ def atom_dgl_multigraph(
586641
# construct atomistic line graph
587642
# (nodes are bonds, edges are bond pairs)
588643
# and add bond angle cosines as edge features
589-
lg = g.line_graph(shared=True, backtracking=backtracking)
590-
lg.apply_edges(compute_bond_cosines)
644+
lg = get_line_graph(
645+
g,
646+
lat=atoms.lattice_mat,
647+
inner_cutoff=inner_cutoff,
648+
lighten_edges=lighten_edges,
649+
backtracking=backtracking,
650+
)
651+
# lg = g.line_graph(shared=True, backtracking=backtracking)
652+
# lg.apply_edges(compute_bond_cosines)
591653
return g, lg
592654
else:
593655
return g

alignn/models/alignn_atomwise.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
compute_pair_vector_and_distance,
2222
MLPLayer,
2323
)
24-
from alignn.graphs import compute_bond_cosines
24+
from alignn.graphs import compute_bond_cosines, get_line_graph
2525
from alignn.utils import BaseSettings
2626

2727

@@ -248,55 +248,6 @@ def forward(
248248
return x, y, z
249249

250250

251-
def get_line_graph(
252-
g, lat=[], inner_cutoff=3.0, lighten_edges=False, backtracking=True
253-
):
254-
if not lighten_edges:
255-
lg = g.line_graph(shared=True, backtracking=backtracking)
256-
# lg.ndata["r"] = r
257-
lg.apply_edges(compute_bond_cosines)
258-
return lg
259-
else:
260-
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
261-
g.ndata["cart_coords"].requires_grad_(True)
262-
r, bondlength = compute_pair_vector_and_distance(g)
263-
dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"]
264-
src_pos = g.ndata["cart_coords"][g.edges()[0]]
265-
bond_vec = dst_pos - src_pos
266-
bond_dist = torch.norm(bond_vec, dim=1)
267-
pos = g.ndata["cart_coords"]
268-
g.edata["bond_dist"] = bond_dist
269-
g.edata["r"] = bond_vec
270-
src, dst = g.edges() # shape: [E], [E]
271-
pos = g.ndata["cart_coords"]
272-
src1 = src.unsqueeze(1) # [E,1]
273-
dst1 = dst.unsqueeze(1) # [E,1]
274-
src2 = src.unsqueeze(0) # [1,E]
275-
dst2 = dst.unsqueeze(0) # [1,E]
276-
# Broadcasted match on center node
277-
center_match = dst1 == src2 # [E, E] -> bool matrix
278-
# Get u, v, w for matching triples
279-
u = src1.expand(-1, len(src)) # [E, E]
280-
v = dst1.expand(-1, len(src)) # [E, E]
281-
v2 = src2.expand(len(src), -1) # [E, E]
282-
w = dst2.expand(len(src), -1) # [E, E]
283-
# Mask out u == w (no backtracking)
284-
non_backtrack = u != w
285-
# Compute distance from u to w for all pairs (eid1, eid2)
286-
pos_u = pos[u]
287-
pos_w = pos[w]
288-
uw_dist = torch.norm(pos_u - pos_w, dim=-1) # [E, E]
289-
# Apply angular cutoff
290-
angle_mask = center_match & non_backtrack & (uw_dist < inner_cutoff)
291-
# Get edge pairs (eid1, eid2) for the line graph
292-
eid1, eid2 = angle_mask.nonzero(as_tuple=True)
293-
# Create the line graph
294-
lg = dgl.graph((eid1, eid2), num_nodes=len(src))
295-
lg.ndata["r"] = bond_vec
296-
lg.apply_edges(compute_bond_cosines)
297-
return lg
298-
299-
300251
class ALIGNNAtomWise(nn.Module):
301252
"""Atomistic Line graph network.
302253
@@ -437,7 +388,7 @@ def forward(
437388
backtracking=self.config.backtracking,
438389
)
439390
z = self.angle_embedding(lg.edata["h"])
440-
print("lg", lg)
391+
# print("lg", lg)
441392
# angle features (fixed)
442393
if self.config.extra_features != 0:
443394
features = g.ndata["extra_features"]

0 commit comments

Comments
 (0)