1616
1717 ipex_cpu = ipex if ipex ._C ._has_cpu () else None
1818 ipex_xpu = ipex if ipex ._C ._has_xpu () else None
19+ ipex_cpu_only = ipex ._C ._has_cpu () and (not ipex ._C ._has_xpu ())
1920except BaseException :
2021 ipex_cpu = None
2122 ipex_xpu = None
@@ -56,7 +57,7 @@ def _ipex_xpu_version_prereq(major, minor):
5657
5758def _maybe_torch_compile (func ):
5859 # torch.compile requires g++ and pytorch >= 2.0
59- if gxx_available and _torch_version_prereq (2 , 0 ) and os .getenv ('PT_HPU_LAZY_MODE' ,1 )== 0 :
60+ if gxx_available and _torch_version_prereq (2 , 0 ) and not ipex_xpu and os .getenv ('PT_HPU_LAZY_MODE' ,1 )== 0 :
6061 options = {}
6162 # fx_graph_cache requires pytorch >= 2.2
6263 if _torch_version_prereq (2 , 2 ):
@@ -182,7 +183,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
182183 A_reshaped = A .reshape (m , k )
183184
184185 # torch._int_mm is available on CPU since torch 2.4
185- if _torch_version_prereq (2 , 4 ):
186+ if _torch_version_prereq (2 , 4 ) and A . device . type == "cpu" :
186187 C = torch ._int_mm (A_reshaped , B .T ).to (dtype )
187188 else :
188189 C = torch .matmul (A_reshaped .float (), B .t ().float ()).to (dtype )
@@ -234,8 +235,10 @@ def mm_dequant_impl(
234235 out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
235236
236237 if compute_dtype not in [torch .float32 , torch .bfloat16 ]:
237- warnings .warn (f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use float instead" )
238- compute_dtype = torch .float32
238+ warnings .warn (
239+ f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use bfloat16 instead"
240+ )
241+ compute_dtype = torch .bfloat16
239242 A_reshaped = A .reshape (out_shape ).to (compute_dtype )
240243 row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 ).to (compute_dtype )
241244 col_stats = col_stats .reshape (- 1 ).unsqueeze (0 ).to (compute_dtype )
@@ -408,7 +411,6 @@ def dequantize_4bit_impl(
408411 torch.Tensor:
409412 Dequantized tensor.
410413 """
411-
412414 if A .shape [0 ] == 1 :
413415 transpose = False
414416 A = A .squeeze (0 )
@@ -438,23 +440,27 @@ def dequantize_4bit_impl(
438440 if quant_state .nested :
439441 raise NotImplementedError ("bnb_4bit_use_double_quant is not supported yet for CPU/XPU" )
440442
441- if ipex_cpu and _ipex_cpu_version_prereq (2 , 3 ) and hasattr (quant_state , "op_context" ):
442- assert quant_state .op_context is not None
443- A = quant_state .op_context .to_public (quant_state .op_context .get_weight ())
444- A = A .reshape (- 1 )
445- absmax = quant_state .op_context .get_scales ().reshape (- 1 )
446-
447- if out is None :
448- out = torch .empty (quant_state .shape , dtype = quant_state .dtype , device = A .device )
443+ if ipex_cpu_only and _ipex_cpu_version_prereq (2 , 5 ) and getattr (quant_state , "ipex" , False ):
444+ A = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , quant_state .shape , 2 )
445+ quant_state .ipex = False
449446
450- n = out .numel ()
451447 # Map nf4 to [-1, 1]
448+ < << << << HEAD
452449 out_uint8 = torch .empty (A .size (0 ) * 2 , dtype = torch .uint8 , device = A .device )
453450 out_uint8 [::2 ] = A .bitwise_and (0xF )
454451 out_uint8 [1 ::2 ] = A .bitwise_right_shift (4 )
455452 out_dq = torch .empty (out_uint8 .shape , dtype = quant_state .code .dtype , device = quant_state .code .device )
456453 for i in range (len (quant_state .code )):
457454 out_dq [out_uint8 == i ] = quant_state .code [i ]
455+ == == == =
456+ out_dq = torch .empty (A .size (0 ) * 2 , dtype = torch .int32 , device = A .device )
457+ n = out_dq .numel ()
458+ out_dq [::2 ] = A & 0xF
459+ out_dq [1 ::2 ] = A >> 4
460+ # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
461+ quant_state .code = quant_state .code .to (quant_state .dtype )
462+ out_dq = quant_state .code [out_dq ]
463+ > >> >> >> b2ac423 (Enable XPU and optimize cpu / xpu op (#1418))
458464
459465 # Apply scales
460466 if out_dq .numel () != n :
@@ -464,12 +470,17 @@ def dequantize_4bit_impl(
464470 blocks += 1 if n % blocksize > 0 else 0
465471 rem = n % blocksize
466472 has_rem = rem > 0
467- out_reshaped = out .reshape (- 1 )
468- out_reshaped [: n - rem ] = (out_dq [: n - rem ].view (- 1 , blocksize ) * absmax [: blocks - has_rem ].view (- 1 , 1 )).reshape (
469- - 1
470- )
473+
471474 if has_rem :
475+ if out is None :
476+ out = torch .empty (quant_state .shape , dtype = quant_state .dtype , device = A .device )
477+ out_reshaped = out .reshape (- 1 )
478+ out_reshaped [: n - rem ] = (
479+ out_dq [: n - rem ].view (- 1 , blocksize ) * absmax [: blocks - has_rem ].view (- 1 , 1 )
480+ ).reshape (- 1 )
472481 out_reshaped [n - rem :] = out_dq [n - rem :] * absmax [- 1 ]
482+ else :
483+ out = (out_dq .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).reshape (quant_state .shape ).to (quant_state .dtype )
473484
474485 # take transpose here because weight is transposed (again) for computation
475486 if transpose :
@@ -510,9 +521,21 @@ def gemm_4bit_impl(
510521 torch.Tensor:
511522 GEMM output tensor.
512523 """
513- if ipex_cpu and _ipex_cpu_version_prereq (2 , 3 ) and hasattr (state , "op_context" ):
514- assert state .op_context is not None
515- output = torch .ops .torch_ipex .ipex_woq_linear (A , state .op_context .get_data_handle ())
524+ if getattr (state , "ipex" , False ):
525+ output = torch .ops .torch_ipex .woq_linear (
526+ A ,
527+ B ,
528+ "nf4" ,
529+ state .shape ,
530+ state .new_scales ,
531+ state .new_zeros ,
532+ None ,
533+ None ,
534+ state .blocksize ,
535+ ipex_cpu .quantization .WoqLowpMode .BF16 ,
536+ 1 ,
537+ state .compensation ,
538+ )
516539 else :
517540 dqB = dequantize_4bit_impl (B , state , blocksize = state .blocksize ).t ()
518541 output = torch .matmul (A , dqB .to (A .dtype ))
0 commit comments