@@ -517,7 +517,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
517517
518518 # 1. Dequantize
519519 # 2. MatmulnN
520- output = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
520+ if getattr (quant_state , "ipex" , False ):
521+ output = F .gemv_4bit (A , B , out , state = quant_state )
522+ else :
523+ output = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
521524
522525 # 3. Save state
523526 ctx .state = quant_state
@@ -548,7 +551,10 @@ def backward(ctx, grad_output):
548551 # not supported by PyTorch. TODO: create work-around
549552 # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
550553 if req_gradA :
551- grad_A = torch .matmul (grad_output , F .dequantize_4bit (B , ctx .state ).to (grad_output .dtype ).t ())
554+ if getattr (ctx .state , "ipex" , False ):
555+ grad_A = F .gemv_4bit (grad_output , B , None , state = ctx .state )
556+ else :
557+ grad_A = torch .matmul (grad_output , F .dequantize_4bit (B , ctx .state ).to (grad_output .dtype ).t ())
552558
553559 return grad_A , grad_B , None , grad_bias , None
554560
@@ -575,7 +581,7 @@ def matmul_4bit(
575581 bias = None ,
576582):
577583 assert quant_state is not None
578- if ( A .numel () == A .shape [- 1 ] or A .device .type == "cpu" ) and A .requires_grad == False :
584+ if A .numel () == A .shape [- 1 ] and A .device .type != "cpu" and A .requires_grad == False :
579585 # CPU backend does not require A to be a vector
580586 if A .shape [- 1 ] % quant_state .blocksize != 0 :
581587 warn (
0 commit comments