Skip to content

Commit 792a56c

Browse files
committed
updates to GT and LG
1 parent a158591 commit 792a56c

File tree

5 files changed

+3
-12
lines changed

5 files changed

+3
-12
lines changed

gluefactory/geometry/gt_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def gt_matches_from_pose_depth(
1616
epi_th=None,
1717
cc_th=None,
1818
min_overlap: float | None = None,
19-
add_epi_outliers=True,
19+
add_epi_outliers: bool = True,
2020
**kw,
2121
):
2222
if kp0.shape[1] == 0 or kp1.shape[1] == 0:

gluefactory/models/extractors/mixed.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,8 @@ class MixedExtractor(BaseModel):
3939
def _init(self, conf):
4040
if conf.detector.name:
4141
self.detector = get_model(conf.detector.name)(to_ctr(conf.detector))
42-
else:
43-
self.required_data_keys += ["cache"]
44-
self.required_cache_keys += ["keypoints"]
45-
4642
if conf.descriptor.name:
4743
self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor))
48-
else:
49-
self.required_data_keys += ["cache"]
50-
self.required_cache_keys += ["descriptors"]
51-
5244
if conf.refiner.name:
5345
self.refiner = get_model(conf.refiner.name)(to_ctr(conf.refiner))
5446

gluefactory/models/matchers/depth_matcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def _init(self, conf):
3737
]
3838

3939
@misc.filter_batch_for_jit
40-
@misc.AMP_CUSTOM_FWD_F32
4140
def _forward(self, data):
4241
return self.match_with_depth(data)
4342

gluefactory/models/matchers/lightglue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
119119
if self.enable_flash and q.device.type == "cuda":
120120
# use torch 2.0 scaled_dot_product_attention with flash
121121
if FLASH_AVAILABLE:
122-
args = [x.half().contiguous() for x in [q, k, v]]
122+
args = [x.contiguous() for x in [q, k, v]]
123123
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
124124
return v if mask is None else v.nan_to_num()
125125
elif FLASH_AVAILABLE:

gluefactory/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def train_step(
710710
for k, v in loss_metrics.items():
711711
val = v.detach()
712712
if self.distributed:
713-
torch.distributed.all_reduce(val)
713+
torch.distributed.all_reduce(val.contiguous())
714714
val = val / self.num_gpus
715715
loss_metrics[k] = val
716716
self.step_timer.measure("loss_fn")

0 commit comments

Comments
 (0)