We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 90bbe14 commit 4960932Copy full SHA for 4960932
1 file changed
bitsandbytes/autograd/_functions.py
@@ -84,6 +84,13 @@ def get_inverse_transform_indices(
84
return permuted_tile_indices
85
86
87
+# torch.compiler.is_compiling() is available only in torch >= 2.3
88
+if hasattr(torch.compiler, "is_compiling"):
89
+ _is_compiling = torch.compiler.is_compiling
90
+else:
91
+ _is_compiling = torch._dynamo.is_compiling
92
+
93
94
@deprecated(
95
"This function is deprecated and will be removed in a future release.",
96
category=FutureWarning,
@@ -174,7 +181,7 @@ def forward(
174
181
input_shape = A.shape
175
182
176
183
# Cast A to fp16
177
- if A.dtype != torch.float16:
184
+ if A.dtype != torch.float16 and not _is_compiling():
178
185
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
179
186
180
187
if len(A.shape) == 3:
0 commit comments