Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from pathlib import Path
from typing import Optional, Any
from typing import Any, Optional

import torch
from datasets import load_dataset
Expand Down
9 changes: 4 additions & 5 deletions kvpress/presses/criticalkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
logger = logging.getLogger(__name__)


@dataclass
class CriticalKVPress(ScorerPress):
"""
CriticalKV: Two-stage compression with output projection weighting.
Expand All @@ -35,11 +34,11 @@ class CriticalKVPress(ScorerPress):
Remaining budget used in second stage with output projection weighting.
"""

press: ScorerPress = None
epsilon: float = 1e-4
first_stage_ratio: float = 0.5
def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio: float = 0.5):
self.press = press
self.epsilon = epsilon
self.first_stage_ratio = first_stage_ratio

def __post_init__(self):
assert isinstance(self.press, ScorerPress), "CriticalKVPress requires a ScorerPress as input"
if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm:
logger.warning("use_vnorm should be disabled for CriticalKVPress")
Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/finch_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch.nn import functional as F

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.snapkv_press import SnapKVPress


@dataclass
Expand Down
12 changes: 9 additions & 3 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class KeyRerotationPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress)

@property
def compression_ratio(self):
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value

@staticmethod
def _rerotate_cos_sin(x, inv_freq, selected_positions):
"""
Expand Down Expand Up @@ -108,9 +116,7 @@ def rerotate_keys(
The rerotated keys tensor of shape
``(bsz, num_heads, n_kept, d)``.
"""
new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(keys,
module.rotary_emb.inv_freq,
indices)
new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(keys, module.rotary_emb.inv_freq, indices)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()
return (keys * new_cos) + (rotate_half(keys) * new_sin)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Alessio Devoto", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.7"
version = "0.2.8"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down
9 changes: 2 additions & 7 deletions tests/presses/test_key_rerotation_press_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@
# variants with the smallest possible code changes.

import inspect
from dataclasses import dataclass
from copy import deepcopy
from dataclasses import dataclass

import pytest
import torch
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaRotaryEmbedding,
rotate_half,
)
from transformers import Gemma3ForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaRotaryEmbedding, rotate_half

from kvpress import KeyRerotationPress, ScorerPress
from tests.fixtures import unit_test_model # noqa: F401
Expand Down
27 changes: 16 additions & 11 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,24 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
press = cls(**kwargs)
if isinstance(wrapper_press, ComposedPress):
press = ComposedPress(presses=[press])
if isinstance(wrapper_press, KeyRerotationPress):
press = KeyRerotationPress(press=press)
if isinstance(wrapper_press, (AdaKVPress, CriticalKVPress, CriticalAdaKVPress)):
if isinstance(press, ScorerPress):
press = wrapper_press(press=press)
else:
if wrapper_press is not None:
if hasattr(press, "__post_init_from_model__"):
# TODO: Handle __post_init_from_model__ in wrapper presses
return
if issubclass(wrapper_press, ComposedPress):
press = ComposedPress(presses=[press])
elif not isinstance(press, ScorerPress): # remaining wrapper presses only support ScorerPress
return
if isinstance(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=2)
elif issubclass(wrapper_press, (KeyRerotationPress, AdaKVPress, CriticalKVPress, CriticalAdaKVPress)):
press = wrapper_press(press=press)
elif issubclass(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=24)

# TODO: Handle __post_init_from_model__ differently
if hasattr(press, "__post_init_from_model__"):
press.__post_init_from_model__(unit_test_model)
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
input_ids = torch.randint(0, 1024, (1, 128))
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
# Check that the press has a compression_ratio attribute
assert hasattr(press, "compression_ratio")
Expand Down