-
Notifications
You must be signed in to change notification settings - Fork 128
add lagkv_press #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add lagkv_press #77
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4984886
add lagkv_press
JoelSeniorLiang 89f8698
Add signed off.
JoelSeniorLiang b9aef80
add lagkv to __init__
JoelSeniorLiang d0bf4f6
add a no compression switch
JoelSeniorLiang 0e0ac52
Merge branch 'NVIDIA:main' into lagkv
JoelSeniorLiang b377bd0
update for minor changes and style fixs
JoelSeniorLiang 59d467c
fix doc
JoelSeniorLiang 99f39e6
update no compression behavior
JoelSeniorLiang 96d9dce
Remove the overrided compress method
JoelSeniorLiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,58 +1,60 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from kvpress.attention_patch import patch_attention_functions | ||
| from kvpress.pipeline import KVPressTextGenerationPipeline | ||
| from kvpress.presses.adakv_press import AdaKVPress | ||
| from kvpress.presses.base_press import BasePress | ||
| from kvpress.presses.chunk_press import ChunkPress | ||
| from kvpress.presses.chunkkv_press import ChunkKVPress | ||
| from kvpress.presses.composed_press import ComposedPress | ||
| from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress | ||
| from kvpress.presses.duo_attention_press import DuoAttentionPress | ||
| from kvpress.presses.expected_attention_press import ExpectedAttentionPress | ||
| from kvpress.presses.key_rerotation_press import KeyRerotationPress | ||
| from kvpress.presses.knorm_press import KnormPress | ||
| from kvpress.presses.observed_attention_press import ObservedAttentionPress | ||
| from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress | ||
| from kvpress.presses.random_press import RandomPress | ||
| from kvpress.presses.scorer_press import ScorerPress | ||
| from kvpress.presses.simlayerkv_press import SimLayerKVPress | ||
| from kvpress.presses.snapkv_press import SnapKVPress | ||
| from kvpress.presses.streaming_llm_press import StreamingLLMPress | ||
| from kvpress.presses.think_press import ThinKPress | ||
| from kvpress.presses.tova_press import TOVAPress | ||
| from kvpress.presses.qfilter_press import QFilterPress | ||
| from kvpress.presses.pyramidkv_press import PyramidKVPress | ||
| from kvpress.presses.finch_press import FinchPress | ||
|
|
||
| # Patch the attention functions to support head-wise compression | ||
| patch_attention_functions() | ||
|
|
||
| __all__ = [ | ||
| "CriticalAdaKVPress", | ||
| "CriticalKVPress", | ||
| "AdaKVPress", | ||
| "BasePress", | ||
| "ComposedPress", | ||
| "ScorerPress", | ||
| "ExpectedAttentionPress", | ||
| "KnormPress", | ||
| "ObservedAttentionPress", | ||
| "RandomPress", | ||
| "SimLayerKVPress", | ||
| "SnapKVPress", | ||
| "StreamingLLMPress", | ||
| "ThinKPress", | ||
| "TOVAPress", | ||
| "KVPressTextGenerationPipeline", | ||
| "PerLayerCompressionPress", | ||
| "KeyRerotationPress", | ||
| "ChunkPress", | ||
| "DuoAttentionPress", | ||
| "ChunkKVPress", | ||
| "QFilterPress", | ||
| "PyramidKVPress", | ||
| "FinchPress", | ||
| ] | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from kvpress.attention_patch import patch_attention_functions | ||
| from kvpress.pipeline import KVPressTextGenerationPipeline | ||
| from kvpress.presses.adakv_press import AdaKVPress | ||
| from kvpress.presses.base_press import BasePress | ||
| from kvpress.presses.chunk_press import ChunkPress | ||
| from kvpress.presses.chunkkv_press import ChunkKVPress | ||
| from kvpress.presses.composed_press import ComposedPress | ||
| from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress | ||
| from kvpress.presses.duo_attention_press import DuoAttentionPress | ||
| from kvpress.presses.expected_attention_press import ExpectedAttentionPress | ||
| from kvpress.presses.key_rerotation_press import KeyRerotationPress | ||
| from kvpress.presses.knorm_press import KnormPress | ||
| from kvpress.presses.observed_attention_press import ObservedAttentionPress | ||
| from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress | ||
| from kvpress.presses.random_press import RandomPress | ||
| from kvpress.presses.scorer_press import ScorerPress | ||
| from kvpress.presses.simlayerkv_press import SimLayerKVPress | ||
| from kvpress.presses.snapkv_press import SnapKVPress | ||
| from kvpress.presses.streaming_llm_press import StreamingLLMPress | ||
| from kvpress.presses.think_press import ThinKPress | ||
| from kvpress.presses.tova_press import TOVAPress | ||
| from kvpress.presses.qfilter_press import QFilterPress | ||
| from kvpress.presses.pyramidkv_press import PyramidKVPress | ||
| from kvpress.presses.finch_press import FinchPress | ||
| from kvpress.presses.lagkv_press import LagKVPress | ||
|
|
||
| # Patch the attention functions to support head-wise compression | ||
| patch_attention_functions() | ||
|
|
||
| __all__ = [ | ||
| "CriticalAdaKVPress", | ||
| "CriticalKVPress", | ||
| "AdaKVPress", | ||
| "BasePress", | ||
| "ComposedPress", | ||
| "ScorerPress", | ||
| "ExpectedAttentionPress", | ||
| "KnormPress", | ||
| "ObservedAttentionPress", | ||
| "RandomPress", | ||
| "SimLayerKVPress", | ||
| "SnapKVPress", | ||
| "StreamingLLMPress", | ||
| "ThinKPress", | ||
| "TOVAPress", | ||
| "KVPressTextGenerationPipeline", | ||
| "PerLayerCompressionPress", | ||
| "KeyRerotationPress", | ||
| "ChunkPress", | ||
| "DuoAttentionPress", | ||
| "ChunkKVPress", | ||
| "QFilterPress", | ||
| "PyramidKVPress", | ||
| "FinchPress", | ||
| "LagKVPress", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| from kvpress.presses.scorer_press import ScorerPress | ||
|
|
||
|
|
||
| @dataclass | ||
| class LagKVPress(ScorerPress): | ||
| """ | ||
| Prune KV pairs with lag-relative information (https://arxiv.org/abs/2504.04704) | ||
|
|
||
| Args: | ||
| n_sink (`int`): | ||
| The number of sink tokens. | ||
| lag_size (`int`): | ||
| The size of the partition. The subsequent partition will serve as a reference for the prior one. | ||
| cross_scoring (`bool`): | ||
| (experimental) if cross scoring is enabled, the score will not be limited to inside partion. | ||
| Since the score is totally normalized, it's possible use it to allocating KV among heads. | ||
| This switch will be useful for Press Wrapper like AdaKVPress. | ||
| """ | ||
| n_sink: int = 4 | ||
| lag_size: int = 128 | ||
| cross_scoring: bool = False | ||
|
|
||
| def score( | ||
| self, | ||
| module: nn.Module, | ||
| hidden_states: torch.Tensor, | ||
| keys: torch.Tensor, | ||
| values: torch.Tensor, | ||
| attentions: torch.Tensor, | ||
| kwargs, | ||
| ) -> torch.Tensor: | ||
| bsz, num_key_value_heads, q_len, d = keys.shape | ||
| if q_len < self.n_sink + 2 * self.lag_size: | ||
| # no compression | ||
| score = torch.ones((bsz, num_key_value_heads, q_len), | ||
| dtype=keys.dtype, device=keys.device) | ||
| if q_len > self.n_sink: | ||
| # make sure the sliding part will be selected. | ||
| score[:, :, self.n_sink:] = (torch.arange(q_len - self.n_sink, device=keys.device) | ||
| / (q_len - self.n_sink) | ||
| ).to(keys.dtype) | ||
| return score | ||
|
|
||
| end_idx = self.n_sink + ((q_len - self.n_sink) // self.lag_size) * self.lag_size | ||
| tail_len = self.lag_size + q_len - end_idx | ||
|
|
||
| key_score = self._get_states_score( | ||
| keys[:, :, self.n_sink:end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d)) | ||
| value_score = self._get_states_score( | ||
| values[:, :, self.n_sink:end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d)) | ||
| # score is in range [0, 1] | ||
| score = (key_score + value_score) / 2 | ||
JoelSeniorLiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if not self.cross_scoring: | ||
alessiodevoto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| score = score.argsort(dim=-1).argsort(dim=-1) / self.lag_size | ||
| score = score.to(keys.dtype) | ||
| # the parts should always keep | ||
| sink_shape = (bsz, num_key_value_heads, self.n_sink) | ||
| sink_score = torch.ones(sink_shape, dtype=score.dtype, device=score.device) | ||
| tail_shape = (bsz, num_key_value_heads, tail_len) | ||
| tail_score = torch.ones(tail_shape, dtype=score.dtype, device=score.device) | ||
| score = torch.cat((sink_score, score.reshape(bsz, num_key_value_heads, -1), tail_score), dim=-1) | ||
| return score | ||
|
|
||
| def _get_states_score(self, target_v): | ||
| """evaluate the scores of keys and values for each token""" | ||
| ref = target_v[:, :, 1:, :, :] | ||
| v = target_v[:, :, :-1, :, :] | ||
| # lag-relative information | ||
| min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) | ||
| max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) | ||
|
|
||
| score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1) | ||
| return score | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.