Skip to content

Commit efbd732

Browse files
committed
AddOp(upsample_trilinear3d)
1 parent 9a7ae80 commit efbd732

3 files changed

Lines changed: 64 additions & 3 deletions

File tree

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,17 +2491,28 @@ def aten_upsample_nearest3d_backward(
24912491
raise NotImplementedError()
24922492

24932493

2494+
@torch_op("aten::upsample_trilinear3d", trace_only=True)
24942495
def aten_upsample_trilinear3d(
2495-
self: TensorType,
2496+
self: TReal,
24962497
output_size: INT64,
24972498
align_corners: bool,
24982499
scales_d: Optional[float] = None,
24992500
scales_h: Optional[float] = None,
25002501
scales_w: Optional[float] = None,
2501-
) -> TensorType:
2502+
) -> TReal:
25022503
"""upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""
25032504

2504-
raise NotImplementedError()
2505+
del scales_d
2506+
del scales_h
2507+
del scales_w
2508+
2509+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2510+
return _aten_upsample_output_size(
2511+
self,
2512+
output_size,
2513+
mode="linear",
2514+
coordinate_transformation_mode=coordinate_transformation_mode,
2515+
)
25052516

25062517

25072518
def aten_upsample_trilinear3d_backward(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,44 @@ def shape(size, rank, with_batch_channel=True):
15791579
)
15801580

15811581

1582+
def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs):
1583+
del op_info
1584+
del kwargs
1585+
1586+
N, C = 2, 3
1587+
D = 4
1588+
SS = 3
1589+
L = 5
1590+
1591+
align_corners_options = (True, False)
1592+
rank = 3
1593+
1594+
def shape(size, rank, with_batch_channel=True):
1595+
if with_batch_channel:
1596+
return tuple([N, C] + ([size] * rank))
1597+
return tuple([size] * rank)
1598+
1599+
make_arg = functools.partial(
1600+
torch_testing.make_tensor,
1601+
device=device,
1602+
dtype=dtype,
1603+
requires_grad=requires_grad,
1604+
low=-1,
1605+
high=1,
1606+
)
1607+
1608+
for align_corners in align_corners_options:
1609+
yield opinfo_core.SampleInput(
1610+
make_arg(shape(D, rank)), shape(SS, rank, False), align_corners
1611+
)
1612+
yield opinfo_core.SampleInput(
1613+
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
1614+
)
1615+
yield opinfo_core.SampleInput(
1616+
make_arg(shape(D, rank)), shape(L, rank, False), align_corners
1617+
)
1618+
1619+
15821620
class _TestParamsMaxPoolEmptyStrideBase:
15831621
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
15841622
def __init__(self):
@@ -2079,6 +2117,13 @@ def __init__(self):
20792117
sample_inputs_func=sample_inputs_upsample_linear1d,
20802118
supports_out=False,
20812119
),
2120+
opinfo_core.OpInfo(
2121+
"ops.aten.upsample_trilinear3d",
2122+
aten_name="upsample_trilinear3d",
2123+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
2124+
sample_inputs_func=sample_inputs_upsample_trilinear3d,
2125+
supports_out=False,
2126+
),
20822127
opinfo_core.OpInfo(
20832128
"nn.functional.max_pool1d_with_indices",
20842129
aten_name="max_pool1d_with_indices",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,11 @@ def _where_input_wrangler(
21402140
and sample.kwargs.get("scales") is not None,
21412141
reason="fixme: align_corners=False output mismatch when scales are provided",
21422142
),
2143+
TorchLibOpInfo(
2144+
"ops.aten.upsample_trilinear3d",
2145+
nn_ops.aten_upsample_trilinear3d,
2146+
trace_only=True,
2147+
),
21432148
TorchLibOpInfo(
21442149
"nn.functional.upsample_nearest2d",
21452150
nn_ops.aten_upsample_nearest2d,

0 commit comments

Comments
 (0)