Skip to content

Commit 2773fba

Browse files
committed
Lg
1 parent f2048aa commit 2773fba

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

alignn/graphs.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def get_line_graph(
449449
lg.apply_edges(compute_bond_cosines)
450450
return lg
451451
else:
452+
# print(g)
452453
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
453454
g.ndata["cart_coords"].requires_grad_(True)
454455
r, bondlength = compute_pair_vector_and_distance(g)
@@ -642,15 +643,20 @@ def atom_dgl_multigraph(
642643
# construct atomistic line graph
643644
# (nodes are bonds, edges are bond pairs)
644645
# and add bond angle cosines as edge features
645-
lg = get_line_graph(
646-
g,
647-
lat=atoms.lattice_mat,
648-
inner_cutoff=inner_cutoff,
649-
lighten_edges=lighten_edges,
650-
backtracking=backtracking,
651-
)
652-
# lg = g.line_graph(shared=True, backtracking=backtracking)
653-
# lg.apply_edges(compute_bond_cosines)
646+
# print("lighten_edges",lighten_edges)
647+
if lighten_edges:
648+
lg = get_line_graph(
649+
g,
650+
lat=torch.tensor(atoms.lattice_mat).type(
651+
torch.get_default_dtype()
652+
),
653+
inner_cutoff=inner_cutoff,
654+
lighten_edges=lighten_edges,
655+
backtracking=backtracking,
656+
)
657+
else:
658+
lg = g.line_graph(shared=True, backtracking=backtracking)
659+
lg.apply_edges(compute_bond_cosines)
654660
return g, lg
655661
else:
656662
return g

alignn/models/alignn_atomwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def forward(
375375
if len(g) == 3:
376376
g, lg, lat = g
377377
lg = lg.local_var()
378+
# print('lg',lg)
378379
# z = self.angle_embedding(lg.edata.pop("h"))
379380
z = self.angle_embedding(lg.edata["h"])
380381
else:

0 commit comments

Comments
 (0)