Skip to content

Commit 847801c

Browse files
Copilotjustinchuby
andauthored
optimizer: Prevent constant folding of DynamicQuantizeLinear (#2865)
The constant folding pass was eliminating `DequantizeLinear` nodes that operated on constant weight tensors during `optimize()`, collapsing the quantization structure into a plain `Conv` and losing quantization semantics in QAT-exported models. ### Changes - **`optimizer/_constant_folding.py`**: Add `DynamicQuantizeLinear` to `DEFAULT_CONSTANT_FOLD_BLACKLIST` alongside the existing `QuantizeLinear` and `DequantizeLinear` entries; reorder alphabetically for consistency - **`optimizer/_constant_folding_test.py`**: Add tests verifying `QuantizeLinear` and `DequantizeLinear` are not folded when all inputs are constant initializers <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> ---- *This section details on the original issue you should resolve* <issue_title>[ONNX] Optimize should not fold DequantizeLinear</issue_title> <issue_description>### 🐛 Describe the bug After the QAT model undergoes the onnx_program.optimize() process, there is a loss of quantization nodes. As shown in the figure on the left is the normal export, and on the right is the abnormal export graph. <img width="898" height="884" alt="Image" src="https://github.com/user-attachments/assets/481bc3c0-38fe-45f6-9fde-bc1a287617a3" /> This bug occurred in `torch/onnx/_internal/exporter/_onnx_program.py`: ``` def optimize(self) -> None: self.model = onnxscript_apis.optimize(self.model) ``` and it internally called the optimize_ir function in `onnxscript/optimizer/_optimizer.py`. The default value of `input_size_limit` is 512. Nodes with an input size less than this value will be collapsed. ``` def optimize_ir( model: ir.Model, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, inline: bool = True, ) -> None: passes = [ ir.passes.PassManager( [ _constant_folding.FoldConstantsPass( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, ), rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), common_passes.RemoveUnusedNodesPass(), common_passes.RemoveUnusedFunctionsPass(), common_passes.RemoveUnusedOpsetsPass(), ], steps=num_iterations, early_stop=stop_if_no_change, ), ...... ``` ⭐ Please enable the parameter `optimization` function in `torch/onnx/_internal/exporter/_onnx_program.py`. Otherwise, I will be able to install onnxscript only by referring to the source code. The smallest reproducible example: ``` import copy import torch import torch.nn as nn from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) from onnxslim import slim import onnx class ConvBnReluModel(nn.Module): def __init__(self, eps=1e-3, momentum=0.03): super().__init__() self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False) self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum) self.act = nn.ReLU() def forward(self, x): return self.act(self.bn(self.conv(x))) def get_batch_norm_node_args(gm): for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default: return tuple(node.args) raise RuntimeError("No aten.batch_norm.default node found") torch.manual_seed(0) device = 'cuda' model = ConvBnReluModel().train().to(device) inputs = (torch.randn(2, 4, 8, 8).to(device),) exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module() print("before prepare:", get_batch_norm_node_args(exported)) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True)) prepared = prepare_qat_pt2e(exported, quantizer) prepared.to(device) torch.ao.quantization.move_exported_model_to_eval(prepared) torch.ao.quantization.allow_exported_model_train_eval(prepared) prepared.eval() #---export quantized model to onnx--- qat_onnx_sp = './quant.onnx' quantized_model = convert_pt2e(prepared) print('convert_pt2e done!') onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21) """ bug """ onnx_program.optimize() onnx_program.save(qat_onnx_sp) print(f'export qat model to [{qat_onnx_sp}] done!') model_simp = slim(onnx_program.model_proto) sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx') onnx.save(model_simp, sim_path) print(f"save onnx model to [{sim_path}] Successfully!") ``` ### Versions Versions of relevant libraries: [pip3] executorch==0.5.0 [pip3] numpy==1.23.5 [pip3] nvidia-cublas-cu11==11.11.3.6 [pip3] nvidia-cuda-cupti-cu11==11.8.87 [pip3] nvidia-cuda-nvrtc-cu11==11.8.89 [pip3] nvidia-cuda-runtime-cu11==11.8.89 [pip3] nvidia-cudnn-cu11==9.1.0.70 [pip3] nvidia-cufft-cu11==10.9.0.58 [pip3] nvidia-curand-cu11==10.3.0.86 [pip3] nvidia-cusolver-cu11==11.4.1.48 [pip3] nvidia-cusparse-cu11==11.7.5.86 [pip3] nvidia-nccl-cu11==2.21.5 [pip3] nvidia-nvtx-cu11==11.8.86 [pip3] onnx==1.17.0 [pip3] onnx_graphsurgeon==0.5.8 [pip3] onnx-ir==0.1.12 [pip3] onnx-simplifier==0.4.36 [pip3] onnxruntime==1.21.0 [pip3] onnxruntime-gpu==1.21.0 [pip3] onnxscript==0.4.0 [pip3] onnxslim==0.1.48 [pip3] torch==2.6.0+cu118 [pip3] torchao==0.14.1 [pip3] torchaudio==2.6.0+cu118 [pip3] torchvision==0.21.0+cu118 [pip3] ... </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes pytorch/pytorch#177611 <!-- START COPILOT CODING AGENT TIPS --> --- 📍 Connect Copilot coding agent with [Jira](https://gh.io/cca-jira-docs), [Azure Boards](https://gh.io/cca-azure-boards-docs) or [Linear](https://gh.io/cca-linear-docs) to delegate work to Copilot in one click without leaving your project management tool. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 19e5284 commit 847801c

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
# ConstantOfShape is preserved to avoid increasing model size unnecessarily
3131
"ConstantOfShape",
3232
# Quantize/DequantizeLinear are preserved to keep the quantization info
33-
"QuantizeLinear",
3433
"DequantizeLinear",
34+
"DynamicQuantizeLinear",
35+
"QuantizeLinear",
3536
]
3637

3738
DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,36 @@ def test_node_is_folded_if_specified_as_should_fold(self):
689689
np.ones((42, 42), dtype=np.int64),
690690
)
691691

692+
def test_quantize_linear_is_not_folded(self):
693+
model_text = """
694+
<ir_version: 10, opset_import: [ "" : 20]>
695+
agraph () => (uint8[4] z)
696+
<float[4] x = {1.0, 2.0, 3.0, 4.0}, float[1] scale = {1.0}, uint8[1] zero_point = {0}>
697+
{
698+
z = QuantizeLinear (x, scale, zero_point)
699+
}
700+
"""
701+
model = ir.from_onnx_text(model_text)
702+
optimized = self._fold(model)
703+
ops = [node.op_type for node in optimized.graph]
704+
# QuantizeLinear should not be folded even when all inputs are constants
705+
self.assertEqual(ops, ["QuantizeLinear"])
706+
707+
def test_dequantize_linear_is_not_folded(self):
708+
model_text = """
709+
<ir_version: 10, opset_import: [ "" : 20]>
710+
agraph () => (float[4] z)
711+
<uint8[4] x = {1, 2, 3, 4}, float[1] scale = {1.0}, uint8[1] zero_point = {0}>
712+
{
713+
z = DequantizeLinear (x, scale, zero_point)
714+
}
715+
"""
716+
model = ir.from_onnx_text(model_text)
717+
optimized = self._fold(model)
718+
ops = [node.op_type for node in optimized.graph]
719+
# DequantizeLinear should not be folded even when all inputs are constants
720+
self.assertEqual(ops, ["DequantizeLinear"])
721+
692722
def test_multi_graph_identity_output_preserves_output_name(self):
693723
model = """
694724
<ir_version: 10, opset_import: ["" : 20]>

0 commit comments

Comments
 (0)