Skip to content

Commit 6be9d18

Browse files
authored
Fix Op (convolution) | add nd support to convolution (#2108)
![Screenshot 2025-03-14 171007](https://github.com/user-attachments/assets/fc965055-9a29-44c6-a25d-b0e4a5867d0b) Ran into a case that aten.convolution.default takes 2D image with [0] as padding, which broke our assumption of it comes with the same rank of nd image.
1 parent 5bc7de5 commit 6be9d18

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,16 +2074,30 @@ def aten_convolution(
20742074
) -> TFloat:
20752075
"""convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor"""
20762076

2077+
rank = len(input.shape)
2078+
2079+
image_d = rank - 2
2080+
2081+
# NOTE: We assume the sequence padding/dilation/stride
2082+
# from ATen op can only be either len == 1 or
2083+
# len == rank.
2084+
20772085
if not isinstance(padding, Sequence):
2078-
padding = (padding, padding)
2086+
padding = [padding] * image_d
2087+
elif len(padding) == 1:
2088+
padding = [padding[0]] * image_d
20792089
pads = [*padding, *padding]
20802090

20812091
if not isinstance(dilation, Sequence):
2082-
dilation = (dilation, dilation)
2092+
dilation = [dilation] * image_d
2093+
elif len(dilation) == 1:
2094+
dilation = [dilation[0]] * image_d
20832095
dilations = list(dilation)
20842096

20852097
if not isinstance(stride, Sequence):
2086-
stride = (stride, stride)
2098+
stride = [stride] * image_d
2099+
elif len(stride) == 1:
2100+
stride = [stride[0]] * image_d
20872101
strides = list(stride)
20882102

20892103
result = _aten_convolution_onnx(

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
239239
"groups": 1,
240240
},
241241
),
242+
(
243+
(1, 3, 224, 224),
244+
(32, 3, 3, 3),
245+
None,
246+
{
247+
"stride": (2,),
248+
"padding": (1,),
249+
"dilation": (1,),
250+
"transposed": False,
251+
"output_padding": (0, 0),
252+
"groups": 1,
253+
},
254+
),
242255
(
243256
(1, 3, 3, 224, 224),
244257
(32, 3, 3, 3, 3),
@@ -252,21 +265,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
252265
"groups": 1,
253266
},
254267
),
255-
# FIXME(jiz): Uncomment out these test data once
256-
# torch 2.0 is released.
257-
# (
258-
# (1, 3, 224, 224, 224),
259-
# (32, 3, 3, 3, 3),
260-
# (32,),
261-
# {
262-
# "stride": (2, 2, 2),
263-
# "padding": (1, 1, 1),
264-
# "dilation": (1, 1, 1),
265-
# "transposed": False,
266-
# "output_padding": (0, 0, 0),
267-
# "groups": 1,
268-
# },
269-
# ),
268+
(
269+
(1, 3, 224, 224, 224),
270+
(32, 3, 3, 3, 3),
271+
(32,),
272+
{
273+
"stride": (2, 2, 2),
274+
"padding": (1, 1, 1),
275+
"dilation": (1, 1, 1),
276+
"transposed": False,
277+
"output_padding": (0, 0, 0),
278+
"groups": 1,
279+
},
280+
),
270281
(
271282
(2, 4, 6, 6),
272283
(4, 1, 3, 3),

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,7 @@ def _where_input_wrangler(
15571557
TorchLibOpInfo(
15581558
"ops.aten.convolution",
15591559
core_ops.aten_convolution,
1560-
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
1560+
tolerance={torch.float32: (2e-4, 9e-4)},
15611561
),
15621562
TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True),
15631563
TorchLibOpInfo(

0 commit comments

Comments
 (0)