@@ -3392,12 +3392,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33923392 raise NotImplementedError ()
33933393
33943394
3395+ @torch_op ("aten::fake_quantize_per_tensor_affine" , trace_only = True )
33953396def aten_fake_quantize_per_tensor_affine (
3396- self : TensorType , scale : float , zero_point : int , quant_min : int , quant_max : int
3397- ) -> TensorType :
3397+ self : TFloat ,
3398+ scale : float ,
3399+ zero_point : int ,
3400+ quant_min : int ,
3401+ quant_max : int ,
3402+ ) -> TFloat :
33983403 """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
33993404
3400- raise NotImplementedError ()
3405+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3406+
3407+
3408+ @torch_op ("aten::fake_quantize_per_tensor_affine.tensor_qparams" , trace_only = True )
3409+ def aten_fake_quantize_per_tensor_affine_tensor_qparams (
3410+ self : TFloat ,
3411+ scale : TReal ,
3412+ zero_point : TReal ,
3413+ quant_min : int ,
3414+ quant_max : int ,
3415+ ) -> TFloat :
3416+ """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3417+
3418+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3419+
3420+
3421+ def _aten_fake_quantize_per_tensor_affine (
3422+ self : TFloat ,
3423+ scale : Union [float , TReal ],
3424+ zero_point : Union [int , TReal ],
3425+ quant_min : int ,
3426+ quant_max : int ,
3427+ ) -> TFloat :
3428+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3429+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3430+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3431+ raise NotImplementedError (
3432+ "For (quant_min, quant_max), ONNX allows only "
3433+ "(0, 127), (0, 255) and (-128, 127). "
3434+ f"Got ({ quant_min } , { quant_max } )" ,
3435+ )
3436+
3437+ if quant_min == 0 :
3438+ int_dtype = ir .DataType .UINT8
3439+ else :
3440+ int_dtype = ir .DataType .INT8
3441+
3442+ # TODO: When opset >= 19, remove this cast
3443+ orig_dtype = self .type .dtype
3444+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3445+ self = op .Cast (self , to = ir .DataType .FLOAT )
3446+
3447+ # TODO: When opset >= 19, relex the condition for this cast
3448+ if isinstance (scale , float ) or scale .type .dtype != ir .DataType .FLOAT :
3449+ scale = op .Cast (scale , to = ir .DataType .FLOAT )
3450+
3451+ if isinstance (zero_point , int ) or zero_point .type .dtype != int_dtype :
3452+ zero_point = op .Cast (zero_point , to = int_dtype )
3453+
3454+ quantized = op .QuantizeLinear (self , scale , zero_point )
3455+
3456+ # See comment about, PyTorch-specific (0, 127) handling
3457+ if (quant_min , quant_max ) == (0 , 127 ):
3458+ const_127 = op .Cast (127 , to = int_dtype )
3459+ quantized = op .Clip (quantized , max = const_127 )
3460+
3461+ output = op .DequantizeLinear (quantized , scale , zero_point )
3462+
3463+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3464+ if orig_dtype != ir .DataType .FLOAT :
3465+ output = op .Cast (output , to = orig_dtype )
3466+
3467+ return output
34013468
34023469
34033470def aten_fake_quantize_per_tensor_affine_cachemask (
0 commit comments