Skip to content

Commit ba96190

Browse files
authored
Merge branch 'main' into add-grouped-mm-2795
2 parents 76eaefc + 19e5284 commit ba96190

3 files changed

Lines changed: 43 additions & 79 deletions

File tree

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,6 +2329,7 @@ def _aten_upsample_output_size(
23292329
mode: str,
23302330
coordinate_transformation_mode: str,
23312331
antialias: int = 0,
2332+
cubic_coeff_a: float = -0.75,
23322333
) -> TReal:
23332334
batch_and_channel = op.Shape(self, end=2, start=0)
23342335
# When output_size is passed in as a list of integers, the torch.onnx
@@ -2344,6 +2345,7 @@ def _aten_upsample_output_size(
23442345
output_size,
23452346
mode=mode,
23462347
coordinate_transformation_mode=coordinate_transformation_mode,
2348+
cubic_coeff_a=cubic_coeff_a,
23472349
nearest_mode="floor",
23482350
antialias=antialias,
23492351
)
@@ -2355,6 +2357,7 @@ def _aten_upsample_scales(
23552357
mode: str,
23562358
coordinate_transformation_mode: str,
23572359
antialias: int = 0,
2360+
cubic_coeff_a: float = -0.75,
23582361
) -> TReal:
23592362
return op.Resize(
23602363
self,
@@ -2365,6 +2368,7 @@ def _aten_upsample_scales(
23652368
None,
23662369
mode=mode,
23672370
coordinate_transformation_mode=coordinate_transformation_mode,
2371+
cubic_coeff_a=cubic_coeff_a,
23682372
nearest_mode="floor",
23692373
antialias=antialias,
23702374
)
@@ -2404,12 +2408,15 @@ def aten__upsample_bicubic2d_aa(
24042408
# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
24052409
# unless when align_corners is True, in which case we do not know what is going on.
24062410
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2411+
# PyTorch uses cubic_coeff_a=-0.5 (Keys interpolation, PIL-compatible) when
2412+
# antialias=True, as opposed to -0.75 (OpenCV-compatible) for the non-antialias case.
24072413
return _aten_upsample_output_size(
24082414
self,
24092415
output_size,
24102416
mode="cubic",
24112417
coordinate_transformation_mode=coordinate_transformation_mode,
24122418
antialias=1,
2419+
cubic_coeff_a=-0.5,
24132420
)
24142421

24152422

onnxscript/function_libs/torch_lib/ops/vision.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int):
5555
@torch_op("torchvision::roi_align", trace_only=True)
5656
def torchvision_roi_align(
5757
input,
58-
boxes,
59-
output_size: Sequence[int],
60-
spatial_scale: float = 1.0,
58+
rois,
59+
spatial_scale: float,
60+
pooled_height: int,
61+
pooled_width: int,
6162
sampling_ratio: int = -1,
6263
aligned: bool = False,
6364
):
64-
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
65-
pooled_height, pooled_width = output_size
66-
batch_indices = _process_batch_indices_for_roi_align(boxes)
67-
rois_coords = _process_rois_for_roi_align(boxes)
65+
"""torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"""
66+
batch_indices = _process_batch_indices_for_roi_align(rois)
67+
rois_coords = _process_rois_for_roi_align(rois)
6868
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
6969
sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio)
7070

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 29 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,81 +1499,38 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14991499

15001500

15011501
def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
1502-
del op_info
1503-
del kwargs
1504-
# roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1505-
1506-
# Test 1: spatial_scale=1, sampling_ratio=2
1507-
x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1508-
roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device)
1509-
yield opinfo_core.SampleInput(
1510-
x1,
1511-
args=(roi1, (5, 5)),
1512-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True},
1513-
)
1514-
1515-
# Test 2: spatial_scale=0.5, sampling_ratio=3
1516-
x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1517-
roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1518-
yield opinfo_core.SampleInput(
1519-
x2,
1520-
args=(roi2, (5, 5)),
1521-
kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True},
1522-
)
1523-
1524-
# Test 3: spatial_scale=1.8, sampling_ratio=2
1525-
x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1526-
roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1527-
yield opinfo_core.SampleInput(
1528-
x3,
1529-
args=(roi3, (5, 5)),
1530-
kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True},
1531-
)
1532-
1533-
# Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1534-
x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1535-
roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1536-
yield opinfo_core.SampleInput(
1537-
x4,
1538-
args=(roi4, (2, 2)),
1539-
kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True},
1540-
)
1502+
del op_info, kwargs
15411503

1542-
# Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1543-
x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1544-
roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1545-
yield opinfo_core.SampleInput(
1546-
x5,
1547-
args=(roi5, (2, 2)),
1548-
kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True},
1549-
)
1504+
def make_x():
1505+
return torch.rand(
1506+
1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad
1507+
)
15501508

1551-
# Test 6: malformed boxes (test_roi_align_malformed_boxes)
1552-
x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1553-
roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device)
1554-
yield opinfo_core.SampleInput(
1555-
x6,
1556-
args=(roi6, (5, 5)),
1557-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True},
1558-
)
1509+
# rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
1510+
roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device)
1511+
roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1512+
roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device)
1513+
roi_malformed = torch.tensor(
1514+
[[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device
1515+
) # x1 > x2-ish
15591516

1560-
# Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1561-
x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1562-
roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1563-
yield opinfo_core.SampleInput(
1564-
x7,
1565-
args=(roi7, (5, 5)),
1566-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False},
1567-
)
1517+
# (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
1518+
cases = [
1519+
(roi_a, 1.0, 5, 5, 2, True),
1520+
(roi_b, 0.5, 5, 5, 3, True),
1521+
(roi_b, 1.8, 5, 5, 2, True),
1522+
(roi_b, 2.5, 2, 2, 0, True),
1523+
(roi_b, 2.5, 2, 2, -1, True),
1524+
(roi_malformed, 1.0, 5, 5, 1, True),
1525+
(roi_int, 1.0, 5, 5, 2, False),
1526+
(roi_int, 1.0, 5, 5, -1, False),
1527+
]
15681528

1569-
# Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1570-
x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1571-
roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1572-
yield opinfo_core.SampleInput(
1573-
x8,
1574-
args=(roi8, (5, 5)),
1575-
kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False},
1576-
)
1529+
for rois, spatial_scale, ph, pw, sr, aligned in cases:
1530+
yield opinfo_core.SampleInput(
1531+
make_x(),
1532+
args=(rois, float(spatial_scale), int(ph), int(pw), int(sr), bool(aligned)),
1533+
)
15771534

15781535

15791536
def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
@@ -3160,7 +3117,7 @@ def __init__(self):
31603117
),
31613118
opinfo_core.OpInfo(
31623119
"torchvision.ops.roi_align",
3163-
op=torchvision.ops.roi_align,
3120+
op=torch.ops.torchvision.roi_align.default,
31643121
dtypes=common_dtype.floating_types(),
31653122
sample_inputs_func=sample_inputs_roi_align,
31663123
supports_out=False,

0 commit comments

Comments
 (0)