Motivation.
bitsandbytes and GGUF are two quantization/format backends in vLLM that see very low usage relative to the maintenance burden they impose (roughly 0.5% and 0.1% respectively from what I can tell).
Both predate the current weight loading architecture (weight_loader_v2) and have not been migrated to it. They inject conditional branches throughout the critical weight-loading path in shared code (linear.py, fused_moe/layer.py, vocab_parallel_embedding.py) in ways that make the codebase harder to maintain and refactor.
In addition, performance is not great when using these methods, with users often citing running GGUF models with llamacpp to be faster due to different priorities wrt bs=1 performance on consumer GPUs.
This RFC proposes deprecating both backends and eventually removing them, to simplify the core weight loading infrastructure and unblock further cleanup.
If we were to choose one over the other, I think removing GGUF would take priority due to the greater usage of BNB. Another option is to propose moving these methods to be OOT quantization plugins, but I doubt the feasibility due to the current need to modify internal structures in vLLM.
Summary
|
bitsandbytes |
GGUF |
| Dedicated Python |
~1,426 lines |
~1,464 lines |
| CUDA kernels |
0 |
~6,000 lines |
| Shared code branches |
~95 lines in 6 locations |
~75 lines in 5 locations |
weight_loader_v2 |
not supported |
not supported |
| TP support |
limited (pre-quant doesn't work) |
full |
| CUDA graph support |
8-bit forces eager |
full |
| External dep |
bitsandbytes pip package |
gguf pip package |
| Model-specific hacks |
3 models |
8+ models |
Both formats add ~3,100 lines of dedicated Python, ~170 lines of branching in shared weight loading code, and block migration to weight_loader_v2. GGUF additionally carries ~6,000 lines of CUDA kernels.
The primary benefit of removal isn't the line count; it's making linear.py's weight loading methods readable and refactorable again, and unblocking the weight_loader_v2 migration.
Codebase cost
Dedicated files
These are self-contained and could be deleted as units:
| File |
Lines |
Purpose |
quantization/bitsandbytes.py |
609 |
Config, LinearMethod (4bit/8bit), MoEMethod |
model_loader/bitsandbytes_loader.py |
817 |
Full model loader with TP sharding, quant state mgmt, on-the-fly quantization |
quantization/gguf.py |
691 |
Config, LinearMethod, MoEMethod, EmbeddingMethod, kernel dispatch |
model_loader/gguf_loader.py |
437 |
Model loader, GGUF file discovery, tensor name mapping |
transformers_utils/gguf_utils.py |
336 |
GGUF detection, remote download, config patching |
| Total |
~2,890 |
|
Also ~6,000 lines of GGUF-specific CUDA kernels in csrc/quantization/gguf/ (a partial port of ggml ops).
Conditional branches in shared code
This is the real problem. Both formats add if branches in the hot path of weight loading that every other quantization method has to read around.
linear.py — the worst offender
bitsandbytes adds branches in 6 locations (~95 lines):
adjust_bitsandbytes_4bit_shard() — a top-level helper that only exists for bnb
ColumnParallelLinear.weight_loader — overloads is_sharded_weight with use_bitsandbytes_4bit
MergedColumnParallelLinear.weight_loader — builds an offsets dict and calls adjust_bitsandbytes_4bit_shard(), duplicated for both the fused and per-shard paths
QKVParallelLinear.weight_loader — same pattern again, duplicated for both paths
RowParallelLinear.weight_loader — overloads is_sharded_weight again
The bnb pattern is essentially copy-pasted 4 times: build an offsets dict mapping shard IDs to original sizes, call adjust_bitsandbytes_4bit_shard() to recompute the offset in packed uint8 space.
GGUF adds branches in 5 locations (~75 lines):
ReplicatedLinear.weight_loader — is_gguf_weight / is_gguf_weight_type checks + materialize UninitializedParameter
ColumnParallelLinear.weight_loader — same pattern
MergedColumnParallelLinear.weight_loader — weight type dict, shard_id tracking, data_container append
QKVParallelLinear.weight_loader — same with q/k/v index map
RowParallelLinear.weight_loader — same materialize pattern
GGUF uses UninitializedParameter + a data_container list + shard_id_map — a lazy-init approach that forces every weight_loader to have special materialization logic.
fused_moe/layer.py
The weight_loader method has two early-return blocks before the normal loading path:
- GGUF (~10 lines):
is_gguf_weight_type check + UninitializedParameter materialization for MoE experts
- bnb (~35 lines): flat-packed BNB tensor handling with special w1/w2/w3 logic
vocab_parallel_embedding.py
- GGUF:
is_gguf_weight_type direct copy in weight_loader, bypassing normal shard logic
- GGUF:
tie_weights() returns embed_tokens instead of self because quantized embeddings can't share raw weight tensors
config/model.py
_verify_bnb_config(): 25 lines to force eager mode because bnb 8-bit doesn't support CUDA graphs
engine/arg_utils.py
- Auto-detection overrides for both formats:
if is_gguf(self.model): self.quantization = self.load_format = "gguf" and the equivalent for bnb
Neither supports weight_loader_v2
linear.py has a WEIGHT_LOADER_V2_SUPPORTED allowlist. Neither BitsAndBytesLinearMethod nor GGUFLinearMethod is on it — they both use the legacy weight_loader path. This means any effort to migrate the codebase to the cleaner v2 API has to keep the old code path alive for these two backends.
Additional GGUF-specific complexity
gguf_loader.py instantiates a dummy HuggingFace model on meta device to extract parameter names for tensor mapping (lines 219-227). This is fragile and breaks when HF model classes change.
- The loader has ~70 lines of hardcoded model-type name remapping (deepseek_v2/v3, qwen2/3_moe, minimax_m2, cohere, gemma3) that must be updated for each new MoE architecture.
transformers_utils/gguf_utils.py adds config patching (maybe_patch_hf_config_from_gguf) and tokenizer extraction from the GGUF container.
- ~8 model files (llama, llama4, gemma3, exaone, etc.) have GGUF-specific RoPE style detection branches.
Additional bnb-specific complexity
bitsandbytes_loader.py has its own TP sharding logic in _unquantized_generator (110 lines) that reimplements what the linear layer weight loaders already do.
- The loader attaches runtime state as parameter attributes (
bnb_quant_state, bnb_shard_offsets, matmul_state) which the quantization method reads during inference. This attribute-passing pattern is unique to bnb and forces checks in every weight loading path.
- MoE quant state fusion (
_fuse_moe_quant_states, 80 lines) manually merges per-expert quant states into fused w13/w2 representations.
- Pre-quantized bnb models don't work with tensor parallelism at all (hard error at line 551-555).
Proposed Change.
linear.py weight_loader cleanup
Remove ~170 lines of conditional branching across the 4 parallel linear classes. The weight_loader methods become straightforward: determine output/input dim, narrow, copy. No more adjust_bitsandbytes_4bit_shard(), no more UninitializedParameter materialization, no more data_container tracking.
This is the biggest win — these methods are read and modified by anyone working on a new quantization backend, and the bnb/GGUF branches are confusing because they work completely differently from every other quant method.
weight_loader_v2 migration
With bnb and GGUF gone, the legacy weight_loader path could potentially be removed entirely (or at least simplified), since the remaining quant methods are all on the v2 allowlist or could be migrated.
fused_moe/layer.py simplification
Remove ~45 lines of early-return branches from the weight_loader. The control flow becomes linear.
Model loader factory
Remove 2 of ~6 loader classes. The dispatch logic in model_loader/__init__.py gets simpler.
Config / arg_utils
Remove auto-detection branches, CUDA graph workarounds, and bnb/GGUF-specific validation.
Build system
Drop ~6,000 lines of CUDA kernels from csrc/quantization/gguf/ and the corresponding CMakeLists entry. Faster builds.
Dependencies
Drop bitsandbytes and gguf as pip dependencies.
Feedback Period.
Two weeks
CC List.
@robertgshaw2-redhat @simon-mo @Isotr0py @DarkLight1337
Any Other Things.
No response
Before submitting a new issue...
Motivation.
bitsandbytes and GGUF are two quantization/format backends in vLLM that see very low usage relative to the maintenance burden they impose (roughly 0.5% and 0.1% respectively from what I can tell).
Both predate the current weight loading architecture (
weight_loader_v2) and have not been migrated to it. They inject conditional branches throughout the critical weight-loading path in shared code (linear.py,fused_moe/layer.py,vocab_parallel_embedding.py) in ways that make the codebase harder to maintain and refactor.In addition, performance is not great when using these methods, with users often citing running GGUF models with llamacpp to be faster due to different priorities wrt bs=1 performance on consumer GPUs.
This RFC proposes deprecating both backends and eventually removing them, to simplify the core weight loading infrastructure and unblock further cleanup.
If we were to choose one over the other, I think removing GGUF would take priority due to the greater usage of BNB. Another option is to propose moving these methods to be OOT quantization plugins, but I doubt the feasibility due to the current need to modify internal structures in vLLM.
Summary
weight_loader_v2bitsandbytespip packageggufpip packageBoth formats add ~3,100 lines of dedicated Python, ~170 lines of branching in shared weight loading code, and block migration to
weight_loader_v2. GGUF additionally carries ~6,000 lines of CUDA kernels.The primary benefit of removal isn't the line count; it's making
linear.py's weight loading methods readable and refactorable again, and unblocking theweight_loader_v2migration.Codebase cost
Dedicated files
These are self-contained and could be deleted as units:
quantization/bitsandbytes.pymodel_loader/bitsandbytes_loader.pyquantization/gguf.pymodel_loader/gguf_loader.pytransformers_utils/gguf_utils.pyAlso ~6,000 lines of GGUF-specific CUDA kernels in
csrc/quantization/gguf/(a partial port of ggml ops).Conditional branches in shared code
This is the real problem. Both formats add
ifbranches in the hot path of weight loading that every other quantization method has to read around.linear.py— the worst offenderbitsandbytes adds branches in 6 locations (~95 lines):
adjust_bitsandbytes_4bit_shard()— a top-level helper that only exists for bnbColumnParallelLinear.weight_loader— overloadsis_sharded_weightwithuse_bitsandbytes_4bitMergedColumnParallelLinear.weight_loader— builds an offsets dict and callsadjust_bitsandbytes_4bit_shard(), duplicated for both the fused and per-shard pathsQKVParallelLinear.weight_loader— same pattern again, duplicated for both pathsRowParallelLinear.weight_loader— overloadsis_sharded_weightagainThe bnb pattern is essentially copy-pasted 4 times: build an offsets dict mapping shard IDs to original sizes, call
adjust_bitsandbytes_4bit_shard()to recompute the offset in packed uint8 space.GGUF adds branches in 5 locations (~75 lines):
ReplicatedLinear.weight_loader—is_gguf_weight/is_gguf_weight_typechecks + materializeUninitializedParameterColumnParallelLinear.weight_loader— same patternMergedColumnParallelLinear.weight_loader— weight type dict, shard_id tracking,data_containerappendQKVParallelLinear.weight_loader— same with q/k/v index mapRowParallelLinear.weight_loader— same materialize patternGGUF uses
UninitializedParameter+ adata_containerlist +shard_id_map— a lazy-init approach that forces everyweight_loaderto have special materialization logic.fused_moe/layer.pyThe
weight_loadermethod has two early-return blocks before the normal loading path:is_gguf_weight_typecheck + UninitializedParameter materialization for MoE expertsvocab_parallel_embedding.pyis_gguf_weight_typedirect copy inweight_loader, bypassing normal shard logictie_weights()returnsembed_tokensinstead ofselfbecause quantized embeddings can't share raw weight tensorsconfig/model.py_verify_bnb_config(): 25 lines to force eager mode because bnb 8-bit doesn't support CUDA graphsengine/arg_utils.pyif is_gguf(self.model): self.quantization = self.load_format = "gguf"and the equivalent for bnbNeither supports
weight_loader_v2linear.pyhas aWEIGHT_LOADER_V2_SUPPORTEDallowlist. NeitherBitsAndBytesLinearMethodnorGGUFLinearMethodis on it — they both use the legacyweight_loaderpath. This means any effort to migrate the codebase to the cleaner v2 API has to keep the old code path alive for these two backends.Additional GGUF-specific complexity
gguf_loader.pyinstantiates a dummy HuggingFace model on meta device to extract parameter names for tensor mapping (lines 219-227). This is fragile and breaks when HF model classes change.transformers_utils/gguf_utils.pyadds config patching (maybe_patch_hf_config_from_gguf) and tokenizer extraction from the GGUF container.Additional bnb-specific complexity
bitsandbytes_loader.pyhas its own TP sharding logic in_unquantized_generator(110 lines) that reimplements what the linear layer weight loaders already do.bnb_quant_state,bnb_shard_offsets,matmul_state) which the quantization method reads during inference. This attribute-passing pattern is unique to bnb and forces checks in every weight loading path._fuse_moe_quant_states, 80 lines) manually merges per-expert quant states into fused w13/w2 representations.Proposed Change.
linear.pyweight_loader cleanupRemove ~170 lines of conditional branching across the 4 parallel linear classes. The
weight_loadermethods become straightforward: determine output/input dim, narrow, copy. No moreadjust_bitsandbytes_4bit_shard(), no moreUninitializedParametermaterialization, no moredata_containertracking.This is the biggest win — these methods are read and modified by anyone working on a new quantization backend, and the bnb/GGUF branches are confusing because they work completely differently from every other quant method.
weight_loader_v2migrationWith bnb and GGUF gone, the legacy
weight_loaderpath could potentially be removed entirely (or at least simplified), since the remaining quant methods are all on the v2 allowlist or could be migrated.fused_moe/layer.pysimplificationRemove ~45 lines of early-return branches from the weight_loader. The control flow becomes linear.
Model loader factory
Remove 2 of ~6 loader classes. The dispatch logic in
model_loader/__init__.pygets simpler.Config / arg_utils
Remove auto-detection branches, CUDA graph workarounds, and bnb/GGUF-specific validation.
Build system
Drop ~6,000 lines of CUDA kernels from
csrc/quantization/gguf/and the corresponding CMakeLists entry. Faster builds.Dependencies
Drop
bitsandbytesandggufas pip dependencies.Feedback Period.
Two weeks
CC List.
@robertgshaw2-redhat @simon-mo @Isotr0py @DarkLight1337
Any Other Things.
No response
Before submitting a new issue...