Skip to content

[ROCm] Try multiple hipBLASLt heuristic algos and fall back to fp32 for unsupported int8 shapes#1934

Open
Abdennacer-Badaoui wants to merge 1 commit intobitsandbytes-foundation:mainfrom
Abdennacer-Badaoui:fix-hipblast-algos
Open

[ROCm] Try multiple hipBLASLt heuristic algos and fall back to fp32 for unsupported int8 shapes#1934
Abdennacer-Badaoui wants to merge 1 commit intobitsandbytes-foundation:mainfrom
Abdennacer-Badaoui:fix-hipblast-algos

Conversation

@Abdennacer-Badaoui
Copy link
Copy Markdown
Member

Description

On AMD MI300X (gfx942), igemmlt's HIP path was crashing for certain small-batch int8 GEMM shapes (e.g. m=2048, n=4, k=6144 from BLOOM's QKV projection).

Root cause: With max_workspace_size = 0, hipBLASLt's heuristic returns very few candidate algorithms (3 in our case), and on MI300X all of them fail at hipblasLtMatmul time with HIPBLAS_STATUS_INVALID_VALUE. The code requested only request_solutions = 1 and used the first algo unconditionally, so the failure was unrecoverable.

This PR makes the HIP path resilient:

1. Try multiple algos
Bump request_solutions from 1 to 8 and loop through the heuristic results, accepting the first one that actually runs (bnb_blasLtMatmul returns BNB_BLAS_STATUS_SUCCESS). Most shapes succeed on algo 0, the loop only kicks in for pathological cases.

2. Drain the HIP error flag
Call hipGetLastError() after a fully-failed loop, so the next unrelated HIP call doesn't inherit a stale invalid device ordinal error.

3. Signal Python to fall back to fp32
Return ERR_NOT_IMPLEMENTED (100) when no algo works. The Python wrapper in bitsandbytes/backends/cuda/ops.py previously raised NotImplementedError on code 100; this PR replaces that with the same fp32 torch.matmul fallback already used for lda % 4 != 0, plus a one-time RuntimeWarning. This implements the existing TODO in the file: "Warn and implement a fallback to fp32 compute?"

Note: The CUDA path is untouched; every change is gated by #if BNB_HIP.

Context

Found while validating accelerate end-to-end on 8× MI300X (ROCm 7.1, PyTorch 2.8.0+rocm7.1.0), specifically:

tests/test_quantization.py::MixedInt8EmptyModelTest::test_cpu_gpu_disk_loading_custom_device_map_kwargs
This test exercises bnb 8-bit inference on BLOOM. Without this fix, the int8 path is unusable for any model whose QKV projection produces this shape on MI300X. With it, the test passes — small-n cases silently degrade to fp32 (slower but correct) instead of crashing.

@github-actions
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant