@@ -2103,4 +2103,138 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
21032103 return out
21042104
21052105
2106+ def _convert_weight_packed_for_cpu (qweight : torch .Tensor , quant_state : QuantState , block_n : int = 32 ):
2107+ """
2108+ qweight: (K * N / 2) uint8
2109+ return: packed_weight
2110+ """
2111+ if qweight .dtype != torch .uint8 :
2112+ quant_state .original_storage_type = qweight .dtype
2113+ qweight = qweight .view (torch .uint8 )
2114+ quant_state .original_dtype = quant_state .dtype
2115+ quant_state .original_nested = quant_state .nested
2116+ quant_state .original_qshape = qweight .shape
2117+
2118+ qweight = qweight .reshape (- 1 )
2119+ unpacked_w = torch .empty (qweight .shape [0 ] * 2 , dtype = torch .int32 , device = qweight .device )
2120+ unpacked_w [1 ::2 ] = qweight & 0xF
2121+ unpacked_w [::2 ] = qweight >> 4
2122+ qweight_final = unpacked_w .reshape (quant_state .shape ).to (torch .uint8 ) # (*, N, K)
2123+ # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
2124+ assert len (qweight_final .shape ) == 2
2125+ N , K = qweight_final .shape [0 ], qweight_final .shape [1 ]
2126+ assert N % block_n == 0 , "N must be divisible by block_n"
2127+ assert K % 2 == 0 , "K must be even"
2128+ BLOCK_N = block_n
2129+ BIT_COUNT = 32 # (=32 low +32 high)
2130+ new_shape = [N // BLOCK_N , BLOCK_N , K // 2 , 2 ]
2131+ out_shape = [N , K // 2 ]
2132+ qw = qweight_final .reshape (new_shape ) # (..., N/B, B, K/2, 2)
2133+ qw = qw .transpose (- 3 , - 2 ).contiguous () # (..., N/B, K/2, B, 2)
2134+ qw = qw .reshape (- 1 , BIT_COUNT * 2 ) # [-1, 64]
2135+ high = qw [:, BIT_COUNT :] # high 32
2136+ low = qw [:, :BIT_COUNT ] # low 32
2137+ packed = ((high << 4 ) | low ).to (torch .uint8 ) # combine
2138+ final_qweight = packed .reshape (out_shape )
2139+ if quant_state .nested :
2140+ absmax = dequantize_blockwise (quant_state .absmax , quant_state .state2 )
2141+ absmax += quant_state .offset
2142+ if absmax .dtype != torch .float32 :
2143+ absmax = absmax .float ()
2144+
2145+ quant_state .absmax = absmax
2146+ quant_state .nested = False
2147+ delattr (quant_state , "state2" )
2148+
2149+ quant_state .absmax = (
2150+ quant_state .absmax .reshape (quant_state .shape [0 ], quant_state .shape [1 ] // quant_state .blocksize )
2151+ .T .to (torch .bfloat16 )
2152+ .contiguous ()
2153+ )
2154+
2155+ quant_state .dtype = torch .bfloat16
2156+ quant_state .packing_format_for_cpu = True
2157+ return final_qweight , quant_state
2158+
2159+
2160+ def _convert_weight_packed_for_cpu_inverse (
2161+ packed_weight : torch .Tensor ,
2162+ quant_state : QuantState ,
2163+ block_n : int = 32 ,
2164+ ) -> tuple [torch .Tensor , QuantState ]:
2165+ """
2166+ packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight)
2167+ quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu`
2168+ Returns:
2169+ qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
2170+ recovered_state: QuantState with partially restored fields (best-effort inverse)
2171+ """
2172+ assert quant_state .packing_format_for_cpu , "only for packing format"
2173+ assert packed_weight .dtype == torch .uint8
2174+ assert len (packed_weight .shape ) == 2 , "packed_weight should be [N, K/2]"
2175+ N , K_half = packed_weight .shape
2176+ K = K_half * 2
2177+
2178+ # 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
2179+ BLOCK_N = block_n
2180+ BIT_COUNT = 32 # (=32 low + 32 high)
2181+
2182+ assert N % BLOCK_N == 0 , "N must be divisible by block_n"
2183+ assert K % 2 == 0 , "K must be even"
2184+
2185+ # [N, K/2] -> [-1, 64] (32 low + 32 high)
2186+ packed = packed_weight .reshape (- 1 , BIT_COUNT ) # [-1, 64]
2187+ # split high/low nibbles
2188+ high = (packed >> 4 ) & 0xF
2189+ low = packed & 0xF
2190+ # concatenate to [..., 64], first 32 are low, last 32 are high
2191+ qw = torch .cat ([low , high ], dim = - 1 ).to (torch .uint8 ) # [..., 64]
2192+
2193+ # -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K]
2194+ qw = qw .reshape (N // BLOCK_N , K_half , BLOCK_N , 2 ) # [N/B, K/2, B, 2]
2195+ qw = qw .transpose (- 3 , - 2 ).contiguous () # [N/B, B, K/2, 2]
2196+ qw = qw .reshape (N , K ) # [N, K]
2197+
2198+ qweight = qw # [N, K]
2199+
2200+ unpacked_w = qweight .reshape (- 1 ).to (torch .int32 ) # [K*N]
2201+ high4 = (unpacked_w [::2 ] & 0xF ).to (torch .uint8 )
2202+ low4 = (unpacked_w [1 ::2 ] & 0xF ).to (torch .uint8 )
2203+ qweight = (high4 << 4 ) | low4 # [K*N/2]
2204+
2205+ # 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
2206+ recovered_state = quant_state
2207+ qweight = qweight .to (torch .uint8 ).reshape (recovered_state .original_qshape )
2208+
2209+ # quantize absmax
2210+ if recovered_state .original_nested :
2211+ absmax = recovered_state .absmax .T .reshape (- 1 ).to (recovered_state .original_dtype )
2212+ offset = absmax .mean ()
2213+ qabsmax , state2 = quantize_blockwise (absmax - offset , blocksize = 256 )
2214+ recovered_state .absmax = qabsmax
2215+ recovered_state .offset = offset
2216+ recovered_state .state2 = state2
2217+ recovered_state .nested = True
2218+
2219+ recovered_state .dtype = recovered_state .original_dtype
2220+ recovered_state .packing_format_for_cpu = False
2221+
2222+ if getattr (recovered_state , "original_storage_type" , None ):
2223+ qweight = qweight .view (recovered_state .original_storage_type )
2224+
2225+ return qweight , recovered_state
2226+
2227+
2228+ def has_avx512bf16 ():
2229+ """
2230+ Try calling native lib.has_avx512bf16_cpu().
2231+ Return False explicitly if symbol missing or call fails.
2232+ """
2233+ try :
2234+ support_avx_bf16 = lib .has_avx512bf16_cpu ()
2235+ except (AttributeError , RuntimeError , OSError ):
2236+ support_avx_bf16 = False
2237+ return support_avx_bf16
2238+
2239+
21062240C = 127.0
0 commit comments