Skip to content

Commit 0f6d91d

Browse files
raimbekovmzy1git
andauthored
Vectorize masks_to_boxes for performance (#9358)
Co-authored-by: zy1git <zycoding1@gmail.com>
1 parent 326a11d commit 0f6d91d

1 file changed

Lines changed: 17 additions & 9 deletions

File tree

torchvision/ops/boxes.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -509,17 +509,25 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
509509
if masks.numel() == 0:
510510
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
511511

512-
n = masks.shape[0]
512+
n, h, w = masks.shape
513513

514-
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
514+
masks_bool = masks.bool()
515515

516-
for index, mask in enumerate(masks):
517-
y, x = torch.where(mask != 0)
516+
non_zero_rows = torch.any(masks_bool, dim=2)
517+
non_zero_cols = torch.any(masks_bool, dim=1)
518518

519-
if x.numel() > 0:
520-
bounding_boxes[index, 0] = torch.min(x)
521-
bounding_boxes[index, 1] = torch.min(y)
522-
bounding_boxes[index, 2] = torch.max(x)
523-
bounding_boxes[index, 3] = torch.max(y)
519+
empty_masks = ~torch.any(non_zero_rows, dim=1)
520+
521+
non_zero_rows_f = non_zero_rows.float()
522+
non_zero_cols_f = non_zero_cols.float()
523+
524+
y1 = non_zero_rows_f.argmax(dim=1)
525+
x1 = non_zero_cols_f.argmax(dim=1)
526+
y2 = (h - 1) - non_zero_rows_f.flip(dims=[1]).argmax(dim=1)
527+
x2 = (w - 1) - non_zero_cols_f.flip(dims=[1]).argmax(dim=1)
528+
529+
bounding_boxes = torch.stack([x1, y1, x2, y2], dim=1).float()
530+
531+
bounding_boxes[empty_masks] = 0
524532

525533
return bounding_boxes

0 commit comments

Comments
 (0)