Commit 847801c
optimizer: Prevent constant folding of DynamicQuantizeLinear (#2865)
The constant folding pass was eliminating `DequantizeLinear` nodes that
operated on constant weight tensors during `optimize()`, collapsing the
quantization structure into a plain `Conv` and losing quantization
semantics in QAT-exported models.
### Changes
- **`optimizer/_constant_folding.py`**: Add `DynamicQuantizeLinear` to
`DEFAULT_CONSTANT_FOLD_BLACKLIST` alongside the existing
`QuantizeLinear` and `DequantizeLinear` entries; reorder alphabetically
for consistency
- **`optimizer/_constant_folding_test.py`**: Add tests verifying
`QuantizeLinear` and `DequantizeLinear` are not folded when all inputs
are constant initializers
<!-- START COPILOT ORIGINAL PROMPT -->
<details>
<summary>Original prompt</summary>
----
*This section details on the original issue you should resolve*
<issue_title>[ONNX] Optimize should not fold
DequantizeLinear</issue_title>
<issue_description>### 🐛 Describe the bug
After the QAT model undergoes the onnx_program.optimize() process, there
is a loss of quantization nodes. As shown in the figure on the left is
the normal export, and on the right is the abnormal export graph.
<img width="898" height="884" alt="Image"
src="https://github.com/user-attachments/assets/481bc3c0-38fe-45f6-9fde-bc1a287617a3"
/>
This bug occurred in `torch/onnx/_internal/exporter/_onnx_program.py`:
```
def optimize(self) -> None:
self.model = onnxscript_apis.optimize(self.model)
```
and it internally called the optimize_ir function in
`onnxscript/optimizer/_optimizer.py`.
The default value of `input_size_limit` is 512. Nodes with an input size
less than this value will be collapsed.
```
def optimize_ir(
model: ir.Model,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
inline: bool = True,
) -> None:
passes = [
ir.passes.PassManager(
[
_constant_folding.FoldConstantsPass(
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
),
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
common_passes.RemoveUnusedNodesPass(),
common_passes.RemoveUnusedFunctionsPass(),
common_passes.RemoveUnusedOpsetsPass(),
],
steps=num_iterations,
early_stop=stop_if_no_change,
),
......
```
⭐ Please enable the parameter `optimization` function in
`torch/onnx/_internal/exporter/_onnx_program.py`. Otherwise, I will be
able to install onnxscript only by referring to the source code.
The smallest reproducible example:
```
import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx
class ConvBnReluModel(nn.Module):
def __init__(self, eps=1e-3, momentum=0.03):
super().__init__()
self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def get_batch_norm_node_args(gm):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
return tuple(node.args)
raise RuntimeError("No aten.batch_norm.default node found")
torch.manual_seed(0)
device = 'cuda'
model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()
#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')
model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")
```
### Versions
Versions of relevant libraries:
[pip3] executorch==0.5.0
[pip3] numpy==1.23.5
[pip3] nvidia-cublas-cu11==11.11.3.6
[pip3] nvidia-cuda-cupti-cu11==11.8.87
[pip3] nvidia-cuda-nvrtc-cu11==11.8.89
[pip3] nvidia-cuda-runtime-cu11==11.8.89
[pip3] nvidia-cudnn-cu11==9.1.0.70
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.3.0.86
[pip3] nvidia-cusolver-cu11==11.4.1.48
[pip3] nvidia-cusparse-cu11==11.7.5.86
[pip3] nvidia-nccl-cu11==2.21.5
[pip3] nvidia-nvtx-cu11==11.8.86
[pip3] onnx==1.17.0
[pip3] onnx_graphsurgeon==0.5.8
[pip3] onnx-ir==0.1.12
[pip3] onnx-simplifier==0.4.36
[pip3] onnxruntime==1.21.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.4.0
[pip3] onnxslim==0.1.48
[pip3] torch==2.6.0+cu118
[pip3] torchao==0.14.1
[pip3] torchaudio==2.6.0+cu118
[pip3] torchvision==0.21.0+cu118
[pip3] ...
</details>
<!-- START COPILOT CODING AGENT SUFFIX -->
- Fixes pytorch/pytorch#177611
<!-- START COPILOT CODING AGENT TIPS -->
---
📍 Connect Copilot coding agent with [Jira](https://gh.io/cca-jira-docs),
[Azure Boards](https://gh.io/cca-azure-boards-docs) or
[Linear](https://gh.io/cca-linear-docs) to delegate work to Copilot in
one click without leaving your project management tool.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>1 parent 19e5284 commit 847801c
File tree
2 files changed
+32
-1
lines changed- onnxscript/optimizer
2 files changed
+32
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | | - | |
34 | 33 | | |
| 34 | + | |
| 35 | + | |
35 | 36 | | |
36 | 37 | | |
37 | 38 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
689 | 689 | | |
690 | 690 | | |
691 | 691 | | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
692 | 722 | | |
693 | 723 | | |
694 | 724 | | |
| |||
0 commit comments