Skip to content

Commit 68f8ef8

Browse files
authored
Update DMSPress and KVzapPress (#177)
1 parent 8b3c2f7 commit 68f8ef8

File tree

9 files changed

+22
-33
lines changed

9 files changed

+22
-33
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,19 @@ jobs:
1111
runs-on: linux-amd64-gpu-l4-latest-1
1212
steps:
1313
- uses: actions/checkout@v3
14-
- name: Setup Python
15-
uses: actions/setup-python@v4
16-
with:
17-
python-version: 3.10.11
18-
19-
- name: Setup CUDA
20-
uses: Jimver/cuda-toolkit@v0.2.16
21-
with:
22-
cuda: '12.5.0'
23-
24-
- name: Set CUDA_HOME
25-
run: echo "CUDA_HOME=/usr/local/cuda" >> $GITHUB_ENV
26-
14+
- name: Verify environment
15+
run: |
16+
nvidia-smi
17+
python3 --version
2718
- name: Install uv
2819
uses: astral-sh/setup-uv@v6
2920
with:
3021
enable-cache: true
3122

3223
- name: Install dependencies
33-
run: uv sync --all-groups
34-
24+
run: |
25+
uv sync --all-groups
26+
uv pip install torch==2.10
3527
- run: make test
3628
env:
3729
HF_TOKEN: ${{ secrets.HF_TOKEN }}

Makefile

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,14 @@ reports:
4141

4242
.PHONY: test
4343
test: reports
44-
$(UV) pip install optimum-quanto
45-
$(UV) pip install flash-attn
44+
$(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12
4645
PYTHONPATH=. \
4746
$(UV) run pytest \
4847
--cov-report xml:reports/coverage.xml \
4948
--cov=kvpress/ \
5049
--junitxml=./reports/junit.xml \
5150
-v \
5251
tests/ | tee reports/pytest_output.log
53-
@if grep -q "SKIPPED" reports/pytest_output.log; then \
54-
echo "Error: Tests were skipped. All tests must run."; \
55-
grep "SKIPPED" reports/pytest_output.log; \
56-
exit 1; \
57-
fi
5852
@if grep -q "FAILED" reports/pytest_output.log; then \
5953
echo "Error: Some tests failed."; \
6054
grep "FAILED" reports/pytest_output.log; \

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses:
150150
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
151151
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
152152
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
153-
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
153+
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill.
154154

155155
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
156156

kvpress/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def generate_answer(
311311
generated_ids.append(new_id)
312312
if new_id.item() in should_stop_token_ids:
313313
break
314-
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)
314+
answer = str(self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True))
315315
return answer
316316

317317
def postprocess(self, model_outputs, single_question):

kvpress/presses/dms_press.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class DMSPress(BasePress):
1717
"""
1818
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
1919
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.
20+
This press implements a dense-prefill version of DMS, not the sparse-prefill version.
2021
2122
Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
2223
to determine which KV pairs to evict. This allows for adaptive compression where the actual

kvpress/presses/kvzap_press.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class KVzapModel(PreTrainedModel):
2424

2525
def __init__(self, config):
2626
super().__init__(config)
27+
self.all_tied_weights_keys = {}
2728
if config.hidden_dim is None:
2829
# Linear model
2930
self.layers = nn.ModuleList(
@@ -72,8 +73,7 @@ def score(
7273
attentions: torch.Tensor,
7374
kwargs: dict,
7475
) -> torch.Tensor:
75-
module = self.kvzap_model.layers[module.layer_idx]
76-
module = module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
77-
with torch.no_grad():
78-
scores = module(hidden_states).transpose(1, 2)
76+
kvzap_module = self.kvzap_model.layers[module.layer_idx]
77+
kvzap_module = kvzap_module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
78+
scores = kvzap_module(hidden_states).transpose(1, 2)
7979
return scores

kvzap/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _forward_hook(self, module, input, kwargs, output):
201201
scale = scale.repeat_interleave(module.o_proj.block_size[0], dim=0)
202202
scale = scale.repeat_interleave(module.o_proj.block_size[1], dim=1)
203203
Wo = Wo.to(V.dtype) * scale
204-
Wo = Wo.view(module.config.num_attention_heads, module.head_dim, module.config.hidden_size)
204+
Wo = Wo.view(module.config.num_attention_heads, V.shape[-1], module.config.hidden_size)
205205
WoV_norm = torch.einsum("h i j, b h t i -> b h t j", Wo.to(dtype=V.dtype), V).norm(dim=-1)
206206
scores = torch.einsum("b h t i, b h i -> b h t i", scores, WoV_norm)
207207

kvzap/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ def train_linear(X: torch.Tensor, y: torch.Tensor) -> KVzapModel:
115115
KVzapConfig(input_dim=X.shape[2], hidden_dim=None, output_dim=y.shape[2], n_modules=X.shape[1])
116116
)
117117
for layer_idx, (W, b) in enumerate(params):
118-
linear_model.layers[layer_idx].weight.data = torch.tensor(W, dtype=X.dtype) # type: ignore[index]
119-
linear_model.layers[layer_idx].bias.data = torch.tensor(b, dtype=X.dtype) # type: ignore[index]
118+
W = torch.tensor(np.atleast_2d(W), dtype=X.dtype)
119+
b = torch.tensor(np.atleast_1d(b), dtype=X.dtype)
120+
linear_model.layers[layer_idx].weight.data = W # type: ignore[index]
121+
linear_model.layers[layer_idx].bias.data = b # type: ignore[index]
120122
return linear_model
121123

122124

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "kvpress"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
description = "Efficiently compress the KV cache of any pretrained transformer"
55
authors = [
66
{ name = "Simon Jegou" },
@@ -15,7 +15,7 @@ dependencies = [
1515
"torch>=2.3.1,<3",
1616
# transformers<4.54 is not supported due to refactoring of the transformers library.
1717
# transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers
18-
"transformers>=4.56",
18+
"transformers>=4.56,<5.0.0",
1919
"sentencepiece>=0.2.0,<0.3",
2020
"protobuf>=5.27.2,<6",
2121
"datasets>=2.21.0,<3",

0 commit comments

Comments
 (0)