@@ -369,9 +369,9 @@ def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
369369 # print(mean(errors))
370370 # print(mean(relerrors))
371371
372- @pytest .mark .parametrize ("hidden_dim" , get_test_dims ( 32 , 256 , n = 2 ) , ids = id_formatter ("hidden_dim" ))
373- @pytest .mark .parametrize ("batch_dim" , get_test_dims ( 16 , 256 , n = 2 ) , ids = id_formatter ("batch_dim" ))
374- @pytest .mark .parametrize ("seq_dim" , get_test_dims ( 16 , 256 , n = 2 ) , ids = id_formatter ("seq_dim" ))
372+ @pytest .mark .parametrize ("hidden_dim" , [ 32 , 256 ] , ids = id_formatter ("hidden_dim" ))
373+ @pytest .mark .parametrize ("batch_dim" , [ 16 , 256 ] , ids = id_formatter ("batch_dim" ))
374+ @pytest .mark .parametrize ("seq_dim" , [ 16 , 256 ] , ids = id_formatter ("seq_dim" ))
375375 @pytest .mark .parametrize ("transpose" , BOOLEAN_TUPLES , ids = id_formatter ("transpose" ))
376376 def test_igemm (self , hidden_dim , batch_dim , transpose , seq_dim ):
377377 hidden_dim = hidden_dim - (hidden_dim % 32 )
@@ -415,9 +415,9 @@ def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
415415
416416 torch .testing .assert_close (out .float (), out2 )
417417
418- @pytest .mark .parametrize ("seq_dim" , get_test_dims ( 32 , 512 , n = 3 ) , ids = id_formatter ("seq_dim" ))
419- @pytest .mark .parametrize ("hidden_dim" , get_test_dims ( 32 , 1024 * 4 , n = 3 ) , ids = id_formatter ("hidden_dim" ))
420- @pytest .mark .parametrize ("batch_dim" , get_test_dims ( 2 , 16 , n = 3 ) , ids = id_formatter ("batch_dim" ))
418+ @pytest .mark .parametrize ("seq_dim" , [ 32 , 256 , 512 ] , ids = id_formatter ("seq_dim" ))
419+ @pytest .mark .parametrize ("hidden_dim" , [ 64 , 1024 , 4096 ] , ids = id_formatter ("hidden_dim" ))
420+ @pytest .mark .parametrize ("batch_dim" , [ 2 , 8 , 16 ] , ids = id_formatter ("batch_dim" ))
421421 def test_dim3_igemm (self , seq_dim , hidden_dim , batch_dim ):
422422 seq_dim = seq_dim - (seq_dim % 32 )
423423 hidden_dim = hidden_dim - (hidden_dim % 32 )
@@ -431,9 +431,9 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
431431
432432 torch .testing .assert_close (out .float (), out2 )
433433
434- @pytest .mark .parametrize ("seq_dim" , get_test_dims ( 32 , 512 , n = 2 ) , ids = id_formatter ("seq_dim" ))
435- @pytest .mark .parametrize ("hidden_dim" , get_test_dims ( 32 , 1024 * 4 , n = 2 ) , ids = id_formatter ("hidden_dim" ))
436- @pytest .mark .parametrize ("batch_dim" , get_test_dims ( 2 , 16 , n = 2 ) , ids = id_formatter ("batch_dim" ))
434+ @pytest .mark .parametrize ("seq_dim" , [ 32 , 512 ] , ids = id_formatter ("seq_dim" ))
435+ @pytest .mark .parametrize ("hidden_dim" , [ 32 , 1024 * 4 ] , ids = id_formatter ("hidden_dim" ))
436+ @pytest .mark .parametrize ("batch_dim" , [ 2 , 16 ] , ids = id_formatter ("batch_dim" ))
437437 @pytest .mark .parametrize ("transpose" , TRUE_FALSE , ids = id_formatter ("transpose" ))
438438 def test_minmax_igemm (self , seq_dim , hidden_dim , batch_dim , transpose ):
439439 def min_max (x ):
@@ -501,10 +501,10 @@ def min_max(x):
501501 assert mean (errs ) < 0.015
502502 assert mean (relerrs ) < 0.3
503503
504- @pytest .mark .parametrize ("dim1" , get_test_dims ( 1 , 64 , n = 2 ) , ids = id_formatter ("dim1" ))
505- @pytest .mark .parametrize ("dim2" , get_test_dims ( 32 , 128 , n = 2 ) , ids = id_formatter ("dim2" ))
506- @pytest .mark .parametrize ("dim3" , get_test_dims ( 32 , 256 , n = 2 ) , ids = id_formatter ("dim3" ))
507- @pytest .mark .parametrize ("dim4" , get_test_dims ( 32 , 256 , n = 2 ) , ids = id_formatter ("dim4" ))
504+ @pytest .mark .parametrize ("dim1" , [ 1 , 64 ] , ids = id_formatter ("dim1" ))
505+ @pytest .mark .parametrize ("dim2" , [ 32 , 128 ] , ids = id_formatter ("dim2" ))
506+ @pytest .mark .parametrize ("dim3" , [ 32 , 256 ] , ids = id_formatter ("dim3" ))
507+ @pytest .mark .parametrize ("dim4" , [ 32 , 256 ] , ids = id_formatter ("dim4" ))
508508 @pytest .mark .parametrize ("transpose" , BOOLEAN_TUPLES , ids = id_formatter ("transpose" ))
509509 def test_ibmm (self , dim1 , dim2 , dim3 , dim4 , transpose ):
510510 dim2 = dim2 - (dim2 % 16 )
@@ -760,8 +760,8 @@ def test_coo_int8_vectorwise_quant(self, dim1, dim2):
760760
761761
762762class TestSpMMFunctional :
763- @pytest .mark .parametrize ("dim1" , get_test_dims ( 1 , 1 * 1024 , n = 2 ) , ids = id_formatter ("dim1" ))
764- @pytest .mark .parametrize ("dim2" , get_test_dims ( 1 , 1 * 1024 , n = 2 ) , ids = id_formatter ("dim2" ))
763+ @pytest .mark .parametrize ("dim1" , [ 256 , 1024 ] , ids = id_formatter ("dim1" ))
764+ @pytest .mark .parametrize ("dim2" , [ 128 , 512 ] , ids = id_formatter ("dim2" ))
765765 @pytest .mark .parametrize ("transposed_B" , TRUE_FALSE , ids = id_formatter ("transposed_B" ))
766766 def test_spmm_coo (self , dim1 , dim2 , transposed_B ):
767767 threshold = 1.5
@@ -1096,37 +1096,34 @@ def test_4bit_quant(self, dtype, quant_type, blocksize):
10961096 assert err .item () < math .log2 (blocksize ) * 8e-2
10971097
10981098 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
1099- def test_4bit_compressed_stats (self , quant_type ):
1100- for blocksize in [128 , 64 ]:
1101- errs1 = []
1102- errs2 = []
1103- for i in range (10 ):
1104- A1 = torch .randn (1024 , 1024 , device = "cuda" ).half ()
1105- q2 , SA2 = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
1106- q3 , SA3 = F .quantize_4bit (A1 , blocksize = blocksize , compress_statistics = True , quant_type = quant_type )
1107- A2 = F .dequantize_4bit (q2 , SA2 , quant_type = quant_type )
1108- A3 = F .dequantize_4bit (q3 , SA3 , quant_type = quant_type )
1109-
1110- err = (A1 - A2 ).abs ().float ()
1111- relerr = (err / (A1 .abs ().float () + 1e-15 )).mean ()
1112- err = err .mean ()
1099+ @pytest .mark .parametrize ("blocksize" , [64 , 128 ], ids = id_formatter ("blocksize" ))
1100+ def test_4bit_compressed_stats (self , quant_type , blocksize ):
1101+ errs1 = []
1102+ errs2 = []
1103+ for i in range (10 ):
1104+ A1 = torch .randn (1024 , 1024 , device = "cuda" ).half ()
1105+ q2 , SA2 = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
1106+ q3 , SA3 = F .quantize_4bit (A1 , blocksize = blocksize , compress_statistics = True , quant_type = quant_type )
1107+ A2 = F .dequantize_4bit (q2 , SA2 , quant_type = quant_type )
1108+ A3 = F .dequantize_4bit (q3 , SA3 , quant_type = quant_type )
11131109
1114- errs1 .append (err .item ())
1110+ err = (A1 - A2 ).abs ().float ()
1111+ relerr = (err / (A1 .abs ().float () + 1e-15 )).mean ()
1112+ err = err .mean ()
11151113
1116- assert err .item () < 0.11
1117- assert relerr .item () < 0.28
1114+ errs1 .append (err .item ())
11181115
1119- err = (A1 - A3 ).abs ().float ()
1120- relerr = (err / (A1 .abs ().float () + 1e-15 )).mean ()
1121- err = err .mean ()
1116+ assert err .item () < 0.11
1117+ assert relerr .item () < 0.28
11221118
1123- errs2 .append (err .item ())
1119+ err = (A1 - A3 ).abs ().float ()
1120+ relerr = (err / (A1 .abs ().float () + 1e-15 )).mean ()
1121+ err = err .mean ()
11241122
1125- assert err .item () < 0.11
1126- assert relerr .item () < 0.28
1123+ errs2 .append (err .item ())
11271124
1128- # print(sum(errs1)/len(errs1), blocksize, quant_type)
1129- # print(sum(errs2)/len(errs2), blocksize, quant_type)
1125+ assert err . item () < 0.11
1126+ assert relerr . item () < 0.28
11301127
11311128 # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
11321129 @pytest .mark .parametrize ("quant_type" , ["nf4" ])
@@ -1169,135 +1166,133 @@ def test_bench_4bit_dequant(self, quant_type):
11691166 [torch .uint8 , torch .float16 , torch .bfloat16 , torch .float32 ],
11701167 ids = describe_dtype ,
11711168 )
1172- def test_gemv_4bit (self , dtype , storage_type , quant_storage , double_quant , kind ):
1173- for dim in [128 , 256 , 512 , 1024 ]:
1174- # for dim in [4*1024]:
1175- # for dim in [1*16]:
1176- errs1 = []
1177- errs2 = []
1178- errs3 = []
1179- relerrs1 = []
1180- relerrs2 = []
1181- relerrs3 = []
1182- max_errs1 = []
1183- max_errs2 = []
1184- max_errs3 = []
1169+ @pytest .mark .parametrize ("dim" , [128 , 256 , 512 , 1024 ], ids = id_formatter ("dim" ))
1170+ def test_gemv_4bit (self , dim , dtype , storage_type , quant_storage , double_quant , kind ):
1171+ errs1 = []
1172+ errs2 = []
1173+ errs3 = []
1174+ relerrs1 = []
1175+ relerrs2 = []
1176+ relerrs3 = []
1177+ max_errs1 = []
1178+ max_errs2 = []
1179+ max_errs3 = []
11851180
1186- for i in range (100 ):
1187- if kind == "fc1" :
1188- A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1189- B = torch .randn (dim * 4 , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1190- elif kind == "fc2" :
1191- A = torch .randn (1 , 4 * dim , dtype = dtype , device = "cuda" )
1192- B = torch .randn (dim , 4 * dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1193- elif kind == "attn" :
1194- A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1195- B = torch .randn (dim , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1196- elif kind == "attn_packed" :
1197- A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1198- B = torch .randn (dim * 3 , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1199-
1200- qB , state = F .quantize_4bit (
1201- B ,
1202- quant_type = storage_type ,
1203- compress_statistics = double_quant ,
1204- quant_storage = quant_storage ,
1205- )
1206- C3 = torch .matmul (A , B .t ())
1207- C2 = F .gemv_4bit (A , qB .t (), state = state )
1208- A .requires_grad = True
1209- C1 = bnb .matmul_4bit (A , qB .t (), state )
1210-
1211- err1 = (C1 - C2 ).abs ().float ()
1212- err2 = (C3 - C2 ).abs ().float ()
1213- err3 = (C3 - C1 ).abs ().float ()
1214-
1215- mag1 = torch .abs (C1 ).float () + 1e-5
1216- mag2 = torch .abs (C3 ).float () + 1e-5
1217- mag3 = torch .abs (C3 ).float () + 1e-5
1218-
1219- relerr1 = err1 / mag1
1220- relerr2 = err2 / mag2
1221- relerr3 = err3 / mag3
1222-
1223- max_err1 = err1 .max ()
1224- max_err2 = err2 .max ()
1225- max_err3 = err3 .max ()
1226-
1227- errs1 .append (err1 .mean ().item ())
1228- errs2 .append (err2 .mean ().item ())
1229- errs3 .append (err3 .mean ().item ())
1230-
1231- relerrs1 .append (relerr1 .mean ().item ())
1232- relerrs2 .append (relerr2 .mean ().item ())
1233- relerrs3 .append (relerr3 .mean ().item ())
1234-
1235- max_errs1 .append (max_err1 .item ())
1236- max_errs2 .append (max_err2 .item ())
1237- max_errs3 .append (max_err3 .item ())
1238-
1239- c = int (C1 .numel () * 0.0014 * (dim / 256 )) + 1
1240-
1241- c = assert_all_approx_close (C1 , C2 , 1e-5 , 0.01 , count = 0 , throw = False )
1242- err1 = sum (errs1 ) / len (errs1 ) / math .sqrt (dim )
1243- err2 = sum (errs2 ) / len (errs2 ) / math .sqrt (dim )
1244- err3 = sum (errs3 ) / len (errs3 ) / math .sqrt (dim )
1245- relerr1 = sum (relerrs1 ) / len (relerrs1 ) / math .sqrt (dim )
1246- relerr2 = sum (relerrs2 ) / len (relerrs2 ) / math .sqrt (dim )
1247- relerr3 = sum (relerrs3 ) / len (relerrs3 ) / math .sqrt (dim )
1248- maxerr1 = sum (max_errs1 ) / len (max_errs1 ) / math .sqrt (dim )
1249- maxerr2 = sum (max_errs2 ) / len (max_errs2 ) / math .sqrt (dim )
1250- maxerr3 = sum (max_errs3 ) / len (max_errs3 ) / math .sqrt (dim )
1251- absratio = err2 / err3
1252- relratio = relerr2 / relerr3
1253- maxratio = relerr2 / relerr3
1254-
1255- # for debugging if the tests fails
1256- #
1257- # print('='*80)
1258- # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1259- # print(C1.flatten()[-20:])
1260- # print(C2.flatten()[-20:])
1261- # print(f'inference vs training abs: {err1}')
1262- # print(f'inference vs training rel: {relerr1}')
1263- # print(f'inference vs training max: {maxerr1}')
1264- # print(f'inference vs training vs torch err ratio abs: {absratio}')
1265- # print(f'inference vs training vs torch err ratio rel: {relratio}')
1266- # print(f'inference vs training vs torch err ratio max: {maxratio}')
1267- if dtype == torch .float16 :
1268- if dim <= 512 :
1269- assert err1 < 7e-5
1270- assert relerr1 < 0.0008
1271- else :
1272- assert err1 < 6e-5
1273- assert relerr1 < 2e-4
1274- assert absratio < 1.005 and absratio > 0.995
1275- assert relratio < 1.005 and relratio > 0.995
1276- assert maxratio < 1.005 and maxratio > 0.995
1277- elif dtype == torch .float32 :
1278- if dim <= 512 :
1279- assert err1 < 5e-8
1280- assert relerr1 < 1e-6
1281- assert maxerr1 < 1e-7
1282- else :
1283- assert err1 < 5e-8
1284- assert relerr1 < 8e-6
1285- assert maxerr1 < 1e-7
1286- assert absratio < 1.005 and absratio > 0.995
1287- assert relratio < 1.005 and relratio > 0.995
1288- assert maxratio < 1.005 and maxratio > 0.995
1289- elif dtype == torch .bfloat16 :
1290- if dim <= 512 :
1291- assert err1 < 6e-4
1292- assert relerr1 < 0.007
1293- assert maxerr1 < 0.015
1294- else :
1295- assert err1 < 2e-4
1296- assert relerr1 < 0.002
1297- assert maxerr1 < 0.0012
1298- assert absratio < 1.005 and absratio > 0.995
1299- assert relratio < 1.04 and relratio > 0.96
1300- assert maxratio < 1.02 and maxratio > 0.98
1181+ for i in range (100 ):
1182+ if kind == "fc1" :
1183+ A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1184+ B = torch .randn (dim * 4 , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1185+ elif kind == "fc2" :
1186+ A = torch .randn (1 , 4 * dim , dtype = dtype , device = "cuda" )
1187+ B = torch .randn (dim , 4 * dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1188+ elif kind == "attn" :
1189+ A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1190+ B = torch .randn (dim , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1191+ elif kind == "attn_packed" :
1192+ A = torch .randn (1 , dim , dtype = dtype , device = "cuda" )
1193+ B = torch .randn (dim * 3 , dim , dtype = dtype , device = "cuda" ) / math .sqrt (dim )
1194+
1195+ qB , state = F .quantize_4bit (
1196+ B ,
1197+ quant_type = storage_type ,
1198+ compress_statistics = double_quant ,
1199+ quant_storage = quant_storage ,
1200+ )
1201+ C3 = torch .matmul (A , B .t ())
1202+ C2 = F .gemv_4bit (A , qB .t (), state = state )
1203+ A .requires_grad = True
1204+ C1 = bnb .matmul_4bit (A , qB .t (), state )
1205+
1206+ err1 = (C1 - C2 ).abs ().float ()
1207+ err2 = (C3 - C2 ).abs ().float ()
1208+ err3 = (C3 - C1 ).abs ().float ()
1209+
1210+ mag1 = torch .abs (C1 ).float () + 1e-5
1211+ mag2 = torch .abs (C3 ).float () + 1e-5
1212+ mag3 = torch .abs (C3 ).float () + 1e-5
1213+
1214+ relerr1 = err1 / mag1
1215+ relerr2 = err2 / mag2
1216+ relerr3 = err3 / mag3
1217+
1218+ max_err1 = err1 .max ()
1219+ max_err2 = err2 .max ()
1220+ max_err3 = err3 .max ()
1221+
1222+ errs1 .append (err1 .mean ().item ())
1223+ errs2 .append (err2 .mean ().item ())
1224+ errs3 .append (err3 .mean ().item ())
1225+
1226+ relerrs1 .append (relerr1 .mean ().item ())
1227+ relerrs2 .append (relerr2 .mean ().item ())
1228+ relerrs3 .append (relerr3 .mean ().item ())
1229+
1230+ max_errs1 .append (max_err1 .item ())
1231+ max_errs2 .append (max_err2 .item ())
1232+ max_errs3 .append (max_err3 .item ())
1233+
1234+ c = int (C1 .numel () * 0.0014 * (dim / 256 )) + 1
1235+
1236+ c = assert_all_approx_close (C1 , C2 , 1e-5 , 0.01 , count = 0 , throw = False )
1237+ err1 = sum (errs1 ) / len (errs1 ) / math .sqrt (dim )
1238+ err2 = sum (errs2 ) / len (errs2 ) / math .sqrt (dim )
1239+ err3 = sum (errs3 ) / len (errs3 ) / math .sqrt (dim )
1240+ relerr1 = sum (relerrs1 ) / len (relerrs1 ) / math .sqrt (dim )
1241+ relerr2 = sum (relerrs2 ) / len (relerrs2 ) / math .sqrt (dim )
1242+ relerr3 = sum (relerrs3 ) / len (relerrs3 ) / math .sqrt (dim )
1243+ maxerr1 = sum (max_errs1 ) / len (max_errs1 ) / math .sqrt (dim )
1244+ maxerr2 = sum (max_errs2 ) / len (max_errs2 ) / math .sqrt (dim )
1245+ maxerr3 = sum (max_errs3 ) / len (max_errs3 ) / math .sqrt (dim )
1246+ absratio = err2 / err3
1247+ relratio = relerr2 / relerr3
1248+ maxratio = relerr2 / relerr3
1249+
1250+ # for debugging if the tests fails
1251+ #
1252+ # print('='*80)
1253+ # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1254+ # print(C1.flatten()[-20:])
1255+ # print(C2.flatten()[-20:])
1256+ # print(f'inference vs training abs: {err1}')
1257+ # print(f'inference vs training rel: {relerr1}')
1258+ # print(f'inference vs training max: {maxerr1}')
1259+ # print(f'inference vs training vs torch err ratio abs: {absratio}')
1260+ # print(f'inference vs training vs torch err ratio rel: {relratio}')
1261+ # print(f'inference vs training vs torch err ratio max: {maxratio}')
1262+ if dtype == torch .float16 :
1263+ if dim <= 512 :
1264+ assert err1 < 7e-5
1265+ assert relerr1 < 0.0008
1266+ else :
1267+ assert err1 < 6e-5
1268+ assert relerr1 < 2e-4
1269+ assert absratio < 1.005 and absratio > 0.995
1270+ assert relratio < 1.005 and relratio > 0.995
1271+ assert maxratio < 1.005 and maxratio > 0.995
1272+ elif dtype == torch .float32 :
1273+ if dim <= 512 :
1274+ assert err1 < 5e-8
1275+ assert relerr1 < 1e-6
1276+ assert maxerr1 < 1e-7
1277+ else :
1278+ assert err1 < 5e-8
1279+ assert relerr1 < 8e-6
1280+ assert maxerr1 < 1e-7
1281+ assert absratio < 1.005 and absratio > 0.995
1282+ assert relratio < 1.005 and relratio > 0.995
1283+ assert maxratio < 1.005 and maxratio > 0.995
1284+ elif dtype == torch .bfloat16 :
1285+ if dim <= 512 :
1286+ assert err1 < 6e-4
1287+ assert relerr1 < 0.007
1288+ assert maxerr1 < 0.015
1289+ else :
1290+ assert err1 < 2e-4
1291+ assert relerr1 < 0.002
1292+ assert maxerr1 < 0.0012
1293+ assert absratio < 1.005 and absratio > 0.995
1294+ assert relratio < 1.04 and relratio > 0.96
1295+ assert maxratio < 1.02 and maxratio > 0.98
13011296
13021297 @pytest .mark .parametrize ("storage_type" , ["nf4" , "fp4" ], ids = ["nf4" , "fp4" ])
13031298 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = describe_dtype )
@@ -1363,9 +1358,9 @@ def test_managed():
13631358 assert (A == 17 * (2 ** 3 )).sum ().item () == n * n
13641359
13651360
1366- @pytest .mark .parametrize ("dim1" , get_test_dims ( 1 , 64 , n = 1 ) , ids = id_formatter ("dim1" ))
1367- @pytest .mark .parametrize ("dim2" , get_test_dims ( 32 , 128 , n = 1 ) , ids = id_formatter ("dim2" ))
1368- @pytest .mark .parametrize ("dim3" , get_test_dims ( 32 , 256 , n = 1 ) , ids = id_formatter ("dim3" ))
1361+ @pytest .mark .parametrize ("dim1" , [ 32 ] , ids = id_formatter ("dim1" ))
1362+ @pytest .mark .parametrize ("dim2" , [ 64 ] , ids = id_formatter ("dim2" ))
1363+ @pytest .mark .parametrize ("dim3" , [ 128 ] , ids = id_formatter ("dim3" ))
13691364@pytest .mark .deprecated
13701365def test_vector_quant (dim1 , dim2 , dim3 ):
13711366 dim2 = dim2 - (dim2 % 16 )
0 commit comments