|
23 | 23 | COMPLEX128, |
24 | 24 | DOUBLE, |
25 | 25 | FLOAT, |
| 26 | + FLOAT16, |
26 | 27 | INT8, |
27 | 28 | INT16, |
28 | 29 | INT32, |
@@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType: |
3317 | 3318 | raise NotImplementedError() |
3318 | 3319 |
|
3319 | 3320 |
|
| 3321 | +@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True) |
3320 | 3322 | def aten_fake_quantize_per_channel_affine( |
3321 | | - self: TensorType, |
3322 | | - scale: TensorType, |
3323 | | - zero_point: TensorType, |
| 3323 | + self: TFloat, |
| 3324 | + scale: FLOAT, # float32 specifically! |
| 3325 | + zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only! |
3324 | 3326 | axis: int, |
3325 | 3327 | quant_min: int, |
3326 | 3328 | quant_max: int, |
3327 | 3329 | ) -> TensorType: |
3328 | 3330 | """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor""" |
3329 | 3331 |
|
3330 | | - raise NotImplementedError() |
| 3332 | + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). |
| 3333 | + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 |
| 3334 | + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: |
| 3335 | + raise NotImplementedError( |
| 3336 | + "For (quant_min, quant_max), ONNX allows only " |
| 3337 | + "(0, 127), (0, 255) and (-128, 127). " |
| 3338 | + f"Got ({quant_min}, {quant_max})", |
| 3339 | + ) |
| 3340 | + |
| 3341 | + if quant_min == 0: |
| 3342 | + int_dtype = ir.DataType.UINT8 |
| 3343 | + else: |
| 3344 | + int_dtype = ir.DataType.INT8 |
| 3345 | + |
| 3346 | + # TODO: When opset >= 19, remove this cast |
| 3347 | + orig_dtype = self.type.dtype |
| 3348 | + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: |
| 3349 | + self = op.Cast(self, to=ir.DataType.FLOAT) |
| 3350 | + |
| 3351 | + if zero_point.type.dtype == ir.DataType.INT32: |
| 3352 | + zero_point = op.Cast(zero_point, to=int_dtype) |
| 3353 | + else: |
| 3354 | + raise NotImplementedError( |
| 3355 | + "ONNX only supports integer values for the zero_point parameter. " |
| 3356 | + f"Got {zero_point.type.dtype}", |
| 3357 | + ) |
| 3358 | + |
| 3359 | + quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis) |
| 3360 | + |
| 3361 | + # See comment about, PyTorch-specific (0, 127) handling |
| 3362 | + if (quant_min, quant_max) == (0, 127): |
| 3363 | + const_127 = op.Cast(127, to=int_dtype) |
| 3364 | + quantized = op.Clip(quantized, max=const_127) |
| 3365 | + |
| 3366 | + output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis) |
| 3367 | + |
| 3368 | + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear |
| 3369 | + if orig_dtype != ir.DataType.FLOAT: |
| 3370 | + output = op.Cast(output, to=orig_dtype) |
| 3371 | + |
| 3372 | + return output |
3331 | 3373 |
|
3332 | 3374 |
|
3333 | 3375 | def aten_fake_quantize_per_channel_affine_cachemask( |
|
0 commit comments