Skip to content

Commit cf577a7

Browse files
committed
Better handle outlier matches, bugfixes
1 parent 58667bf commit cf577a7

File tree

7 files changed

+118
-24
lines changed

7 files changed

+118
-24
lines changed

gluefactory/geometry/depth.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def project(
6060
camera_i,
6161
camera_j,
6262
T_itoj,
63-
validi,
6463
ccth=None,
6564
sample_depth_fun=sample_depth,
6665
sample_depth_kwargs=None,
@@ -71,20 +70,22 @@ def project(
7170
kpi_3d_i = camera_i.image2cam(kpi)
7271
kpi_3d_i = kpi_3d_i * di[..., None]
7372
kpi_3d_j = T_itoj.transform(kpi_3d_i)
74-
kpi_j, validj = camera_j.cam2image(kpi_3d_j)
73+
kpi_j, valid = camera_j.cam2image(kpi_3d_j)
74+
invalid = ~valid
7575
# di_j = kpi_3d_j[..., -1]
76-
validi = validi & validj
7776
if depthj is None or ccth is None:
78-
return kpi_j, validi & validj
77+
return kpi_j, valid, invalid
7978
else:
8079
# circle consistency
8180
dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs)
8281
kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None]
8382
kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j))
84-
consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth
85-
visible = validi & consistent & validj_i & validj
83+
reproj_error = ((kpi - kpi_j_i) ** 2).sum(-1)
84+
consistent = reproj_error < ccth**2
85+
visible = valid & consistent & validj_i & validj
86+
invalid = invalid | (validj & ((~validj_i) | (~consistent)))
8687
# visible = validi
87-
return kpi_j, visible
88+
return kpi_j, visible, invalid
8889

8990

9091
def dense_warp_consistency(
@@ -100,7 +101,8 @@ def dense_warp_consistency(
100101
-2,
101102
)
102103
validi = di > 0
103-
kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs)
104+
kpir, validir, _ = project(kpi, di, depthj, camerai, cameraj, T_itoj, **kwargs)
105+
validir = validir & validi
104106

105107
return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten(
106108
-1, (depthi.shape[-2:])
@@ -120,12 +122,10 @@ def symmetric_reprojection_error(
120122
d0, valid0 = sample_depth(pts0, depth0)
121123
d1, valid1 = sample_depth(pts1, depth1)
122124

123-
pts0_1, visible0 = project(
124-
pts0, d0, depth1, camera0, camera1, T_0to1, valid0, ccth=None
125-
)
126-
pts1_0, visible1 = project(
127-
pts1, d1, depth0, camera1, camera0, T_1to0, valid1, ccth=None
128-
)
125+
pts0_1, visible0, _ = project(pts0, d0, depth1, camera0, camera1, T_0to1, ccth=None)
126+
visible0 = visible0 & valid0
127+
pts1_0, visible1, _ = project(pts1, d1, depth0, camera1, camera0, T_1to0, ccth=None)
128+
visible1 = visible1 & valid1
129129

130130
reprojection_errors_px = 0.5 * (
131131
(pts0_1 - pts1).norm(dim=-1) + (pts1_0 - pts0).norm(dim=-1)

gluefactory/geometry/epipolar.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,37 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
162162
r_err = angle_error_mat(R, R_gt)
163163

164164
return t_err, r_err
165+
166+
167+
def check_epipolar_intersection(
168+
x_i0: torch.Tensor,
169+
i1_F_i0: torch.Tensor,
170+
width1: int | torch.Tensor,
171+
height1: int | torch.Tensor,
172+
) -> torch.BoolTensor:
173+
x_i0 = tr.to_homogeneous(x_i0) # (..., 3)
174+
L_B = x_i0 @ i1_F_i0.T
175+
176+
l1, l2, l3 = L_B.split(1, dim=1)
177+
178+
eps = 1e-6
179+
180+
# Vertical boundary checks (x=0, x=W_B)
181+
mask_v = torch.abs(l2) > eps
182+
l2_inv = torch.where(mask_v, 1.0 / l2, torch.tensor(0.0, device=x_i0.device))
183+
184+
y0 = -l3 * l2_inv
185+
yW = -(l1 * width1 + l3) * l2_inv
186+
187+
check_v = mask_v & (((y0 >= 0) & (y0 <= height1)) | ((yW >= 0) & (yW <= height1)))
188+
189+
# Horizontal boundary checks (y=0, y=H_B)
190+
mask_h = torch.abs(l1) > eps
191+
l1_inv = torch.where(mask_h, 1.0 / l1, torch.tensor(0.0, device=x_i0.device))
192+
193+
x0 = -l3 * l1_inv
194+
xH = -(l2 * height1 + l3) * l1_inv
195+
196+
check_h = mask_h & (((x0 >= 0) & (x0 <= width1)) | ((xH >= 0) & (xH <= width1)))
197+
198+
return (check_v | check_h).squeeze(-1)

gluefactory/geometry/gt_generation.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def gt_matches_from_pose_depth(
1616
epi_th=None,
1717
cc_th=None,
1818
min_overlap: float | None = None,
19+
add_epi_outliers=True,
1920
**kw,
2021
):
2122
if kp0.shape[1] == 0 or kp1.shape[1] == 0:
@@ -41,12 +42,32 @@ def gt_matches_from_pose_depth(
4142
d0, valid0 = depth.sample_depth(kp0, depth0)
4243
d1, valid1 = depth.sample_depth(kp1, depth1)
4344

44-
kp0_1, visible0 = depth.project(
45-
kp0, d0, depth1, camera0, camera1, T_0to1, valid0, ccth=cc_th
45+
kp0_1, visible0, unmatchable0 = depth.project(
46+
kp0, d0, depth1, camera0, camera1, T_0to1, ccth=cc_th
4647
)
47-
kp1_0, visible1 = depth.project(
48-
kp1, d1, depth0, camera1, camera0, T_1to0, valid1, ccth=cc_th
48+
visible0 = visible0 & valid0
49+
kp1_0, visible1, unmatchable1 = depth.project(
50+
kp1, d1, depth0, camera1, camera0, T_1to0, ccth=cc_th
4951
)
52+
visible1 = visible1 & valid1
53+
54+
unmatchable0 = valid0 & unmatchable0
55+
unmatchable1 = valid1 & unmatchable1
56+
57+
if add_epi_outliers:
58+
i1_F_i0 = epipolar.T_to_F(camera0, camera1, T_0to1)
59+
image_size0 = data["view0"]["image_size"]
60+
image_size1 = data["view1"]["image_size"]
61+
62+
evalid0 = torch.vmap(
63+
epipolar.check_epipolar_intersection,
64+
)(kp0, i1_F_i0, image_size1[:, 0], image_size1[:, 1])
65+
evalid1 = torch.vmap(
66+
epipolar.check_epipolar_intersection,
67+
)(kp1, i1_F_i0.transpose(-1, -2), image_size0[:, 0], image_size0[:, 1])
68+
69+
unmatchable0 = unmatchable0 | (~evalid0)
70+
unmatchable1 = unmatchable1 | (~evalid1)
5071
if min_overlap is not None and "overlap_0to1" in data:
5172
has_overlap = (
5273
torch.max(data["overlap_0to1"], data["overlap_1to0"]) > min_overlap
@@ -118,6 +139,8 @@ def gt_matches_from_pose_depth(
118139
"proj_1to0": kp1_0,
119140
"visible0": visible0,
120141
"visible1": visible1,
142+
"unmatchable0": unmatchable0,
143+
"unmatchable1": unmatchable1,
121144
"has_overlap": has_overlap,
122145
"xyz_keypoints0": c0_t_w.inv() @ (camera0.image2cam(kp0) * d0.unsqueeze(-1)),
123146
"xyz_keypoints1": c1_t_w.inv() @ (camera1.image2cam(kp1) * d1.unsqueeze(-1)),

gluefactory/utils/experiments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def compose_config(
7272
default_config_dir: str = "configs/",
7373
overrides: Optional[list[str]] = None,
7474
sweep_idx: int | None = None,
75+
resolve: bool = True,
7576
) -> tuple[Path, OmegaConf]:
7677

7778
conf_path = parse_config_path(name_or_path, default_config_dir)
@@ -88,6 +89,8 @@ def compose_config(
8889
OmegaConf.set_struct(custom_conf, False)
8990
custom_conf = OmegaConf.merge(custom_conf, sweep_conf)
9091
del custom_conf["sweep"]
92+
if resolve:
93+
OmegaConf.resolve(custom_conf)
9194
return conf_path, custom_conf
9295

9396

gluefactory/utils/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,7 @@ def ranking_ap(m, gt_m, scores):
4848
f"{prefix}average_precision": ap,
4949
"num_matchable": (data[f"gt_{prefix_gt}matches0"] > -1).sum(1),
5050
"num_unmatchable": (data[f"gt_{prefix_gt}matches0"] == -1).sum(1),
51+
"num_matches": (pred[f"{prefix}matches0"] > -1).sum(1),
52+
"average_match_score": pred[f"{prefix}matching_scores0"].mean(1),
5153
}
5254
return metrics

gluefactory/utils/misc.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,13 @@ def resize_image(
471471
return image_out
472472

473473

474+
def l2_normalize(
475+
tensor: torch.Tensor, dim: int = -1, eps: float = 1e-10
476+
) -> torch.Tensor:
477+
norm = torch.norm(tensor, p=2, dim=dim, keepdim=True).clamp_min(eps)
478+
return tensor / norm
479+
480+
474481
def is_image_of_shape(image: torch.Tensor, hw: tuple[int, int]) -> bool:
475482
h, w = hw
476483
return h in image.shape and w in image.shape
@@ -806,7 +813,7 @@ def interpolate_patches(
806813
is_chw: bool = False,
807814
align_corners: bool = False,
808815
padding_mode: str = "zeros",
809-
) -> tuple[torch.Tensor, torch.Tensor]: # B x N x D x ps x ps, B x N x 2
816+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # B x N x D x ps x ps, B x N x 2
810817
if not is_chw:
811818
features = chw_from_hwc(features)
812819
if normalize:
@@ -817,8 +824,8 @@ def interpolate_patches(
817824
(1, 1, ps, ps), device=features.device, dtype=features.dtype
818825
)
819826
p_xy = get_image_coords(dummy_patch)
820-
corners = torch.round(pts_i - ps / 2 - 0.5)
821-
p_xy_i = corners[:, :, None, None, :] + p_xy[:, None]
827+
cxy_i = torch.round(pts_i - ps / 2 - 0.5)
828+
p_xy_i = cxy_i[:, :, None, None, :] + p_xy[:, None]
822829
p_xy_n = normalize_coords(p_xy_i, features.shape[-2:])
823830
patches = torch.vmap(grid_sample, in_dims=(None, 1), out_dims=1)(
824831
features,
@@ -829,7 +836,32 @@ def interpolate_patches(
829836
)
830837
if not is_chw:
831838
patches = hwc_from_chw(patches)
832-
return patches, corners
839+
return patches, p_xy_n, cxy_i
840+
841+
842+
def patch_interpolate_points(
843+
pts: torch.Tensor, # B x N X 2
844+
patches: torch.Tensor, # B x N x D x ps x ps OR B x N x ps x ps x D
845+
**kwargs,
846+
):
847+
return torch.vmap(interpolate_points, in_dims=0, out_dims=0)(
848+
pts[:, :, None],
849+
patches,
850+
**kwargs,
851+
)[..., 0, :]
852+
853+
854+
def log_softmax(scores: torch.Tensor, dim: int | tuple = -1) -> torch.Tensor:
855+
"""Numerically stable log softmax."""
856+
if isinstance(dim, int):
857+
return torch.log_softmax(scores, dim=dim)
858+
else:
859+
last = tuple(range(-len(dim), 0))
860+
scores = scores.moveaxis(dim, last)
861+
log_probs = torch.log_softmax(scores.flatten(-len(dim)), dim=-1).reshape(
862+
*scores.shape
863+
)
864+
return log_probs.moveaxis(last, dim)
833865

834866

835867
def interpolate_matches(

gluefactory/visualization/viz2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def plot_image_grid(
164164
return axs
165165

166166

167-
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
167+
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0, **kwargs):
168168
"""Plot keypoints for existing images.
169169
Args:
170170
kpts: list of ndarrays of size (N, 2).
@@ -180,7 +180,7 @@ def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
180180
for ax, k, c, alpha in zip(axes, kpts, colors, a):
181181
if isinstance(k, torch.Tensor):
182182
k = k.detach().cpu().numpy()
183-
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
183+
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha, **kwargs)
184184

185185

186186
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):

0 commit comments

Comments
 (0)