Skip to content

Commit e76a616

Browse files
committed
implement onnx conversion for aten::fake_quantize_per_channel_affine
1 parent b4bec49 commit e76a616

1 file changed

Lines changed: 46 additions & 4 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
COMPLEX128,
2424
DOUBLE,
2525
FLOAT,
26+
FLOAT16,
2627
INT8,
2728
INT16,
2829
INT32,
@@ -3316,17 +3317,58 @@ def aten_eye(n: int) -> TensorType:
33163317
raise NotImplementedError()
33173318

33183319

3320+
@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True)
33193321
def aten_fake_quantize_per_channel_affine(
3320-
self: TensorType,
3321-
scale: TensorType,
3322-
zero_point: TensorType,
3322+
self: TFloat,
3323+
scale: FLOAT, # float32 specifically!
3324+
zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only!
33233325
axis: int,
33243326
quant_min: int,
33253327
quant_max: int,
33263328
) -> TensorType:
33273329
"""fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor"""
33283330

3329-
raise NotImplementedError()
3331+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3332+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3333+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3334+
raise NotImplementedError(
3335+
"For (quant_min, quant_max), ONNX allows only "
3336+
"(0, 127), (0, 255) and (-128, 127). "
3337+
f"Got ({quant_min}, {quant_max})",
3338+
)
3339+
3340+
if quant_min == 0:
3341+
int_dtype = ir.DataType.UINT8
3342+
else:
3343+
int_dtype = ir.DataType.INT8
3344+
3345+
# TODO: When opset >= 19, remove this cast
3346+
orig_dtype = self.type.dtype
3347+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3348+
self = op.Cast(self, to=ir.DataType.FLOAT)
3349+
3350+
if zero_point.type.dtype == ir.DataType.INT32:
3351+
zero_point = op.Cast(zero_point, to=int_dtype)
3352+
else:
3353+
raise NotImplementedError(
3354+
"ONNX only supports integer values for the zero_point parameter. "
3355+
f"Got {zero_point.type.dtype}",
3356+
)
3357+
3358+
quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis)
3359+
3360+
# See comment about, PyTorch-specific (0, 127) handling
3361+
if (quant_min, quant_max) == (0, 127):
3362+
const_127 = op.Cast(127, to=int_dtype)
3363+
quantized = op.Clip(quantized, max=const_127)
3364+
3365+
output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis)
3366+
3367+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3368+
if orig_dtype != ir.DataType.FLOAT:
3369+
output = op.Cast(output, to=orig_dtype)
3370+
3371+
return output
33303372

33313373

33323374
def aten_fake_quantize_per_channel_affine_cachemask(

0 commit comments

Comments
 (0)