Commit 19e5284
Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75 (#2849)
When exporting `F.interpolate(mode='bicubic', antialias=True)`, the ONNX
Resize node was emitted with `cubic_coeff_a=-0.75` (OpenCV-compatible),
but PyTorch uses `-0.5` (Keys/PIL-compatible) for the antialias path.
This caused ~32x higher numerical error vs. PyTorch when running the
exported model in ONNX Runtime.
## Changes
- **`_aten_upsample_output_size` / `_aten_upsample_scales`**: Added
`cubic_coeff_a: float = -0.75` parameter (default preserves existing
behavior for non-antialias cases) and thread it through to `op.Resize`.
- **`aten__upsample_bicubic2d_aa`**: Pass `cubic_coeff_a=-0.5` to match
PyTorch's runtime behavior when `antialias=True`.
```python
# antialias=True → cubic_coeff_a=-0.5 (Keys/PIL-compatible) ✓
# antialias=False → cubic_coeff_a=-0.75 (OpenCV-compatible) ✓
```
<!-- START COPILOT ORIGINAL PROMPT -->
<details>
<summary>Original prompt</summary>
----
*This section details on the original issue you should resolve*
<issue_title>ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic
antialias=True (should be -0.5)</issue_title>
<issue_description>### 🐛 Describe the bug
# ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic
antialias=True (should be -0.5)
## Bug
When exporting `F.interpolate(mode='bicubic', antialias=True)` to ONNX
via the dynamo exporter, the Resize node is written with
`cubic_coeff_a=-0.75`. However, PyTorch internally uses
`cubic_coeff_a=-0.5` (Keys interpolation) when `antialias=True`, as
documented in the source:
```cpp
// aten/src/ATen/native/cpu/UpSampleKernel.cpp, line ~1347
// We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
// and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;
```
The exported ONNX model therefore produces different results than
PyTorch when run in ONNX Runtime (or any runtime that correctly respects
the `cubic_coeff_a` attribute).
The `-0.75` value was originally hardcoded in PR pytorch/pytorch#24805
for the non-antialias case and was carried forward without accounting
for the antialias path. The distinction between `-0.5` (Keys,
PIL-compatible) and `-0.75` (OpenCV-compatible) based on the antialias
flag was introduced in the ATen kernels via pytorch/vision#3810 and
pytorch#68819.
The legacy TorchScript exporter does not support `antialias=True` at all
(`UnsupportedOperatorError`), so this only affects the dynamo exporter.
## To reproduce
```python
import io
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F
class BicubicAA(nn.Module):
def forward(self, x):
return F.interpolate(x, size=[224, 224], mode="bicubic",
align_corners=False, antialias=True)
# Export
model = BicubicAA()
model.eval()
x = torch.rand(1, 3, 800, 600)
buf = io.BytesIO()
torch.onnx.export(model, (x,), buf, opset_version=18, dynamo=True)
buf.seek(0)
onnx_model = onnx.load(buf)
# Inspect: cubic_coeff_a is -0.75 (wrong for antialias=True)
for node in onnx_model.graph.node:
if node.op_type == "Resize":
for attr in node.attribute:
if attr.name == "cubic_coeff_a":
print(f"Exported cubic_coeff_a = {attr.f}") # -0.75
if attr.name == "antialias":
print(f"Exported antialias = {attr.i}") # 1
# Numerical impact
with torch.no_grad():
pt_out = model(x).numpy()
buf.seek(0)
sess = ort.InferenceSession(buf.read())
ort_wrong = sess.run(None, {"x": x.numpy()})[0]
# Patch to correct value and re-run
for node in onnx_model.graph.node:
if node.op_type == "Resize":
for attr in node.attribute:
if attr.name == "cubic_coeff_a":
attr.f = -0.5
buf2 = io.BytesIO()
onnx.save(onnx_model, buf2)
buf2.seek(0)
sess2 = ort.InferenceSession(buf2.read())
ort_fixed = sess2.run(None, {"x": x.numpy()})[0]
print(f"PyTorch vs ONNX (exported a=-0.75): mean={np.abs(ort_wrong - pt_out).mean():.2e}")
print(f"PyTorch vs ONNX (patched a=-0.50): mean={np.abs(ort_fixed - pt_out).mean():.2e}")
```
Output:
```
Exported cubic_coeff_a = -0.75
Exported antialias = 1
PyTorch vs ONNX (exported a=-0.75): mean=5.31e-03
PyTorch vs ONNX (patched a=-0.50): mean=1.67e-04
```
Patching `cubic_coeff_a` to `-0.5` reduces mean error by 32x, confirming
that PyTorch uses `-0.5` at runtime but the exporter writes `-0.75`.
## Expected behavior
When `antialias=True`, the ONNX Resize node should be exported with
`cubic_coeff_a=-0.5` to match PyTorch's runtime behavior. When
`antialias=False`, `cubic_coeff_a=-0.75` is correct.
### Versions
Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 4.2.3
Libc version: glibc-2.31
Python version: 3.12.12 (main, Feb 3 2026, 22:51:04) [Clang 21.1.4 ]
(64-bit runtime)
Python platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 565.57.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical...
</details>
<!-- START COPILOT CODING AGENT SUFFIX -->
- Fixes pytorch/pytorch#177138
<!-- START COPILOT CODING AGENT TIPS -->
---
🔒 GitHub Advanced Security automatically protects Copilot coding agent
pull requests. You can protect all pull requests by enabling Advanced
Security for your repositories. [Learn more about Advanced
Security.](https://gh.io/cca-advanced-security)
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>1 parent ebb0007 commit 19e5284
1 file changed
+7
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2329 | 2329 | | |
2330 | 2330 | | |
2331 | 2331 | | |
| 2332 | + | |
2332 | 2333 | | |
2333 | 2334 | | |
2334 | 2335 | | |
| |||
2344 | 2345 | | |
2345 | 2346 | | |
2346 | 2347 | | |
| 2348 | + | |
2347 | 2349 | | |
2348 | 2350 | | |
2349 | 2351 | | |
| |||
2355 | 2357 | | |
2356 | 2358 | | |
2357 | 2359 | | |
| 2360 | + | |
2358 | 2361 | | |
2359 | 2362 | | |
2360 | 2363 | | |
| |||
2365 | 2368 | | |
2366 | 2369 | | |
2367 | 2370 | | |
| 2371 | + | |
2368 | 2372 | | |
2369 | 2373 | | |
2370 | 2374 | | |
| |||
2404 | 2408 | | |
2405 | 2409 | | |
2406 | 2410 | | |
| 2411 | + | |
| 2412 | + | |
2407 | 2413 | | |
2408 | 2414 | | |
2409 | 2415 | | |
2410 | 2416 | | |
2411 | 2417 | | |
2412 | 2418 | | |
| 2419 | + | |
2413 | 2420 | | |
2414 | 2421 | | |
2415 | 2422 | | |
| |||
0 commit comments