@@ -99,7 +99,8 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
9999
100100
101101def prefetch_tensor (A : torch .Tensor , to_cpu = False ):
102- assert A .is_paged , "Only paged tensors can be prefetched!"
102+ if not A .is_paged :
103+ raise AssertionError ("Only paged tensors can be prefetched!" )
103104 if to_cpu :
104105 deviceid = - 1
105106 else :
@@ -218,7 +219,8 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
218219 values = values .sort ().values
219220 values /= values .max ()
220221
221- assert values .numel () == 256
222+ if values .numel () != 256 :
223+ raise AssertionError
222224
223225 return values
224226
@@ -254,7 +256,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
254256 e = exponent_bits
255257 p = precision_bits
256258 has_sign = 1 if signed else 0
257- assert e + p == total_bits - has_sign
259+ if e + p != total_bits - has_sign :
260+ raise AssertionError
258261 # the exponent is biased to 2^(e-1) -1 == 0
259262 evalues = []
260263 for i , val in enumerate (range (- (2 ** (exponent_bits - has_sign )), 2 ** (exponent_bits - has_sign ), 1 )):
@@ -279,7 +282,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
279282 if signed :
280283 values .append (- value )
281284
282- assert len (values ) == 2 ** total_bits
285+ if len (values ) != 2 ** total_bits :
286+ raise AssertionError
283287 values .sort ()
284288 if total_bits < 8 :
285289 gap = 256 - len (values )
@@ -337,7 +341,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
337341 data .append (0 )
338342 data .append (1.0 )
339343
340- assert len (data ) == 2 ** total_bits
344+ if len (data ) != 2 ** total_bits :
345+ raise AssertionError
341346
342347 gap = 256 - len (data )
343348 for i in range (gap ):
@@ -516,7 +521,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
516521 qs_dict .update (unpack_tensor_to_dict (qs_dict .pop (first_qs_key )))
517522
518523 qs_dict = {k .split ("." )[- 1 ]: v for k , v in qs_dict .items ()} # strip prefixes
519- assert set (qs_dict .keys ()).issubset (cls .valid_qs_keys )
524+ if not set (qs_dict .keys ()).issubset (cls .valid_qs_keys ):
525+ raise AssertionError
520526
521527 if "nested_absmax" in qs_dict :
522528 offset = torch .tensor (float (qs_dict ["nested_offset" ])).to (device )
@@ -721,7 +727,8 @@ def dequantize_blockwise(
721727 The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
722728 """
723729
724- assert quant_state is not None or absmax is not None
730+ if quant_state is None and absmax is None :
731+ raise AssertionError
725732 if code is None and quant_state is None :
726733 if "dynamic" not in name2qmap :
727734 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
@@ -842,7 +849,8 @@ def get_4bit_type(typename, device=None, blocksize=64):
842849 data = torch .tensor (data , device = device )
843850 data .div_ (data .abs ().max ())
844851
845- assert data .numel () == 16
852+ if data .numel () != 16 :
853+ raise AssertionError
846854
847855 return data
848856
@@ -1009,7 +1017,8 @@ def dequantize_4bit(
10091017 blocksize = 64
10101018
10111019 if quant_state is None :
1012- assert absmax is not None and out is not None
1020+ if absmax is None or out is None :
1021+ raise AssertionError
10131022
10141023 quant_state = QuantState (
10151024 absmax = absmax ,
@@ -1365,7 +1374,8 @@ def igemm(
13651374 ldc = sB [1 ]
13661375 elif len (sB ) == 3 :
13671376 # special case
1368- assert len (sA ) == 3
1377+ if len (sA ) != 3 :
1378+ raise AssertionError
13691379 if not (sA [0 ] == sB [0 ] and sA [1 ] == sB [1 ]):
13701380 raise ValueError (
13711381 f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: { sA } x { sB } " ,
@@ -1658,10 +1668,13 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
16581668 unpacked_w [::2 ] = qweight >> 4
16591669 qweight_final = unpacked_w .reshape (quant_state .shape ).to (torch .uint8 ) # (*, N, K)
16601670 # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
1661- assert len (qweight_final .shape ) == 2
1671+ if len (qweight_final .shape ) != 2 :
1672+ raise AssertionError
16621673 N , K = qweight_final .shape [0 ], qweight_final .shape [1 ]
1663- assert N % block_n == 0 , "N must be divisible by block_n"
1664- assert K % 2 == 0 , "K must be even"
1674+ if N % block_n != 0 :
1675+ raise AssertionError ("N must be divisible by block_n" )
1676+ if K % 2 != 0 :
1677+ raise AssertionError ("K must be even" )
16651678 BLOCK_N = block_n
16661679 BIT_COUNT = 32 # (=32 low +32 high)
16671680 new_shape = [N // BLOCK_N , BLOCK_N , K // 2 , 2 ]
@@ -1706,18 +1719,23 @@ def _convert_weight_packed_for_cpu_inverse(
17061719 qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
17071720 recovered_state: QuantState with partially restored fields (best-effort inverse)
17081721 """
1709- assert quant_state .packing_format_for_cpu , "only for packing format"
1710- assert packed_weight .dtype == torch .uint8
1711- assert len (packed_weight .shape ) == 2 , "packed_weight should be [N, K/2]"
1722+ if not quant_state .packing_format_for_cpu :
1723+ raise AssertionError ("only for packing format" )
1724+ if packed_weight .dtype != torch .uint8 :
1725+ raise AssertionError
1726+ if len (packed_weight .shape ) != 2 :
1727+ raise AssertionError ("packed_weight should be [N, K/2]" )
17121728 N , K_half = packed_weight .shape
17131729 K = K_half * 2
17141730
17151731 # 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
17161732 BLOCK_N = block_n
17171733 BIT_COUNT = 32 # (=32 low + 32 high)
17181734
1719- assert N % BLOCK_N == 0 , "N must be divisible by block_n"
1720- assert K % 2 == 0 , "K must be even"
1735+ if N % BLOCK_N != 0 :
1736+ raise AssertionError ("N must be divisible by block_n" )
1737+ if K % 2 != 0 :
1738+ raise AssertionError ("K must be even" )
17211739
17221740 # [N, K/2] -> [-1, 64] (32 low + 32 high)
17231741 packed = packed_weight .reshape (- 1 , BIT_COUNT ) # [-1, 64]
0 commit comments