22#
33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
5+ from collections .abc import Iterable
56import ctypes as ct
67import itertools
78from math import prod
8- from typing import Any , Dict , Iterable , Optional , Tuple , Union
9+ from typing import Any , Optional , Union
910
1011import numpy as np
1112import torch
@@ -619,7 +620,7 @@ def __get_item__(self, idx):
619620 return list_repr [idx ]
620621
621622 @classmethod
622- def from_dict (cls , qs_dict : Dict [str , Any ], device : torch .device ) -> "QuantState" :
623+ def from_dict (cls , qs_dict : dict [str , Any ], device : torch .device ) -> "QuantState" :
623624 """
624625 unpacks components of state_dict into QuantState
625626 where necessary, convert into strings, torch.dtype, ints, etc.
@@ -741,7 +742,7 @@ def quantize_blockwise(
741742 out : Optional [torch .Tensor ] = None ,
742743 blocksize = 4096 ,
743744 nested = False ,
744- ) -> Tuple [torch .Tensor , QuantState ]:
745+ ) -> tuple [torch .Tensor , QuantState ]:
745746 """Quantize a tensor in blocks of values.
746747
747748 The input tensor is quantized by dividing it into blocks of `blocksize` values.
@@ -994,7 +995,7 @@ def quantize_4bit(
994995 compress_statistics = False ,
995996 quant_type = "fp4" ,
996997 quant_storage = torch .uint8 ,
997- ) -> Tuple [torch .Tensor , QuantState ]:
998+ ) -> tuple [torch .Tensor , QuantState ]:
998999 """Quantize tensor A in blocks of 4-bit values.
9991000
10001001 Quantizes tensor A by dividing it into blocks which are independently quantized.
@@ -1161,7 +1162,7 @@ def quantize(
11611162 A : Tensor ,
11621163 code : Optional [torch .Tensor ] = None ,
11631164 out : Optional [torch .Tensor ] = None ,
1164- ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
1165+ ) -> tuple [Tensor , tuple [Tensor , Tensor ]]:
11651166 if code is None :
11661167 if "dynamic" not in name2qmap :
11671168 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
@@ -1179,7 +1180,7 @@ def quantize(
11791180@deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
11801181def dequantize (
11811182 A : Tensor ,
1182- state : Optional [Tuple [Tensor , Tensor ]] = None ,
1183+ state : Optional [tuple [Tensor , Tensor ]] = None ,
11831184 absmax : Optional [torch .Tensor ] = None ,
11841185 code : Optional [torch .Tensor ] = None ,
11851186 out : Optional [torch .Tensor ] = None ,
@@ -2006,7 +2007,7 @@ def get_colrow_absmax(
20062007 col_stats : Optional [torch .Tensor ] = None ,
20072008 nnz_block_ptr : Optional [torch .Tensor ] = None ,
20082009 threshold = 0.0 ,
2009- ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
2010+ ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
20102011 """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
20112012
20122013 The row-wise and column-wise absmax values are determined.
@@ -2268,9 +2269,9 @@ def spmm_coo(
22682269 out : Optional [torch .Tensor ] = None ,
22692270):
22702271 if not isinstance (cooA , COOSparseTensor ):
2271- assert (
2272- cooA . is_sparse and cooA . layout == torch . sparse_coo
2273- ), "Tensor must be `COOSparseTensor or a PyTorch COO tensor."
2272+ assert cooA . is_sparse and cooA . layout == torch . sparse_coo , (
2273+ "Tensor must be `COOSparseTensor or a PyTorch COO tensor."
2274+ )
22742275
22752276 # Convert to custom COOSparseTensor
22762277 cooA = COOSparseTensor (
0 commit comments