2626
2727
2828class ALIGNNAtomWiseConfig (BaseSettings ):
29- """Hyperparameter schema for jarvisdgl .models.alignn ."""
29+ """Hyperparameter schema for alignn .models.alignn_atomwise ."""
3030
3131 name : Literal ["alignn_atomwise" ]
3232 alignn_layers : int = 2
3333 gcn_layers : int = 2
3434 atom_input_features : int = 1
35- # atom_input_features: int = 92
3635 edge_input_features : int = 80
3736 triplet_input_features : int = 40
3837 embedding_features : int = 64
3938 hidden_features : int = 64
40- # hidden_features: int = 256
41- # fc_layers: int = 1
42- # fc_features: int = 64
4339 output_features : int = 1
4440 grad_multiplier : int = - 1
4541 calculate_gradient : bool = True
@@ -48,9 +44,6 @@ class ALIGNNAtomWiseConfig(BaseSettings):
4844 gradwise_weight : float = 1.0
4945 stresswise_weight : float = 0.0
5046 atomwise_weight : float = 0.0
51- # if link == log, apply `exp` to final outputs
52- # to constrain predictions to be positive
53- link : Literal ["identity" , "log" , "logit" ] = "identity"
5447 zero_inflated : bool = False
5548 classification : bool = False
5649 force_mult_natoms : bool = False
@@ -64,12 +57,21 @@ class ALIGNNAtomWiseConfig(BaseSettings):
6457 batch_stress : bool = True
6558 multiply_cutoff : bool = False
6659 use_penalty : bool = True
60+ lighten_edges : bool = True
61+ backtracking : bool = True
6762 extra_features : int = 0
6863 exponent : int = 5
6964 penalty_factor : float = 0.1
7065 penalty_threshold : float = 1
7166 additional_output_features : int = 0
7267 additional_output_weight : float = 0
68+ link : Literal ["identity" , "log" , "logit" ] = "identity"
69+ # if link == log, apply `exp` to final outputs
70+ # to constrain predictions to be positive
71+ # atom_input_features: int = 92
72+ # hidden_features: int = 256
73+ # fc_layers: int = 1
74+ # fc_features: int = 64
7375
7476 class Config :
7577 """Configure model settings behavior."""
@@ -246,6 +248,55 @@ def forward(
246248 return x , y , z
247249
248250
251+ def get_line_graph (
252+ g , lat = [], inner_cutoff = 3.0 , lighten_edges = False , backtracking = True
253+ ):
254+ if not lighten_edges :
255+ lg = g .line_graph (shared = True , backtracking = backtracking )
256+ # lg.ndata["r"] = r
257+ lg .apply_edges (compute_bond_cosines )
258+ return lg
259+ else :
260+ g .ndata ["cart_coords" ] = compute_cartesian_coordinates (g , lat )
261+ g .ndata ["cart_coords" ].requires_grad_ (True )
262+ r , bondlength = compute_pair_vector_and_distance (g )
263+ dst_pos = g .ndata ["cart_coords" ][g .edges ()[1 ]] + g .edata ["images" ]
264+ src_pos = g .ndata ["cart_coords" ][g .edges ()[0 ]]
265+ bond_vec = dst_pos - src_pos
266+ bond_dist = torch .norm (bond_vec , dim = 1 )
267+ pos = g .ndata ["cart_coords" ]
268+ g .edata ["bond_dist" ] = bond_dist
269+ g .edata ["r" ] = bond_vec
270+ src , dst = g .edges () # shape: [E], [E]
271+ pos = g .ndata ["cart_coords" ]
272+ src1 = src .unsqueeze (1 ) # [E,1]
273+ dst1 = dst .unsqueeze (1 ) # [E,1]
274+ src2 = src .unsqueeze (0 ) # [1,E]
275+ dst2 = dst .unsqueeze (0 ) # [1,E]
276+ # Broadcasted match on center node
277+ center_match = dst1 == src2 # [E, E] -> bool matrix
278+ # Get u, v, w for matching triples
279+ u = src1 .expand (- 1 , len (src )) # [E, E]
280+ v = dst1 .expand (- 1 , len (src )) # [E, E]
281+ v2 = src2 .expand (len (src ), - 1 ) # [E, E]
282+ w = dst2 .expand (len (src ), - 1 ) # [E, E]
283+ # Mask out u == w (no backtracking)
284+ non_backtrack = u != w
285+ # Compute distance from u to w for all pairs (eid1, eid2)
286+ pos_u = pos [u ]
287+ pos_w = pos [w ]
288+ uw_dist = torch .norm (pos_u - pos_w , dim = - 1 ) # [E, E]
289+ # Apply angular cutoff
290+ angle_mask = center_match & non_backtrack & (uw_dist < inner_cutoff )
291+ # Get edge pairs (eid1, eid2) for the line graph
292+ eid1 , eid2 = angle_mask .nonzero (as_tuple = True )
293+ # Create the line graph
294+ lg = dgl .graph ((eid1 , eid2 ), num_nodes = len (src ))
295+ lg .ndata ["r" ] = bond_vec
296+ lg .apply_edges (compute_bond_cosines )
297+ return lg
298+
299+
249300class ALIGNNAtomWise (nn .Module ):
250301 """Atomistic Line graph network.
251302
@@ -370,24 +421,24 @@ def forward(
370421 y: bond features (g.edata and lg.ndata)
371422 z: angle features (lg.edata)
372423 """
373- if len (self .alignn_layers ) > 0 :
374- if len (g ) == 3 :
375- g , lg , lat = g
376- lg = lg .local_var ()
377- # z = self.angle_embedding(lg.edata.pop("h"))
378- z = self .angle_embedding (lg .edata ["h" ])
379- else :
380- g , lat = g
381- g .ndata ["cart_coords" ] = compute_cartesian_coordinates (g , lat )
382- g .ndata ["cart_coords" ].requires_grad_ (True )
383- r , bondlength = compute_pair_vector_and_distance (g )
384- lg = g .line_graph (shared = True )
385- lg .ndata ["r" ] = r
386- lg .apply_edges (compute_bond_cosines )
387- # print('lg',lg)
388- # angle features (fixed)
424+ if len (g ) == 3 :
425+ g , lg , lat = g
426+ lg = lg .local_var ()
427+ # z = self.angle_embedding(lg.edata.pop("h"))
428+ z = self .angle_embedding (lg .edata ["h" ])
389429 else :
390430 g , lat = g
431+ if len (self .alignn_layers ) > 0 :
432+ lg = get_line_graph (
433+ g ,
434+ lat = lat ,
435+ inner_cutoff = self .config .inner_cutoff ,
436+ lighten_edges = self .config .lighten_edges ,
437+ backtracking = self .config .backtracking ,
438+ )
439+ z = self .angle_embedding (lg .edata ["h" ])
440+ print ("lg" , lg )
441+ # angle features (fixed)
391442 if self .config .extra_features != 0 :
392443 features = g .ndata ["extra_features" ]
393444 # print('features',features,features.shape)
@@ -409,6 +460,7 @@ def forward(
409460 r , bondlength = compute_pair_vector_and_distance (g )
410461 lg = g .line_graph (shared = True )
411462 lg .ndata ["r" ] = r
463+ print ("lg" , lg )
412464 lg .apply_edges (compute_bond_cosines )
413465
414466 # bondlength = torch.norm(r, dim=1)
0 commit comments