@@ -1499,81 +1499,38 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14991499
15001500
15011501def sample_inputs_roi_align (op_info , device , dtype , requires_grad , ** kwargs ):
1502- del op_info
1503- del kwargs
1504- # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1505-
1506- # Test 1: spatial_scale=1, sampling_ratio=2
1507- x1 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1508- roi1 = torch .tensor ([[0 , 1.5 , 1.5 , 3 , 3 ]], dtype = dtype , device = device )
1509- yield opinfo_core .SampleInput (
1510- x1 ,
1511- args = (roi1 , (5 , 5 )),
1512- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : True },
1513- )
1514-
1515- # Test 2: spatial_scale=0.5, sampling_ratio=3
1516- x2 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1517- roi2 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1518- yield opinfo_core .SampleInput (
1519- x2 ,
1520- args = (roi2 , (5 , 5 )),
1521- kwargs = {"spatial_scale" : 0.5 , "sampling_ratio" : 3 , "aligned" : True },
1522- )
1523-
1524- # Test 3: spatial_scale=1.8, sampling_ratio=2
1525- x3 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1526- roi3 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1527- yield opinfo_core .SampleInput (
1528- x3 ,
1529- args = (roi3 , (5 , 5 )),
1530- kwargs = {"spatial_scale" : 1.8 , "sampling_ratio" : 2 , "aligned" : True },
1531- )
1532-
1533- # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1534- x4 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1535- roi4 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1536- yield opinfo_core .SampleInput (
1537- x4 ,
1538- args = (roi4 , (2 , 2 )),
1539- kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : 0 , "aligned" : True },
1540- )
1502+ del op_info , kwargs
15411503
1542- # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1543- x5 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1544- roi5 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1545- yield opinfo_core .SampleInput (
1546- x5 ,
1547- args = (roi5 , (2 , 2 )),
1548- kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : - 1 , "aligned" : True },
1549- )
1504+ def make_x ():
1505+ return torch .rand (
1506+ 1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad
1507+ )
15501508
1551- # Test 6: malformed boxes (test_roi_align_malformed_boxes)
1552- x6 = torch .randn (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1553- roi6 = torch .tensor ([[0 , 2 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device )
1554- yield opinfo_core .SampleInput (
1555- x6 ,
1556- args = (roi6 , (5 , 5 )),
1557- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 1 , "aligned" : True },
1558- )
1509+ # rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
1510+ roi_a = torch .tensor ([[0 , 1.5 , 1.5 , 3.0 , 3.0 ]], dtype = dtype , device = device )
1511+ roi_b = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1512+ roi_int = torch .tensor ([[0 , 0.0 , 0.0 , 4.0 , 4.0 ]], dtype = dtype , device = device )
1513+ roi_malformed = torch .tensor (
1514+ [[0 , 2.0 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device
1515+ ) # x1 > x2-ish
15591516
1560- # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1561- x7 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1562- roi7 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1563- yield opinfo_core .SampleInput (
1564- x7 ,
1565- args = (roi7 , (5 , 5 )),
1566- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : False },
1567- )
1517+ # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
1518+ cases = [
1519+ (roi_a , 1.0 , 5 , 5 , 2 , True ),
1520+ (roi_b , 0.5 , 5 , 5 , 3 , True ),
1521+ (roi_b , 1.8 , 5 , 5 , 2 , True ),
1522+ (roi_b , 2.5 , 2 , 2 , 0 , True ),
1523+ (roi_b , 2.5 , 2 , 2 , - 1 , True ),
1524+ (roi_malformed , 1.0 , 5 , 5 , 1 , True ),
1525+ (roi_int , 1.0 , 5 , 5 , 2 , False ),
1526+ (roi_int , 1.0 , 5 , 5 , - 1 , False ),
1527+ ]
15681528
1569- # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1570- x8 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1571- roi8 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1572- yield opinfo_core .SampleInput (
1573- x8 ,
1574- args = (roi8 , (5 , 5 )),
1575- kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : - 1 , "aligned" : False },
1576- )
1529+ for rois , spatial_scale , ph , pw , sr , aligned in cases :
1530+ yield opinfo_core .SampleInput (
1531+ make_x (),
1532+ args = (rois , float (spatial_scale ), int (ph ), int (pw ), int (sr ), bool (aligned )),
1533+ )
15771534
15781535
15791536def sample_inputs_roi_pool (op_info , device , dtype , requires_grad , ** kwargs ):
@@ -3160,7 +3117,7 @@ def __init__(self):
31603117 ),
31613118 opinfo_core .OpInfo (
31623119 "torchvision.ops.roi_align" ,
3163- op = torchvision .ops .roi_align ,
3120+ op = torch .ops .torchvision . roi_align . default ,
31643121 dtypes = common_dtype .floating_types (),
31653122 sample_inputs_func = sample_inputs_roi_align ,
31663123 supports_out = False ,
0 commit comments