Skip to content

Commit 64bc6ff

Browse files
committed
Update markdown, lighten_edges for line graph.
1 parent 3679279 commit 64bc6ff

2 files changed

Lines changed: 79 additions & 26 deletions

File tree

alignn/graphs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def atom_dgl_multigraph(
483483
use_lattice_prop: bool = False,
484484
cutoff_extra=3.5,
485485
dtype="float32",
486+
backtracking=True,
486487
):
487488
"""Obtain a DGLGraph for Atoms object."""
488489
# print('id',id)
@@ -585,7 +586,7 @@ def atom_dgl_multigraph(
585586
# construct atomistic line graph
586587
# (nodes are bonds, edges are bond pairs)
587588
# and add bond angle cosines as edge features
588-
lg = g.line_graph(shared=True)
589+
lg = g.line_graph(shared=True, backtracking=backtracking)
589590
lg.apply_edges(compute_bond_cosines)
590591
return g, lg
591592
else:
@@ -1062,7 +1063,7 @@ def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
10621063

10631064
@staticmethod
10641065
def collate_line_graph(
1065-
samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]]
1066+
samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]],
10661067
):
10671068
"""Dataloader helper to batch graphs cross `samples`."""
10681069
graphs, line_graphs, lattices, labels = map(list, zip(*samples))

alignn/models/alignn_atomwise.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,16 @@
2626

2727

2828
class ALIGNNAtomWiseConfig(BaseSettings):
29-
"""Hyperparameter schema for jarvisdgl.models.alignn."""
29+
"""Hyperparameter schema for alignn.models.alignn_atomwise."""
3030

3131
name: Literal["alignn_atomwise"]
3232
alignn_layers: int = 2
3333
gcn_layers: int = 2
3434
atom_input_features: int = 1
35-
# atom_input_features: int = 92
3635
edge_input_features: int = 80
3736
triplet_input_features: int = 40
3837
embedding_features: int = 64
3938
hidden_features: int = 64
40-
# hidden_features: int = 256
41-
# fc_layers: int = 1
42-
# fc_features: int = 64
4339
output_features: int = 1
4440
grad_multiplier: int = -1
4541
calculate_gradient: bool = True
@@ -48,9 +44,6 @@ class ALIGNNAtomWiseConfig(BaseSettings):
4844
gradwise_weight: float = 1.0
4945
stresswise_weight: float = 0.0
5046
atomwise_weight: float = 0.0
51-
# if link == log, apply `exp` to final outputs
52-
# to constrain predictions to be positive
53-
link: Literal["identity", "log", "logit"] = "identity"
5447
zero_inflated: bool = False
5548
classification: bool = False
5649
force_mult_natoms: bool = False
@@ -64,12 +57,21 @@ class ALIGNNAtomWiseConfig(BaseSettings):
6457
batch_stress: bool = True
6558
multiply_cutoff: bool = False
6659
use_penalty: bool = True
60+
lighten_edges: bool = True
61+
backtracking: bool = True
6762
extra_features: int = 0
6863
exponent: int = 5
6964
penalty_factor: float = 0.1
7065
penalty_threshold: float = 1
7166
additional_output_features: int = 0
7267
additional_output_weight: float = 0
68+
link: Literal["identity", "log", "logit"] = "identity"
69+
# if link == log, apply `exp` to final outputs
70+
# to constrain predictions to be positive
71+
# atom_input_features: int = 92
72+
# hidden_features: int = 256
73+
# fc_layers: int = 1
74+
# fc_features: int = 64
7375

7476
class Config:
7577
"""Configure model settings behavior."""
@@ -246,6 +248,55 @@ def forward(
246248
return x, y, z
247249

248250

251+
def get_line_graph(
252+
g, lat=[], inner_cutoff=3.0, lighten_edges=False, backtracking=True
253+
):
254+
if not lighten_edges:
255+
lg = g.line_graph(shared=True, backtracking=backtracking)
256+
# lg.ndata["r"] = r
257+
lg.apply_edges(compute_bond_cosines)
258+
return lg
259+
else:
260+
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
261+
g.ndata["cart_coords"].requires_grad_(True)
262+
r, bondlength = compute_pair_vector_and_distance(g)
263+
dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"]
264+
src_pos = g.ndata["cart_coords"][g.edges()[0]]
265+
bond_vec = dst_pos - src_pos
266+
bond_dist = torch.norm(bond_vec, dim=1)
267+
pos = g.ndata["cart_coords"]
268+
g.edata["bond_dist"] = bond_dist
269+
g.edata["r"] = bond_vec
270+
src, dst = g.edges() # shape: [E], [E]
271+
pos = g.ndata["cart_coords"]
272+
src1 = src.unsqueeze(1) # [E,1]
273+
dst1 = dst.unsqueeze(1) # [E,1]
274+
src2 = src.unsqueeze(0) # [1,E]
275+
dst2 = dst.unsqueeze(0) # [1,E]
276+
# Broadcasted match on center node
277+
center_match = dst1 == src2 # [E, E] -> bool matrix
278+
# Get u, v, w for matching triples
279+
u = src1.expand(-1, len(src)) # [E, E]
280+
v = dst1.expand(-1, len(src)) # [E, E]
281+
v2 = src2.expand(len(src), -1) # [E, E]
282+
w = dst2.expand(len(src), -1) # [E, E]
283+
# Mask out u == w (no backtracking)
284+
non_backtrack = u != w
285+
# Compute distance from u to w for all pairs (eid1, eid2)
286+
pos_u = pos[u]
287+
pos_w = pos[w]
288+
uw_dist = torch.norm(pos_u - pos_w, dim=-1) # [E, E]
289+
# Apply angular cutoff
290+
angle_mask = center_match & non_backtrack & (uw_dist < inner_cutoff)
291+
# Get edge pairs (eid1, eid2) for the line graph
292+
eid1, eid2 = angle_mask.nonzero(as_tuple=True)
293+
# Create the line graph
294+
lg = dgl.graph((eid1, eid2), num_nodes=len(src))
295+
lg.ndata["r"] = bond_vec
296+
lg.apply_edges(compute_bond_cosines)
297+
return lg
298+
299+
249300
class ALIGNNAtomWise(nn.Module):
250301
"""Atomistic Line graph network.
251302
@@ -370,24 +421,24 @@ def forward(
370421
y: bond features (g.edata and lg.ndata)
371422
z: angle features (lg.edata)
372423
"""
373-
if len(self.alignn_layers) > 0:
374-
if len(g) == 3:
375-
g, lg, lat = g
376-
lg = lg.local_var()
377-
# z = self.angle_embedding(lg.edata.pop("h"))
378-
z = self.angle_embedding(lg.edata["h"])
379-
else:
380-
g, lat = g
381-
g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat)
382-
g.ndata["cart_coords"].requires_grad_(True)
383-
r, bondlength = compute_pair_vector_and_distance(g)
384-
lg = g.line_graph(shared=True)
385-
lg.ndata["r"] = r
386-
lg.apply_edges(compute_bond_cosines)
387-
# print('lg',lg)
388-
# angle features (fixed)
424+
if len(g) == 3:
425+
g, lg, lat = g
426+
lg = lg.local_var()
427+
# z = self.angle_embedding(lg.edata.pop("h"))
428+
z = self.angle_embedding(lg.edata["h"])
389429
else:
390430
g, lat = g
431+
if len(self.alignn_layers) > 0:
432+
lg = get_line_graph(
433+
g,
434+
lat=lat,
435+
inner_cutoff=self.config.inner_cutoff,
436+
lighten_edges=self.config.lighten_edges,
437+
backtracking=self.config.backtracking,
438+
)
439+
z = self.angle_embedding(lg.edata["h"])
440+
print("lg", lg)
441+
# angle features (fixed)
391442
if self.config.extra_features != 0:
392443
features = g.ndata["extra_features"]
393444
# print('features',features,features.shape)
@@ -409,6 +460,7 @@ def forward(
409460
r, bondlength = compute_pair_vector_and_distance(g)
410461
lg = g.line_graph(shared=True)
411462
lg.ndata["r"] = r
463+
print("lg", lg)
412464
lg.apply_edges(compute_bond_cosines)
413465

414466
# bondlength = torch.norm(r, dim=1)

0 commit comments

Comments
 (0)