@@ -1579,6 +1579,44 @@ def shape(size, rank, with_batch_channel=True):
15791579 )
15801580
15811581
1582+ def sample_inputs_upsample_trilinear3d (op_info , device , dtype , requires_grad , ** kwargs ):
1583+ del op_info
1584+ del kwargs
1585+
1586+ N , C = 2 , 3
1587+ D = 4
1588+ SS = 3
1589+ L = 5
1590+
1591+ align_corners_options = (True , False )
1592+ rank = 3
1593+
1594+ def shape (size , rank , with_batch_channel = True ):
1595+ if with_batch_channel :
1596+ return tuple ([N , C ] + ([size ] * rank ))
1597+ return tuple ([size ] * rank )
1598+
1599+ make_arg = functools .partial (
1600+ torch_testing .make_tensor ,
1601+ device = device ,
1602+ dtype = dtype ,
1603+ requires_grad = requires_grad ,
1604+ low = - 1 ,
1605+ high = 1 ,
1606+ )
1607+
1608+ for align_corners in align_corners_options :
1609+ yield opinfo_core .SampleInput (
1610+ make_arg (shape (D , rank )), shape (SS , rank , False ), align_corners
1611+ )
1612+ yield opinfo_core .SampleInput (
1613+ make_arg (shape (D , rank )), shape (S , rank , False ), align_corners
1614+ )
1615+ yield opinfo_core .SampleInput (
1616+ make_arg (shape (D , rank )), shape (L , rank , False ), align_corners
1617+ )
1618+
1619+
15821620class _TestParamsMaxPoolEmptyStrideBase :
15831621 # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
15841622 def __init__ (self ):
@@ -2079,6 +2117,13 @@ def __init__(self):
20792117 sample_inputs_func = sample_inputs_upsample_linear1d ,
20802118 supports_out = False ,
20812119 ),
2120+ opinfo_core .OpInfo (
2121+ "ops.aten.upsample_trilinear3d" ,
2122+ aten_name = "upsample_trilinear3d" ,
2123+ dtypes = common_dtype .floating_types_and (torch .bfloat16 ),
2124+ sample_inputs_func = sample_inputs_upsample_trilinear3d ,
2125+ supports_out = False ,
2126+ ),
20822127 opinfo_core .OpInfo (
20832128 "nn.functional.max_pool1d_with_indices" ,
20842129 aten_name = "max_pool1d_with_indices" ,
0 commit comments