@@ -1471,81 +1471,38 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14711471
14721472
14731473def sample_inputs_roi_align (op_info , device , dtype , requires_grad , ** kwargs ):
1474- del op_info
1475- del kwargs
1476- # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1477-
1478- # Test 1: spatial_scale=1, sampling_ratio=2
1479- x1 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1480- roi1 = torch .tensor ([[0 , 1.5 , 1.5 , 3 , 3 ]], dtype = dtype , device = device )
1481- yield opinfo_core .SampleInput (
1482- x1 ,
1483- args = (roi1 , (5 , 5 )),
1484- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : True },
1485- )
1486-
1487- # Test 2: spatial_scale=0.5, sampling_ratio=3
1488- x2 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1489- roi2 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1490- yield opinfo_core .SampleInput (
1491- x2 ,
1492- args = (roi2 , (5 , 5 )),
1493- kwargs = {"spatial_scale" : 0.5 , "sampling_ratio" : 3 , "aligned" : True },
1494- )
1495-
1496- # Test 3: spatial_scale=1.8, sampling_ratio=2
1497- x3 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1498- roi3 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1499- yield opinfo_core .SampleInput (
1500- x3 ,
1501- args = (roi3 , (5 , 5 )),
1502- kwargs = {"spatial_scale" : 1.8 , "sampling_ratio" : 2 , "aligned" : True },
1503- )
1504-
1505- # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1506- x4 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1507- roi4 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1508- yield opinfo_core .SampleInput (
1509- x4 ,
1510- args = (roi4 , (2 , 2 )),
1511- kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : 0 , "aligned" : True },
1512- )
1474+ del op_info , kwargs
15131475
1514- # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1515- x5 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1516- roi5 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1517- yield opinfo_core .SampleInput (
1518- x5 ,
1519- args = (roi5 , (2 , 2 )),
1520- kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : - 1 , "aligned" : True },
1521- )
1476+ def make_x ():
1477+ return torch .rand (
1478+ 1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad
1479+ )
15221480
1523- # Test 6: malformed boxes (test_roi_align_malformed_boxes)
1524- x6 = torch .randn (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1525- roi6 = torch .tensor ([[0 , 2 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device )
1526- yield opinfo_core .SampleInput (
1527- x6 ,
1528- args = (roi6 , (5 , 5 )),
1529- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 1 , "aligned" : True },
1530- )
1481+ # rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
1482+ roi_a = torch .tensor ([[0 , 1.5 , 1.5 , 3.0 , 3.0 ]], dtype = dtype , device = device )
1483+ roi_b = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1484+ roi_int = torch .tensor ([[0 , 0.0 , 0.0 , 4.0 , 4.0 ]], dtype = dtype , device = device )
1485+ roi_malformed = torch .tensor (
1486+ [[0 , 2.0 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device
1487+ ) # x1 > x2-ish
15311488
1532- # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1533- x7 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1534- roi7 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1535- yield opinfo_core .SampleInput (
1536- x7 ,
1537- args = (roi7 , (5 , 5 )),
1538- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : False },
1539- )
1489+ # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
1490+ cases = [
1491+ (roi_a , 1.0 , 5 , 5 , 2 , True ),
1492+ (roi_b , 0.5 , 5 , 5 , 3 , True ),
1493+ (roi_b , 1.8 , 5 , 5 , 2 , True ),
1494+ (roi_b , 2.5 , 2 , 2 , 0 , True ),
1495+ (roi_b , 2.5 , 2 , 2 , - 1 , True ),
1496+ (roi_malformed , 1.0 , 5 , 5 , 1 , True ),
1497+ (roi_int , 1.0 , 5 , 5 , 2 , False ),
1498+ (roi_int , 1.0 , 5 , 5 , - 1 , False ),
1499+ ]
15401500
1541- # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1542- x8 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1543- roi8 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1544- yield opinfo_core .SampleInput (
1545- x8 ,
1546- args = (roi8 , (5 , 5 )),
1547- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : - 1 , "aligned" : False },
1548- )
1501+ for rois , spatial_scale , ph , pw , sr , aligned in cases :
1502+ yield opinfo_core .SampleInput (
1503+ make_x (),
1504+ args = (rois , float (spatial_scale ), int (ph ), int (pw ), int (sr ), bool (aligned )),
1505+ )
15491506
15501507
15511508def sample_inputs_roi_pool (op_info , device , dtype , requires_grad , ** kwargs ):
@@ -3132,7 +3089,7 @@ def __init__(self):
31323089 ),
31333090 opinfo_core .OpInfo (
31343091 "torchvision.ops.roi_align" ,
3135- op = torchvision .ops .roi_align ,
3092+ op = torch .ops .torchvision . roi_align . default ,
31363093 dtypes = common_dtype .floating_types (),
31373094 sample_inputs_func = sample_inputs_roi_align ,
31383095 supports_out = False ,
0 commit comments