|
16 | 16 | import dgl |
17 | 17 | from tqdm import tqdm |
18 | 18 | from jarvis.core.atoms import Atoms |
| 19 | +from alignn.models.utils import ( |
| 20 | + compute_cartesian_coordinates, |
| 21 | + compute_pair_vector_and_distance, |
| 22 | +) |
19 | 23 |
|
20 | 24 | # import matgl |
21 | 25 |
|
@@ -435,6 +439,55 @@ def radius_graph_old( |
435 | 439 | ### |
436 | 440 |
|
437 | 441 |
|
| 442 | +def get_line_graph( |
| 443 | + g, lat=[], inner_cutoff=3.0, lighten_edges=False, backtracking=True |
| 444 | +): |
| 445 | + if not lighten_edges: |
| 446 | + lg = g.line_graph(shared=True, backtracking=backtracking) |
| 447 | + # lg.ndata["r"] = r |
| 448 | + lg.apply_edges(compute_bond_cosines) |
| 449 | + return lg |
| 450 | + else: |
| 451 | + g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat) |
| 452 | + g.ndata["cart_coords"].requires_grad_(True) |
| 453 | + r, bondlength = compute_pair_vector_and_distance(g) |
| 454 | + dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"] |
| 455 | + src_pos = g.ndata["cart_coords"][g.edges()[0]] |
| 456 | + bond_vec = dst_pos - src_pos |
| 457 | + bond_dist = torch.norm(bond_vec, dim=1) |
| 458 | + pos = g.ndata["cart_coords"] |
| 459 | + g.edata["bond_dist"] = bond_dist |
| 460 | + g.edata["r"] = bond_vec |
| 461 | + src, dst = g.edges() # shape: [E], [E] |
| 462 | + pos = g.ndata["cart_coords"] |
| 463 | + src1 = src.unsqueeze(1) # [E,1] |
| 464 | + dst1 = dst.unsqueeze(1) # [E,1] |
| 465 | + src2 = src.unsqueeze(0) # [1,E] |
| 466 | + dst2 = dst.unsqueeze(0) # [1,E] |
| 467 | + # Broadcasted match on center node |
| 468 | + center_match = dst1 == src2 # [E, E] -> bool matrix |
| 469 | + # Get u, v, w for matching triples |
| 470 | + u = src1.expand(-1, len(src)) # [E, E] |
| 471 | + # v = dst1.expand(-1, len(src)) # [E, E] |
| 472 | + # v2 = src2.expand(len(src), -1) # [E, E] |
| 473 | + w = dst2.expand(len(src), -1) # [E, E] |
| 474 | + # Mask out u == w (no backtracking) |
| 475 | + non_backtrack = u != w |
| 476 | + # Compute distance from u to w for all pairs (eid1, eid2) |
| 477 | + pos_u = pos[u] |
| 478 | + pos_w = pos[w] |
| 479 | + uw_dist = torch.norm(pos_u - pos_w, dim=-1) # [E, E] |
| 480 | + # Apply angular cutoff |
| 481 | + angle_mask = center_match & non_backtrack & (uw_dist < inner_cutoff) |
| 482 | + # Get edge pairs (eid1, eid2) for the line graph |
| 483 | + eid1, eid2 = angle_mask.nonzero(as_tuple=True) |
| 484 | + # Create the line graph |
| 485 | + lg = dgl.graph((eid1, eid2), num_nodes=len(src)) |
| 486 | + lg.ndata["r"] = bond_vec |
| 487 | + lg.apply_edges(compute_bond_cosines) |
| 488 | + return lg |
| 489 | + |
| 490 | + |
438 | 491 | class Graph(object): |
439 | 492 | """Generate a graph object.""" |
440 | 493 |
|
@@ -484,6 +537,8 @@ def atom_dgl_multigraph( |
484 | 537 | cutoff_extra=3.5, |
485 | 538 | dtype="float32", |
486 | 539 | backtracking=True, |
| 540 | + inner_cutoff=3.0, |
| 541 | + lighten_edges=False, |
487 | 542 | ): |
488 | 543 | """Obtain a DGLGraph for Atoms object.""" |
489 | 544 | # print('id',id) |
@@ -586,8 +641,15 @@ def atom_dgl_multigraph( |
586 | 641 | # construct atomistic line graph |
587 | 642 | # (nodes are bonds, edges are bond pairs) |
588 | 643 | # and add bond angle cosines as edge features |
589 | | - lg = g.line_graph(shared=True, backtracking=backtracking) |
590 | | - lg.apply_edges(compute_bond_cosines) |
| 644 | + lg = get_line_graph( |
| 645 | + g, |
| 646 | + lat=atoms.lattice_mat, |
| 647 | + inner_cutoff=inner_cutoff, |
| 648 | + lighten_edges=lighten_edges, |
| 649 | + backtracking=backtracking, |
| 650 | + ) |
| 651 | + # lg = g.line_graph(shared=True, backtracking=backtracking) |
| 652 | + # lg.apply_edges(compute_bond_cosines) |
591 | 653 | return g, lg |
592 | 654 | else: |
593 | 655 | return g |
|
0 commit comments