Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
- `PyramidKVPress` ([source](kvpress/presses/pyramidkv_press.py), [paper](https://arxiv.org/abs/2406.02069)): maintain pyramid-like cache sizes, allocating more cache budget to lower layers and less to higher layers
- `LagKVPress` ([source](kvpress/presses/lagkv_press.py), [paper](https://arxiv.org/abs/2504.04704)): leverage on the KV lag-relative information to compress. It's query free, attention-weight free, and flash-attention compatible.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand Down
118 changes: 60 additions & 58 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.finch_press import FinchPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.lagkv_press import LagKVPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
"LagKVPress",
]
83 changes: 83 additions & 0 deletions kvpress/presses/lagkv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.scorer_press import ScorerPress


@dataclass
class LagKVPress(ScorerPress):
"""
Prune KV pairs with lag-relative information (https://arxiv.org/abs/2504.04704)

Args:
n_sink (`int`):
The number of sink tokens.
lag_size (`int`):
The size of the partition. The subsequent partition will serve as a reference for the prior one.
cross_scoring (`bool`):
(experimental) if cross scoring is enabled, the score will not be limited to inside partion.
Since the score is totally normalized, it's possible use it to allocating KV among heads.
This switch will be useful for Press Wrapper like AdaKVPress.
"""
n_sink: int = 4
lag_size: int = 128
cross_scoring: bool = False

def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
bsz, num_key_value_heads, q_len, d = keys.shape
if q_len < self.n_sink + 2 * self.lag_size:
# no compression
score = torch.ones((bsz, num_key_value_heads, q_len),
dtype=keys.dtype, device=keys.device)
if q_len > self.n_sink:
# make sure the sliding part will be selected.
score[:, :, self.n_sink:] = (torch.arange(q_len - self.n_sink, device=keys.device)
/ (q_len - self.n_sink)
).to(keys.dtype)
return score

end_idx = self.n_sink + ((q_len - self.n_sink) // self.lag_size) * self.lag_size
tail_len = self.lag_size + q_len - end_idx

key_score = self._get_states_score(
keys[:, :, self.n_sink:end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d))
value_score = self._get_states_score(
values[:, :, self.n_sink:end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d))
# score is in range [0, 1]
score = (key_score + value_score) / 2

if not self.cross_scoring:
score = score.argsort(dim=-1).argsort(dim=-1) / self.lag_size
score = score.to(keys.dtype)
# the parts should always keep
sink_shape = (bsz, num_key_value_heads, self.n_sink)
sink_score = torch.ones(sink_shape, dtype=score.dtype, device=score.device)
tail_shape = (bsz, num_key_value_heads, tail_len)
tail_score = torch.ones(tail_shape, dtype=score.dtype, device=score.device)
score = torch.cat((sink_score, score.reshape(bsz, num_key_value_heads, -1), tail_score), dim=-1)
return score

def _get_states_score(self, target_v):
"""evaluate the scores of keys and values for each token"""
ref = target_v[:, :, 1:, :, :]
v = target_v[:, :, :-1, :, :]
# lag-relative information
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)

score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
return score
8 changes: 8 additions & 0 deletions tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TOVAPress,
QFilterPress,
PyramidKVPress,
LagKVPress,
)


Expand Down Expand Up @@ -57,4 +58,11 @@ def load_attention_pattern(model):
"cls": PyramidKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
},
{
"cls": LagKVPress,
"kwargs": [
{"compression_ratio": 0.5, "n_sink": 16, "lag_size": 128},
{"compression_ratio": 0.8, "n_sink": 16, "lag_size": 128}
],
},
]