1515
1616 ipex_cpu = ipex if ipex ._C ._has_cpu () else None
1717 ipex_xpu = ipex if ipex ._C ._has_xpu () else None
18+ ipex_cpu_only = ipex ._C ._has_cpu () and (not ipex ._C ._has_xpu ())
1819except BaseException :
1920 ipex_cpu = None
2021 ipex_xpu = None
@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):
5556
5657def _maybe_torch_compile (func ):
5758 # torch.compile requires g++ and pytorch >= 2.0
58- if gxx_available and _torch_version_prereq (2 , 0 ):
59+ if gxx_available and _torch_version_prereq (2 , 0 ) and not ipex_xpu :
5960 options = {}
6061 # fx_graph_cache requires pytorch >= 2.2
6162 if _torch_version_prereq (2 , 2 ):
@@ -181,7 +182,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
181182 A_reshaped = A .reshape (m , k )
182183
183184 # torch._int_mm is available on CPU since torch 2.4
184- if _torch_version_prereq (2 , 4 ):
185+ if _torch_version_prereq (2 , 4 ) and A . device . type == "cpu" :
185186 C = torch ._int_mm (A_reshaped , B .T ).to (dtype )
186187 else :
187188 C = torch .matmul (A_reshaped .float (), B .t ().float ()).to (dtype )
@@ -233,8 +234,10 @@ def mm_dequant_impl(
233234 out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
234235
235236 if compute_dtype not in [torch .float32 , torch .bfloat16 ]:
236- warnings .warn (f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use float instead" )
237- compute_dtype = torch .float32
237+ warnings .warn (
238+ f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use bfloat16 instead"
239+ )
240+ compute_dtype = torch .bfloat16
238241 A_reshaped = A .reshape (out_shape ).to (compute_dtype )
239242 row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 ).to (compute_dtype )
240243 col_stats = col_stats .reshape (- 1 ).unsqueeze (0 ).to (compute_dtype )
@@ -342,7 +345,7 @@ def quantize_4bit_impl(
342345 scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
343346 scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
344347 # map [-1, 1] to nf4/fp4
345- out_uint8 = torch .empty (scaled_A .shape , dtype = torch .uint8 )
348+ out_uint8 = torch .empty (scaled_A .shape , dtype = torch .uint8 , device = A . device )
346349 if quant_type == "nf4" :
347350 for i in range (len (NF4_QUANT_TABLE )):
348351 out_uint8 [scaled_A > NF4_QUANT_TABLE [i ]] = i
@@ -408,7 +411,6 @@ def dequantize_4bit_impl(
408411 torch.Tensor:
409412 Dequantized tensor.
410413 """
411-
412414 if A .shape [0 ] == 1 :
413415 transpose = False
414416 A = A .squeeze (0 )
@@ -438,23 +440,18 @@ def dequantize_4bit_impl(
438440 if quant_state .nested :
439441 raise NotImplementedError ("bnb_4bit_use_double_quant is not supported yet for CPU/XPU" )
440442
441- if ipex_cpu and _ipex_cpu_version_prereq (2 , 3 ) and hasattr (quant_state , "op_context" ):
442- assert quant_state .op_context is not None
443- A = quant_state .op_context .to_public (quant_state .op_context .get_weight ())
444- A = A .reshape (- 1 )
445- absmax = quant_state .op_context .get_scales ().reshape (- 1 )
446-
447- if out is None :
448- out = torch .empty (quant_state .shape , dtype = quant_state .dtype , device = A .device )
443+ if ipex_cpu_only and _ipex_cpu_version_prereq (2 , 5 ) and getattr (quant_state , "ipex" , False ):
444+ A = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , quant_state .shape , 2 )
445+ quant_state .ipex = False
449446
450- n = out .numel ()
451447 # Map nf4 to [-1, 1]
452- out_uint8 = torch .empty (A .size (0 ) * 2 , dtype = torch .uint8 , device = A .device )
453- out_uint8 [::2 ] = A .bitwise_and (0xF )
454- out_uint8 [1 ::2 ] = A .bitwise_right_shift (4 )
455- out_dq = torch .empty (out_uint8 .shape ).to (quant_state .dtype )
456- for i in range (len (quant_state .code )):
457- out_dq [out_uint8 == i ] = quant_state .code [i ]
448+ out_dq = torch .empty (A .size (0 ) * 2 , dtype = torch .int32 , device = A .device )
449+ n = out_dq .numel ()
450+ out_dq [::2 ] = A & 0xF
451+ out_dq [1 ::2 ] = A >> 4
452+ # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
453+ quant_state .code = quant_state .code .to (quant_state .dtype )
454+ out_dq = quant_state .code [out_dq ]
458455
459456 # Apply scales
460457 if out_dq .numel () != n :
@@ -464,12 +461,17 @@ def dequantize_4bit_impl(
464461 blocks += 1 if n % blocksize > 0 else 0
465462 rem = n % blocksize
466463 has_rem = rem > 0
467- out_reshaped = out .reshape (- 1 )
468- out_reshaped [: n - rem ] = (out_dq [: n - rem ].view (- 1 , blocksize ) * absmax [: blocks - has_rem ].view (- 1 , 1 )).reshape (
469- - 1
470- )
464+
471465 if has_rem :
466+ if out is None :
467+ out = torch .empty (quant_state .shape , dtype = quant_state .dtype , device = A .device )
468+ out_reshaped = out .reshape (- 1 )
469+ out_reshaped [: n - rem ] = (
470+ out_dq [: n - rem ].view (- 1 , blocksize ) * absmax [: blocks - has_rem ].view (- 1 , 1 )
471+ ).reshape (- 1 )
472472 out_reshaped [n - rem :] = out_dq [n - rem :] * absmax [- 1 ]
473+ else :
474+ out = (out_dq .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).reshape (quant_state .shape ).to (quant_state .dtype )
473475
474476 # take transpose here because weight is transposed (again) for computation
475477 if transpose :
@@ -510,9 +512,21 @@ def gemm_4bit_impl(
510512 torch.Tensor:
511513 GEMM output tensor.
512514 """
513- if ipex_cpu and _ipex_cpu_version_prereq (2 , 3 ) and hasattr (state , "op_context" ):
514- assert state .op_context is not None
515- output = torch .ops .torch_ipex .ipex_woq_linear (A , state .op_context .get_data_handle ())
515+ if getattr (state , "ipex" , False ):
516+ output = torch .ops .torch_ipex .woq_linear (
517+ A ,
518+ B ,
519+ "nf4" ,
520+ state .shape ,
521+ state .new_scales ,
522+ state .new_zeros ,
523+ None ,
524+ None ,
525+ state .blocksize ,
526+ ipex_cpu .quantization .WoqLowpMode .BF16 ,
527+ 1 ,
528+ state .compensation ,
529+ )
516530 else :
517531 dqB = dequantize_4bit_impl (B , state , blocksize = state .blocksize ).t ()
518532 output = torch .matmul (A , dqB .to (A .dtype ))
0 commit comments