Skip to content

Commit a158591

Browse files
committed
Add some more utils to LG
1 parent cf577a7 commit a158591

File tree

1 file changed

+90
-3
lines changed

1 file changed

+90
-3
lines changed

gluefactory/models/matchers/lightglue.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,70 @@ def forward(
256256
return x0, x1
257257

258258

259+
class UniCrossBlock(nn.Module):
260+
def __init__(
261+
self,
262+
embed_dim: int,
263+
num_heads: int,
264+
flash: bool = False,
265+
bias: bool = True,
266+
dropout: float = 0.0,
267+
) -> None:
268+
super().__init__()
269+
self.heads = num_heads
270+
dim_head = embed_dim // num_heads
271+
self.scale = dim_head**-0.5
272+
inner_dim = dim_head * num_heads
273+
self.to_k = nn.Linear(embed_dim, inner_dim, bias=bias)
274+
self.to_q = nn.Linear(embed_dim, inner_dim, bias=bias)
275+
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
276+
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
277+
self.ffn = nn.Sequential(
278+
nn.Linear(2 * embed_dim, 2 * embed_dim),
279+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
280+
nn.GELU(),
281+
nn.Linear(2 * embed_dim, embed_dim),
282+
)
283+
if flash and FLASH_AVAILABLE:
284+
self.flash = Attention(True)
285+
else:
286+
self.flash = None
287+
288+
if dropout > 1.0e-4:
289+
self.dropout = nn.Dropout(dropout)
290+
else:
291+
self.dropout = nn.Identity()
292+
293+
def forward(
294+
self,
295+
x0: torch.Tensor,
296+
x1: torch.Tensor,
297+
mask: Optional[torch.Tensor] = None,
298+
encoding0: Optional[torch.Tensor] = None,
299+
encoding1: Optional[torch.Tensor] = None,
300+
) -> torch.Tensor:
301+
q = self.to_q(x0)
302+
k = self.to_k(x1)
303+
v = self.to_v(x1)
304+
q, k, v = map(
305+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
306+
(q, k, v),
307+
)
308+
if encoding0 is not None and encoding1 is not None:
309+
q = apply_cached_rotary_emb(encoding0, q)
310+
k = apply_cached_rotary_emb(encoding1, k)
311+
if self.flash is not None and q.device.type == "cuda":
312+
m0 = self.flash(q, k, v, mask)
313+
else:
314+
raise NotImplementedError(
315+
"Non-flash attention not implemented for UniCrossBlock"
316+
)
317+
m0 = m0.transpose(1, 2).flatten(start_dim=-2)
318+
m0 = self.to_out(m0)
319+
x0 = x0 + self.dropout(self.ffn(torch.cat([x0, m0], -1)))
320+
return x0
321+
322+
259323
class TransformerLayer(nn.Module):
260324
def __init__(self, *args, **kwargs):
261325
super().__init__()
@@ -272,7 +336,22 @@ def forward(
272336
cross_encoding1: torch.Tensor | None = None,
273337
mask0: Optional[torch.Tensor] = None,
274338
mask1: Optional[torch.Tensor] = None,
339+
checkpointed: bool = False,
340+
reentrant: bool = True,
275341
):
342+
if checkpointed and self.training:
343+
return torch.utils.checkpoint.checkpoint(
344+
self.forward,
345+
desc0,
346+
desc1,
347+
encoding0,
348+
encoding1,
349+
cross_encoding0,
350+
cross_encoding1,
351+
mask0,
352+
mask1,
353+
use_reentrant=reentrant,
354+
)
276355
if mask0 is not None and mask1 is not None:
277356
return self.masked_forward(
278357
desc0,
@@ -337,18 +416,26 @@ def sigmoid_log_double_softmax(
337416

338417

339418
class MatchAssignment(nn.Module):
340-
def __init__(self, dim: int) -> None:
419+
def __init__(
420+
self, dim: int, out_dim: int | None = None, temperature: float | None = None
421+
) -> None:
341422
super().__init__()
342-
self.dim = dim
343423
self.matchability = nn.Linear(dim, 1, bias=True)
344-
self.final_proj = nn.Linear(dim, dim, bias=True)
424+
self.final_proj = nn.Linear(dim, out_dim or dim, bias=True)
425+
426+
if temperature is not None:
427+
self.temperature = nn.Parameter(torch.Tensor([temperature]))
428+
else:
429+
self.temperature = None
345430

346431
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
347432
"""build assignment matrix from descriptors"""
348433
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
349434
_, _, d = mdesc0.shape
350435
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
351436
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
437+
if self.temperature is not None:
438+
sim = sim * self.temperature
352439
z0 = self.matchability(desc0)
353440
z1 = self.matchability(desc1)
354441
scores = sigmoid_log_double_softmax(sim, z0, z1)

0 commit comments

Comments
 (0)