@@ -256,6 +256,70 @@ def forward(
256256 return x0 , x1
257257
258258
259+ class UniCrossBlock (nn .Module ):
260+ def __init__ (
261+ self ,
262+ embed_dim : int ,
263+ num_heads : int ,
264+ flash : bool = False ,
265+ bias : bool = True ,
266+ dropout : float = 0.0 ,
267+ ) -> None :
268+ super ().__init__ ()
269+ self .heads = num_heads
270+ dim_head = embed_dim // num_heads
271+ self .scale = dim_head ** - 0.5
272+ inner_dim = dim_head * num_heads
273+ self .to_k = nn .Linear (embed_dim , inner_dim , bias = bias )
274+ self .to_q = nn .Linear (embed_dim , inner_dim , bias = bias )
275+ self .to_v = nn .Linear (embed_dim , inner_dim , bias = bias )
276+ self .to_out = nn .Linear (inner_dim , embed_dim , bias = bias )
277+ self .ffn = nn .Sequential (
278+ nn .Linear (2 * embed_dim , 2 * embed_dim ),
279+ nn .LayerNorm (2 * embed_dim , elementwise_affine = True ),
280+ nn .GELU (),
281+ nn .Linear (2 * embed_dim , embed_dim ),
282+ )
283+ if flash and FLASH_AVAILABLE :
284+ self .flash = Attention (True )
285+ else :
286+ self .flash = None
287+
288+ if dropout > 1.0e-4 :
289+ self .dropout = nn .Dropout (dropout )
290+ else :
291+ self .dropout = nn .Identity ()
292+
293+ def forward (
294+ self ,
295+ x0 : torch .Tensor ,
296+ x1 : torch .Tensor ,
297+ mask : Optional [torch .Tensor ] = None ,
298+ encoding0 : Optional [torch .Tensor ] = None ,
299+ encoding1 : Optional [torch .Tensor ] = None ,
300+ ) -> torch .Tensor :
301+ q = self .to_q (x0 )
302+ k = self .to_k (x1 )
303+ v = self .to_v (x1 )
304+ q , k , v = map (
305+ lambda t : t .unflatten (- 1 , (self .heads , - 1 )).transpose (1 , 2 ),
306+ (q , k , v ),
307+ )
308+ if encoding0 is not None and encoding1 is not None :
309+ q = apply_cached_rotary_emb (encoding0 , q )
310+ k = apply_cached_rotary_emb (encoding1 , k )
311+ if self .flash is not None and q .device .type == "cuda" :
312+ m0 = self .flash (q , k , v , mask )
313+ else :
314+ raise NotImplementedError (
315+ "Non-flash attention not implemented for UniCrossBlock"
316+ )
317+ m0 = m0 .transpose (1 , 2 ).flatten (start_dim = - 2 )
318+ m0 = self .to_out (m0 )
319+ x0 = x0 + self .dropout (self .ffn (torch .cat ([x0 , m0 ], - 1 )))
320+ return x0
321+
322+
259323class TransformerLayer (nn .Module ):
260324 def __init__ (self , * args , ** kwargs ):
261325 super ().__init__ ()
@@ -272,7 +336,22 @@ def forward(
272336 cross_encoding1 : torch .Tensor | None = None ,
273337 mask0 : Optional [torch .Tensor ] = None ,
274338 mask1 : Optional [torch .Tensor ] = None ,
339+ checkpointed : bool = False ,
340+ reentrant : bool = True ,
275341 ):
342+ if checkpointed and self .training :
343+ return torch .utils .checkpoint .checkpoint (
344+ self .forward ,
345+ desc0 ,
346+ desc1 ,
347+ encoding0 ,
348+ encoding1 ,
349+ cross_encoding0 ,
350+ cross_encoding1 ,
351+ mask0 ,
352+ mask1 ,
353+ use_reentrant = reentrant ,
354+ )
276355 if mask0 is not None and mask1 is not None :
277356 return self .masked_forward (
278357 desc0 ,
@@ -337,18 +416,26 @@ def sigmoid_log_double_softmax(
337416
338417
339418class MatchAssignment (nn .Module ):
340- def __init__ (self , dim : int ) -> None :
419+ def __init__ (
420+ self , dim : int , out_dim : int | None = None , temperature : float | None = None
421+ ) -> None :
341422 super ().__init__ ()
342- self .dim = dim
343423 self .matchability = nn .Linear (dim , 1 , bias = True )
344- self .final_proj = nn .Linear (dim , dim , bias = True )
424+ self .final_proj = nn .Linear (dim , out_dim or dim , bias = True )
425+
426+ if temperature is not None :
427+ self .temperature = nn .Parameter (torch .Tensor ([temperature ]))
428+ else :
429+ self .temperature = None
345430
346431 def forward (self , desc0 : torch .Tensor , desc1 : torch .Tensor ):
347432 """build assignment matrix from descriptors"""
348433 mdesc0 , mdesc1 = self .final_proj (desc0 ), self .final_proj (desc1 )
349434 _ , _ , d = mdesc0 .shape
350435 mdesc0 , mdesc1 = mdesc0 / d ** 0.25 , mdesc1 / d ** 0.25
351436 sim = torch .einsum ("bmd,bnd->bmn" , mdesc0 , mdesc1 )
437+ if self .temperature is not None :
438+ sim = sim * self .temperature
352439 z0 = self .matchability (desc0 )
353440 z1 = self .matchability (desc1 )
354441 scores = sigmoid_log_double_softmax (sim , z0 , z1 )
0 commit comments