Skip to content

Commit b8df1aa

Browse files
committed
use ipex op in backward
1 parent 012c660 commit b8df1aa

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
517517

518518
# 1. Dequantize
519519
# 2. MatmulnN
520-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
520+
if getattr(quant_state, "ipex", False):
521+
output = F.gemv_4bit(A, B, out, state=quant_state)
522+
else:
523+
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
521524

522525
# 3. Save state
523526
ctx.state = quant_state
@@ -548,7 +551,10 @@ def backward(ctx, grad_output):
548551
# not supported by PyTorch. TODO: create work-around
549552
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
550553
if req_gradA:
551-
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
554+
if getattr(ctx.state, "ipex", False):
555+
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state)
556+
else:
557+
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
552558

553559
return grad_A, grad_B, None, grad_bias, None
554560

@@ -575,7 +581,7 @@ def matmul_4bit(
575581
bias=None,
576582
):
577583
assert quant_state is not None
578-
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
584+
if A.numel() == A.shape[-1] and A.device.type != "cpu" and A.requires_grad == False:
579585
# CPU backend does not require A to be a vector
580586
if A.shape[-1] % quant_state.blocksize != 0:
581587
warn(

0 commit comments

Comments
 (0)