Skip to content

Commit 160320e

Browse files
committed
implement onnx conversion for aten::fake_quantize_per_tensor_affine
1 parent 303412f commit 160320e

1 file changed

Lines changed: 70 additions & 3 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,12 +3392,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33923392
raise NotImplementedError()
33933393

33943394

3395+
@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True)
33953396
def aten_fake_quantize_per_tensor_affine(
3396-
self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int
3397-
) -> TensorType:
3397+
self: TFloat,
3398+
scale: float,
3399+
zero_point: int,
3400+
quant_min: int,
3401+
quant_max: int,
3402+
) -> TFloat:
33983403
"""fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
33993404

3400-
raise NotImplementedError()
3405+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3406+
3407+
3408+
@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True)
3409+
def aten_fake_quantize_per_tensor_affine_tensor_qparams(
3410+
self: TFloat,
3411+
scale: TReal,
3412+
zero_point: TReal,
3413+
quant_min: int,
3414+
quant_max: int,
3415+
) -> TFloat:
3416+
"""fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3417+
3418+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3419+
3420+
3421+
def _aten_fake_quantize_per_tensor_affine(
3422+
self: TFloat,
3423+
scale: Union[float, TReal],
3424+
zero_point: Union[int, TReal],
3425+
quant_min: int,
3426+
quant_max: int,
3427+
) -> TFloat:
3428+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3429+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3430+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3431+
raise NotImplementedError(
3432+
"For (quant_min, quant_max), ONNX allows only "
3433+
"(0, 127), (0, 255) and (-128, 127). "
3434+
f"Got ({quant_min}, {quant_max})",
3435+
)
3436+
3437+
if quant_min == 0:
3438+
int_dtype = ir.DataType.UINT8
3439+
else:
3440+
int_dtype = ir.DataType.INT8
3441+
3442+
# TODO: When opset >= 19, remove this cast
3443+
orig_dtype = self.type.dtype
3444+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3445+
self = op.Cast(self, to=ir.DataType.FLOAT)
3446+
3447+
# TODO: When opset >= 19, relex the condition for this cast
3448+
if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT:
3449+
scale = op.Cast(scale, to=ir.DataType.FLOAT)
3450+
3451+
if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype:
3452+
zero_point = op.Cast(zero_point, to=int_dtype)
3453+
3454+
quantized = op.QuantizeLinear(self, scale, zero_point)
3455+
3456+
# See comment about, PyTorch-specific (0, 127) handling
3457+
if (quant_min, quant_max) == (0, 127):
3458+
const_127 = op.Cast(127, to=int_dtype)
3459+
quantized = op.Clip(quantized, max=const_127)
3460+
3461+
output = op.DequantizeLinear(quantized, scale, zero_point)
3462+
3463+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3464+
if orig_dtype != ir.DataType.FLOAT:
3465+
output = op.Cast(output, to=orig_dtype)
3466+
3467+
return output
34013468

34023469

34033470
def aten_fake_quantize_per_tensor_affine_cachemask(

0 commit comments

Comments
 (0)