@@ -509,17 +509,25 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
509509 if masks .numel () == 0 :
510510 return torch .zeros ((0 , 4 ), device = masks .device , dtype = torch .float )
511511
512- n = masks .shape [ 0 ]
512+ n , h , w = masks .shape
513513
514- bounding_boxes = torch . zeros (( n , 4 ), device = masks .device , dtype = torch . float )
514+ masks_bool = masks .bool ( )
515515
516- for index , mask in enumerate ( masks ):
517- y , x = torch .where ( mask != 0 )
516+ non_zero_rows = torch . any ( masks_bool , dim = 2 )
517+ non_zero_cols = torch .any ( masks_bool , dim = 1 )
518518
519- if x .numel () > 0 :
520- bounding_boxes [index , 0 ] = torch .min (x )
521- bounding_boxes [index , 1 ] = torch .min (y )
522- bounding_boxes [index , 2 ] = torch .max (x )
523- bounding_boxes [index , 3 ] = torch .max (y )
519+ empty_masks = ~ torch .any (non_zero_rows , dim = 1 )
520+
521+ non_zero_rows_f = non_zero_rows .float ()
522+ non_zero_cols_f = non_zero_cols .float ()
523+
524+ y1 = non_zero_rows_f .argmax (dim = 1 )
525+ x1 = non_zero_cols_f .argmax (dim = 1 )
526+ y2 = (h - 1 ) - non_zero_rows_f .flip (dims = [1 ]).argmax (dim = 1 )
527+ x2 = (w - 1 ) - non_zero_cols_f .flip (dims = [1 ]).argmax (dim = 1 )
528+
529+ bounding_boxes = torch .stack ([x1 , y1 , x2 , y2 ], dim = 1 ).float ()
530+
531+ bounding_boxes [empty_masks ] = 0
524532
525533 return bounding_boxes
0 commit comments