@@ -3392,12 +3392,86 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33923392 raise NotImplementedError ()
33933393
33943394
3395+ @torch_op ("aten::fake_quantize_per_tensor_affine" , trace_only = True )
33953396def aten_fake_quantize_per_tensor_affine (
3396- self : TensorType , scale : float , zero_point : int , quant_min : int , quant_max : int
3397- ) -> TensorType :
3397+ self : TFloat ,
3398+ scale : float ,
3399+ zero_point : int ,
3400+ quant_min : int ,
3401+ quant_max : int ,
3402+ ) -> TFloat :
33983403 """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
33993404
3400- raise NotImplementedError ()
3405+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3406+
3407+
3408+ @torch_op ("aten::fake_quantize_per_tensor_affine.tensor_qparams" , trace_only = True )
3409+ def aten_fake_quantize_per_tensor_affine_tensor_qparams (
3410+ self : TFloat ,
3411+ scale : TReal ,
3412+ zero_point : TReal ,
3413+ quant_min : int ,
3414+ quant_max : int ,
3415+ ) -> TFloat :
3416+ """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3417+
3418+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3419+
3420+
3421+ def _aten_fake_quantize_per_tensor_affine (
3422+ self : TFloat ,
3423+ scale : Union [float , TReal ],
3424+ zero_point : Union [int , TReal ],
3425+ quant_min : int ,
3426+ quant_max : int ,
3427+ ) -> TFloat :
3428+
3429+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3430+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3431+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3432+ raise NotImplementedError (
3433+ "For (quant_min, quant_max), ONNX allows only "
3434+ "(0, 127), (0, 255) and (-128, 127). "
3435+ f"Got ({ quant_min } , { quant_max } )" ,
3436+ )
3437+
3438+ if quant_min == 0 :
3439+ int_dtype = ir .DataType .UINT8
3440+ else :
3441+ int_dtype = ir .DataType .INT8
3442+
3443+ # TODO: When opset >= 19, remove this cast
3444+ orig_dtype = self .type .dtype
3445+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3446+ self = op .Cast (self , to = ir .DataType .FLOAT )
3447+
3448+ # TODO: When opset >= 19, relex the condition for this cast
3449+ if (
3450+ isinstance (scale , float ) or
3451+ scale .type .dtype != ir .DataType .FLOAT
3452+ ):
3453+ scale = op .Cast (scale , to = ir .DataType .FLOAT )
3454+
3455+ if (
3456+ isinstance (zero_point , int ) or
3457+ zero_point .type .dtype != int_dtype
3458+ ):
3459+ zero_point = op .Cast (zero_point , to = int_dtype )
3460+
3461+ quantized = op .QuantizeLinear (self , scale , zero_point )
3462+
3463+ # See comment about, PyTorch-specific (0, 127) handling
3464+ if (quant_min , quant_max ) == (0 , 127 ):
3465+ const_127 = op .Cast (127 , to = int_dtype )
3466+ quantized = op .Clip (quantized , max = const_127 )
3467+
3468+ output = op .DequantizeLinear (quantized , scale , zero_point )
3469+
3470+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3471+ if orig_dtype != ir .DataType .FLOAT :
3472+ output = op .Cast (output , to = orig_dtype )
3473+
3474+ return output
34013475
34023476
34033477def aten_fake_quantize_per_tensor_affine_cachemask (
0 commit comments