Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `StreamingLLMPress` ([source](kvpress/presses/streaming_llm_press.py), [paper](https://arxiv.org/abs/2309.17453)): keep only the initial and recent tokens
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -81,6 +82,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.


For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

## Evaluation
Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand All @@ -49,4 +50,5 @@
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
]
58 changes: 58 additions & 0 deletions kvpress/presses/qfilter_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from dataclasses import dataclass

import torch
from huggingface_hub import PyTorchModelHubMixin, get_collection

from kvpress.presses.scorer_press import ScorerPress


class QFilters(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, num_layers: int, num_kv_heads: int, kv_head_dim: int):
super().__init__()
self.q_filters = torch.nn.Parameter(torch.randn(num_layers, num_kv_heads, kv_head_dim))

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
return super().from_pretrained(pretrained_model_name_or_path)


@dataclass
class QFilterPress(ScorerPress):
"""
Prune KV pairs with Q-filters
"""

def __post_init_from_model__(self, model):
model_name = model.config.name_or_path.split("/")[-1]
self.q_filters = self.load_q_filters(model_name)
self.q_filters = self.q_filters.to(model.dtype)

@staticmethod
def load_q_filters(model_name):
try:
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters
except TypeError:
raise ValueError(
f"Could not load Q-filters for {model_name}. Available models: {QFilterPress.available_qfilters()}"
)

@staticmethod
def available_qfilters():
collection = get_collection("nthngdy/q-filters-67a4994dcb302a3d37f3d119", token=False)
return [x.item_id.split("/")[-1][:-6] for x in collection.items]

def score(self, module, hidden_states, keys, values, attentions, kwargs):
q_filter = self.q_filters[module.layer_idx][None, :, None]
q_filter = q_filter.to(keys.device)
scores = -(q_filter * keys).sum(dim=-1)
return scores

@contextmanager
def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield
2 changes: 2 additions & 0 deletions tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)


Expand All @@ -31,6 +32,7 @@ def load_attention_pattern(model):
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": SnapKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
Expand Down
8 changes: 8 additions & 0 deletions tests/presses/test_qfilters_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.presses.qfilter_press import QFilterPress


def test_load_qfilters():
for model_name in QFilterPress.available_qfilters():
QFilterPress.load_q_filters(model_name)