Skip to content

Commit c6f2575

Browse files
fix
1 parent 2de5ec3 commit c6f2575

2 files changed

Lines changed: 36 additions & 8 deletions

File tree

bitsandbytes/backends/cuda/ops.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,20 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor
7474

7575
if has_error:
7676
if has_error == 100:
77-
# `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
78-
# TODO: Warn and implement a fallback to fp32 compute?
79-
raise NotImplementedError("int8_linear_matmul not implemented!")
77+
# `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`. The HIP backend
78+
# also returns this when no usable hipBLASLt algo exists for the shape
79+
# (seen on MI300X for some small-n int8 gemms). Fall back to fp32 — same
80+
# path used for the `lda % 4 != 0` case above.
81+
import warnings
82+
83+
warnings.warn(
84+
f"int8_linear_matmul has no usable (hip|cu)blasLt algo for shape "
85+
f"{shapeA=} {shapeB=}; falling back to fp32 matmul.",
86+
RuntimeWarning,
87+
stacklevel=2,
88+
)
89+
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
90+
return out.copy_(result)
8091
else:
8192
raise RuntimeError(
8293
f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}"

csrc/ops.cu

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,10 @@ int igemmlt(
327327
bnb_blasLtPrefSetAttr(pref, BNB_BLASLT_PREF_MAX_WORKSPACE, &max_workspace_size, sizeof(max_workspace_size))
328328
);
329329

330-
const int request_solutions = 1;
330+
// hipBLASLt's first heuristic algo can be unusable for small-n int8 gemms
331+
// (e.g. n=4 on MI300X) and fails at matmul time with INVALID_VALUE. Request
332+
// several candidates and use the first one that actually runs successfully.
333+
const int request_solutions = 8;
331334
bnb_blasLt_heuristic_t heuristicResult[request_solutions];
332335
int returnedAlgoCount = 0;
333336
checkBlasLtStatus(bnb_blasLtAlgoGetHeuristic(
@@ -340,10 +343,24 @@ int igemmlt(
340343
fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n");
341344
} else {
342345
int alpha = 1, beta = 0;
343-
has_error |= checkBlasLtStatus(bnb_blasLtMatmul(
344-
ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc,
345-
&heuristicResult[0].algo, NULL, 0, stream
346-
));
346+
bnb_blas_status_t matmul_status = BNB_BLAS_STATUS_SUCCESS;
347+
for (int i = 0; i < returnedAlgoCount; ++i) {
348+
matmul_status = bnb_blasLtMatmul(
349+
ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc,
350+
&heuristicResult[i].algo, NULL, 0, stream
351+
);
352+
if (matmul_status == BNB_BLAS_STATUS_SUCCESS)
353+
break;
354+
}
355+
if (matmul_status != BNB_BLAS_STATUS_SUCCESS) {
356+
// Every workspace-free algo hipBLASLt offered failed at runtime
357+
// (seen on MI300X for some small-n int8 gemms). Drain the HIP
358+
// last-error flag the failed launches set, otherwise the next
359+
// unrelated HIP call will inherit it. Then signal the Python
360+
// wrapper to take the fp32 fallback via ERR_NOT_IMPLEMENTED.
361+
(void)hipGetLastError();
362+
return ERR_NOT_IMPLEMENTED;
363+
}
347364
}
348365
#else
349366
int alpha = 1, beta = 0;

0 commit comments

Comments
 (0)