@@ -449,6 +449,7 @@ def get_line_graph(
449449 lg .apply_edges (compute_bond_cosines )
450450 return lg
451451 else :
452+ # print(g)
452453 g .ndata ["cart_coords" ] = compute_cartesian_coordinates (g , lat )
453454 g .ndata ["cart_coords" ].requires_grad_ (True )
454455 r , bondlength = compute_pair_vector_and_distance (g )
@@ -642,15 +643,20 @@ def atom_dgl_multigraph(
642643 # construct atomistic line graph
643644 # (nodes are bonds, edges are bond pairs)
644645 # and add bond angle cosines as edge features
645- lg = get_line_graph (
646- g ,
647- lat = atoms .lattice_mat ,
648- inner_cutoff = inner_cutoff ,
649- lighten_edges = lighten_edges ,
650- backtracking = backtracking ,
651- )
652- # lg = g.line_graph(shared=True, backtracking=backtracking)
653- # lg.apply_edges(compute_bond_cosines)
646+ # print("lighten_edges",lighten_edges)
647+ if lighten_edges :
648+ lg = get_line_graph (
649+ g ,
650+ lat = torch .tensor (atoms .lattice_mat ).type (
651+ torch .get_default_dtype ()
652+ ),
653+ inner_cutoff = inner_cutoff ,
654+ lighten_edges = lighten_edges ,
655+ backtracking = backtracking ,
656+ )
657+ else :
658+ lg = g .line_graph (shared = True , backtracking = backtracking )
659+ lg .apply_edges (compute_bond_cosines )
654660 return g , lg
655661 else :
656662 return g
0 commit comments