diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 326075b2fe..0f9ee7366c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -23,6 +23,7 @@ COMPLEX128, DOUBLE, FLOAT, + FLOAT16, INT8, INT16, INT32, @@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True) def aten_fake_quantize_per_channel_affine( - self: TensorType, - scale: TensorType, - zero_point: TensorType, + self: TFloat, + scale: FLOAT, # float32 specifically! + zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only! axis: int, quant_min: int, quant_max: int, ) -> TensorType: """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + if zero_point.type.dtype == ir.DataType.INT32: + zero_point = op.Cast(zero_point, to=int_dtype) + else: + raise NotImplementedError( + "ONNX only supports integer values for the zero_point parameter. " + f"Got {zero_point.type.dtype}", + ) + + quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_channel_affine_cachemask( @@ -3351,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward( raise NotImplementedError() +@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True) def aten_fake_quantize_per_tensor_affine( - self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int -) -> TensorType: + self: TFloat, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, +) -> TFloat: """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True) +def aten_fake_quantize_per_tensor_affine_tensor_qparams( + self: TFloat, + scale: TReal, + zero_point: TReal, + quant_min: int, + quant_max: int, +) -> TFloat: + """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor""" + + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +def _aten_fake_quantize_per_tensor_affine( + self: TFloat, + scale: Union[float, TReal], + zero_point: Union[int, TReal], + quant_min: int, + quant_max: int, +) -> TFloat: + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + # TODO: When opset >= 19, relex the condition for this cast + if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT: + scale = op.Cast(scale, to=ir.DataType.FLOAT) + + if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype: + zero_point = op.Cast(zero_point, to=int_dtype) + + quantized = op.QuantizeLinear(self, scale, zero_point) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_tensor_affine_cachemask( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 5d7deb1695..2ce015b363 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -779,6 +779,109 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs_fake_quantize_per_tensor_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary) + shapes = [ + (S,), + (1, 0, 3), + (), + ] + + scale_zero_point_dtypes = [ + # default (float, int) + (None, None) + ] + [ + # tensor_qparams (tensor, tensor) + (t1, t2) + for t1 in common_dtype.all_types_and() + for t2 in common_dtype.all_types_and() + ] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals) + for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases: + scale = make_arg( + (), + dtype=scale_dtype or torch.float64, + ) + if scale_dtype is None: + scale = scale.item() + + zero_point = make_arg( + (), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + if zero_point_dtype is None: + zero_point = zero_point.item() + + args = (scale, zero_point, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + +def sample_inputs_fake_quantize_per_channel_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, 2D, 4D and empty tensors (scalar tensors not supported) + axes_and_shapes = [ + # 1D, 2D, 4D + (axis, (S,) * dims) + for dims in (1, 2, 4) + for axis in range(dims) + ] + [ + # empty + (0, (1, 0, 3)), + (2, (1, 0, 3)), + # empty channel axis causes an error due to + # an internal zero_point.min() calculation + # (1, (1, 0, 3)), + ] + + # tensor_qparams + scale_dtype = torch.float + zero_point_dtypes = [torch.int32, torch.float, torch.half] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals) + for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases: + scale = make_arg((shape[axis],), dtype=scale_dtype) + + zero_point = make_arg( + (shape[axis],), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + + args = (scale, zero_point, axis, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + def _index_variable_bool(shape, max_indices, device): if not isinstance(shape, tuple): shape = (shape,) @@ -2408,6 +2511,22 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + aten_name="fake_quantize_per_tensor_affine", + op=torch.fake_quantize_per_tensor_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_channel_affine", + aten_name="fake_quantize_per_channel_affine", + op=torch.fake_quantize_per_channel_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine, + supports_out=False, + ), opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4ef7550b6e..e87a0cc232 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -698,6 +698,17 @@ def _where_input_wrangler( TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_channel_affine", + core_ops.aten_fake_quantize_per_channel_affine, + ).xfail( + reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values", + matcher=lambda sample: sample.args[1].dtype != torch.int32, + ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + core_ops.aten_fake_quantize_per_tensor_affine, + ), TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip).skip( reason="fixme: size 0 inputs are not handled yet",