Skip to content

add FinchPress#64

Closed
giulio98 wants to merge 45 commits intoNVIDIA:mainfrom
miriam-16:features/finch_press
Closed

add FinchPress#64
giulio98 wants to merge 45 commits intoNVIDIA:mainfrom
miriam-16:features/finch_press

Conversation

@giulio98
Copy link
Copy Markdown
Contributor

@giulio98 giulio98 commented Apr 4, 2025

PR description

This pull request introduces the FinchPress implementation (#59), incorporating chunked forward propagation, normalization of scores, and key re-rotation in alignment with the authors' original specifications.

The provided implementation has been thoroughly validated through extensive testing against the original reference code from the authors, ensuring a precise 1-to-1 mapping in functionality.

Below a plot of FinchPress in comparison with SnapKVPress(w/ question) on RULER-4k.
image

Checklist

  • Tests are failing due to QFilterPress
=================================================== short test summary info ===================================================
FAILED tests/integration/test_ruler.py::test_ruler_is_correct[dynamic-press_dict5] - ValueError: Could not load Q-filters for Meta-Llama-3.1-8B-Instruct. Available models: ['Llama-3.1-8B-Instruct', 'Llama-3....
FAILED tests/integration/test_ruler.py::test_ruler_is_correct[quantized-press_dict5] - ValueError: Could not load Q-filters for Meta-Llama-3.1-8B-Instruct. Available models: ['Llama-3.1-8B-Instruct', 'Llama-3....
==================================== 2 failed, 121 passed, 3 warnings in 76.48s (0:01:16) =====================================
  • Make style has errors

/home/jovyan/conda/envs/kvpress/bin/poetry run flake8 | tee -a reports/flake8_errors.log
The "poetry.dev-dependencies" section is deprecated and will be removed in a future version. Use "poetry.group.dev.dependencies" instead.
/home/jovyan/conda/envs/kvpress/bin/poetry run mypy . --check-untyped-defs | tee -a reports/mypy.log
The "poetry.dev-dependencies" section is deprecated and will be removed in a future version. Use "poetry.group.dev.dependencies" instead.
kvpress/presses/base_press.py:99: error: "QuantizedCache" has no attribute "_dequantize"  [attr-defined]
kvpress/presses/base_press.py:99: error: "QuantizedCache" has no attribute "_quantized_key_cache"  [attr-defined]
kvpress/presses/base_press.py:100: error: "QuantizedCache" has no attribute "_dequantize"  [attr-defined]
kvpress/presses/base_press.py:100: error: "QuantizedCache" has no attribute "_quantized_value_cache"  [attr-defined]
kvpress/presses/base_press.py:108: error: "QuantizedCache" has no attribute "_quantized_key_cache"  [attr-defined]
kvpress/presses/base_press.py:108: error: "QuantizedCache" has no attribute "_quantize"  [attr-defined]
kvpress/presses/base_press.py:108: error: "QuantizedCache" has no attribute "axis_key"  [attr-defined]
kvpress/presses/base_press.py:109: error: "QuantizedCache" has no attribute "_quantized_value_cache"  [attr-defined]
kvpress/presses/base_press.py:109: error: "QuantizedCache" has no attribute "_quantize"  [attr-defined]
kvpress/presses/base_press.py:109: error: "QuantizedCache" has no attribute "axis_value"  [attr-defined]
kvpress/presses/base_press.py:110: error: "QuantizedCache" has no attribute "key_cache"  [attr-defined]
kvpress/presses/base_press.py:111: error: "QuantizedCache" has no attribute "value_cache"  [attr-defined]
kvpress/presses/base_press.py:112: error: "QuantizedCache" has no attribute "_seen_tokens"  [attr-defined]
kvpress/presses/base_press.py:136: error: "PreTrainedModel" has no attribute "model"  [attr-defined]
kvpress/presses/base_press.py:137: error: "PreTrainedModel" has no attribute "model"  [attr-defined]
kvpress/presses/finch_press.py:213: error: "QuantizedCache" has no attribute "_dequantize"  [attr-defined]
kvpress/presses/finch_press.py:213: error: "QuantizedCache" has no attribute "_quantized_key_cache"  [attr-defined]
kvpress/presses/finch_press.py:214: error: "QuantizedCache" has no attribute "_dequantize"  [attr-defined]
kvpress/presses/finch_press.py:214: error: "QuantizedCache" has no attribute "_quantized_value_cache"  [attr-defined]
kvpress/presses/finch_press.py:222: error: "QuantizedCache" has no attribute "_quantized_key_cache"  [attr-defined]
kvpress/presses/finch_press.py:222: error: "QuantizedCache" has no attribute "_quantize"  [attr-defined]
kvpress/presses/finch_press.py:222: error: "QuantizedCache" has no attribute "axis_key"  [attr-defined]
kvpress/presses/finch_press.py:223: error: "QuantizedCache" has no attribute "_quantized_value_cache"  [attr-defined]
kvpress/presses/finch_press.py:223: error: "QuantizedCache" has no attribute "_quantize"  [attr-defined]
kvpress/presses/finch_press.py:223: error: "QuantizedCache" has no attribute "axis_value"  [attr-defined]
kvpress/presses/finch_press.py:224: error: "QuantizedCache" has no attribute "key_cache"  [attr-defined]
kvpress/presses/finch_press.py:225: error: "QuantizedCache" has no attribute "value_cache"  [attr-defined]
kvpress/presses/finch_press.py:226: error: "QuantizedCache" has no attribute "_seen_tokens"  [attr-defined]
kvpress/presses/finch_press.py:236: error: "PreTrainedModel" has no attribute "model"  [attr-defined]
kvpress/presses/finch_press.py:237: error: "PreTrainedModel" has no attribute "model"  [attr-defined]
kvpress/presses/finch_press.py:239: error: "PreTrainedModel" has no attribute "forward"  [attr-defined]
kvpress/presses/finch_press.py:242: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "tuple[Any, ...]")  [assignment]
kvpress/presses/finch_press.py:243: error: "tuple[Any, ...]" has no attribute "pop"  [attr-defined]
kvpress/presses/finch_press.py:244: error: "tuple[Any, ...]" has no attribute "pop"  [attr-defined]
kvpress/presses/finch_press.py:286: error: "PreTrainedModel" has no attribute "model"  [attr-defined]
kvpress/presses/finch_press.py:303: error: "PreTrainedModel" has no attribute "forward"  [attr-defined]
kvpress/presses/finch_press.py:309: error: "PreTrainedModel" has no attribute "forward"  [attr-defined]
kvpress/pipeline.py:87: error: Signature of "preprocess" incompatible with supertype "Pipeline"  [override]
kvpress/pipeline.py:87: note:      Superclass:
kvpress/pipeline.py:87: note:          def preprocess(self, input_: Any, **preprocess_parameters: dict[Any, Any]) -> dict[str, GenericTensor]
kvpress/pipeline.py:87: note:      Subclass:
kvpress/pipeline.py:87: note:          def preprocess(self, context: str, questions: list[str], answer_prefix: str, max_context_length: int) -> Any
kvpress/pipeline.py:111: error: Incompatible types in assignment (expression has type "str | list[int] | list[str] | list[list[int]] | BatchEncoding", variable has type "str")  [assignment]
kvpress/pipeline.py:135: error: Signature of "_forward" incompatible with supertype "Pipeline"  [override]
kvpress/pipeline.py:135: note:      Superclass:
kvpress/pipeline.py:135: note:          def _forward(self, input_tensors: dict[str, GenericTensor], **forward_parameters: dict[Any, Any]) -> ModelOutput
kvpress/pipeline.py:135: note:      Subclass:
kvpress/pipeline.py:135: note:          def _forward(self, input_tensors: dict[str, GenericTensor], max_new_tokens: int = ..., press: BasePress | None = ..., cache: Cache | None = ...) -> Any
kvpress/pipeline.py:174: error: Incompatible types in assignment (expression has type "DynamicCache", variable has type "Cache | None")  [assignment]
kvpress/pipeline.py:211: error: Signature of "postprocess" incompatible with supertype "Pipeline"  [override]
kvpress/pipeline.py:211: note:      Superclass:
kvpress/pipeline.py:211: note:          def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: dict[Any, Any]) -> Any
kvpress/pipeline.py:211: note:      Subclass:
kvpress/pipeline.py:211: note:          def postprocess(self, model_outputs: Any, single_question: Any) -> Any
kvpress/pipeline.py:239: error: "Cache" has no attribute "get_seq_length"  [attr-defined]
kvpress/pipeline.py:272: error: "Cache" has no attribute "key_cache"  [attr-defined]
kvpress/pipeline.py:273: error: "Cache" has no attribute "key_cache"  [attr-defined]
kvpress/pipeline.py:276: error: "Cache" has no attribute "value_cache"  [attr-defined]
kvpress/pipeline.py:277: error: "Cache" has no attribute "value_cache"  [attr-defined]
kvpress/pipeline.py:285: error: "Cache" has no attribute "_quantized_value_cache"  [attr-defined]
kvpress/pipeline.py:286: error: "Cache" has no attribute "_quantized_value_cache"  [attr-defined]
Found 49 errors in 3 files (checked 42 source files)
make: *** [Makefile:28: style] Error 1
  • Copyright header is included
  • All commits are signed-off using git commit -s
  • (new press) mypress_press.py is in the presses directory
  • (new press) MyPress is in __init__.py
  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section
  • (new press) new press is in the default_presses list in tests/default_presses.py

FaureElia and others added 30 commits April 4, 2025 15:29
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Co-authored-by: miriam-16 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
…inch_attention), introduce context/question handling in pipeline.

Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
Signed-off-by: giulio98 <[email protected]>
@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 10, 2025

Hi @giulio98,

Thanks for your contribution. I started looking at the PR. Two initial remarks:

  1. Your PR should not update pipeline.py, in particular because it silently concatenate context and question tokens. We'd rather prefer to modify evaluate.py as you initially suggested (adding a [SEP] special token)
  2. Your PR should not modify qfilter_press.py, duo_attention_press.py and longbech/calculate_metrics.py

I started to look more into finch_press.py. Would it be possible to re-use components from other presses ? For instance you could delete compute_finch_attention and compute_normalization_factors and simply use

    def score(self, module, hidden_states, keys, values, attentions, kwargs):

        bsz, num_key_value_heads, q_len, _ = keys.shape
        num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

        if attentions is not None:
            attn_weights = attentions[..., -self.condition_len :, : -self.condition_len]
        else:
            attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"])

        if self.normalize_scores:
            non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
            non_zero_counts = non_zero_counts.to(attn_weights.device)
            attn_weights = attn_weights * non_zero_counts

        # Average per group
        scores = attn_weights.mean(dim=-2)
        scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
        scores = scores.mean(dim=2)
        
        # Add back the observation window. Use max score to make sure the window is not pruned.
        scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
        return scores

Comparison with the initial score method and the score method I propose:

image

Note that in the code above I replaced sum by mean to avoid too large floats when using bfloat16. Also I replaced compute_normalization_factors by a simple torch.arange which is equivalent. Similarly, I will look if it's possible to use ChunkPress and KeyRerotationPress.

values = values.gather(2, indices).contiguous()
return keys, values

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any update here compared to BasePress ? if not remove

@maxjeblick maxjeblick removed their request for review April 10, 2025 12:15
…te_metrics.py,

use mean instead of sum to avoid large floats,
replace computed_normalization factors with torch.arange

Co-authored-by: SimJeg <[email protected]>

Signed-off-by: giulio98 <[email protected]>
@giulio98
Copy link
Copy Markdown
Contributor Author

Hello @SimJeg
I deleted my modifications for longbench, duo_filter_press and qfilter_press it was modified automatically when I run black.
Regarding the score I applied your suggestion, however I keep the call to compute_finch_attention, because when split_size is different than 1, the diagonal for attention mask in line 70 has to be computed like this:

attention_mask = torch.triu(attention_mask, diagonal=key_states.shape[-2] - condition_len + 1)

instead of

attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)

because it needs into account for the past cache of previous iteration, similarly I had to overwrite forward for the following line:

        q_len = hidden_states.shape[1]

        # Don't compress after pre-filling
        if kwargs["cache_position"][-1] > q_len + cache.get_seq_length():
            return output

line 190-194.

For same reason I couldn't use KeyRerotationPress because on the way the nkept is calculated that has to take into account the previous cache length.

Finally, ChunkPress doesn't implement the chunked forward correctly, we wrote for this pourpose the chunked_forward.

Regarding the special [SEP] token, I thought it will make a bit hacky for the user to use FinchPress, but I'm open to discussion for a better user friendly usage.

Giulio

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 14, 2025

In SnapKVPress we have

bsz, num_key_value_heads, q_len, _ = keys.shape

so in fact q_len is already keys.shape[-2] and you can remove compute_finch_attention.

Co-authored-by: SimJeg <[email protected]>

Signed-off-by: giulio98 <[email protected]>
@giulio98
Copy link
Copy Markdown
Contributor Author

In SnapKVPress we have

bsz, num_key_value_heads, q_len, _ = keys.shape

so in fact q_len is already keys.shape[-2] and you can remove compute_finch_attention.

You are right I replaced it with SnapKVPress compute window attention

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 14, 2025

@giulio98 do you think it would make a big difference (in terms of accuracy) to move from :

  1. Run forward pass for x=cat([chunk, question])
  2. Compute scores, compress and rerotate KV cache for chunk
  3. Go to next chunk

to

  1. Run forward pass for the whole inputs
  2. Compute scores by chunk (as in ChunkPress)
  3. Compress and rerotate (as in KeyRerotationPress)

I know it's different but the code would be much easier to read (which is my main concern so far), and maybe have similar performances

@giulio98
Copy link
Copy Markdown
Contributor Author

@giulio98 do you think it would make a big difference (in terms of accuracy) to move from :

  1. Run forward pass for x=cat([chunk, question])
  2. Compute scores, compress and rerotate KV cache for chunk
  3. Go to next chunk

to

  1. Run forward pass for the whole inputs
  2. Compute scores by chunk (as in ChunkPress)
  3. Compress and rerotate (as in KeyRerotationPress)

I know it's different but the code would be much easier to read (which is my main concern so far), and maybe have similar performances

My main concern about this implementation is that the chunked forward was designed in Finch explicttly to handle inputs that exceeed the context window size of LLMs (e.g 128k tokens) using the chunked forward input is split into managable chunks whose length is less than the context window size and that can be processed correctly by the model.

@giulio98
Copy link
Copy Markdown
Contributor Author

@giulio98 do you think it would make a big difference (in terms of accuracy) to move from :

  1. Run forward pass for x=cat([chunk, question])
  2. Compute scores, compress and rerotate KV cache for chunk
  3. Go to next chunk

to

  1. Run forward pass for the whole inputs
  2. Compute scores by chunk (as in ChunkPress)
  3. Compress and rerotate (as in KeyRerotationPress)

I know it's different but the code would be much easier to read (which is my main concern so far), and maybe have similar performances

My main concern about this implementation is that the chunked forward was designed in Finch explicttly to handle inputs that exceeed the context window size of LLMs (e.g 128k tokens) using the chunked forward input is split into managable chunks whose length is less than the context window size and that can be processed correctly by the model.

On the other hand I know that doing as you proposed will make easier to apply for example AdaKV on Finch with possible performance gains

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 14, 2025

Indeed with a compression ratio of 50%, Finch could handle inputs up to 256k even if the LLM max size is 128k (and more generally max_len / compression_ratio). However, even on RULER with 4k (which is far below 128k) we see degradations at 50%, so it might not be relevant to focus on the 128k regime, WDYT ?

@giulio98
Copy link
Copy Markdown
Contributor Author

I suspect that RULER is a dataset very information dense, so it is very hard to apply compression beyond a certain limit, instead if we imagine to compress external knowledge from let's say wikipedia pages I think will make a lot of sense KV Compression.

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 14, 2025

Draft for a proposal of a simplified FinchPress:

# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import math
from dataclasses import dataclass

import torch
from torch.nn import functional as F

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from transformers.models.llama.modeling_llama import rotate_half


@dataclass
class FinchPress(BasePress):

    compression_ratio: float = 0.0
    split_size: int = 1
    normalize_scores: bool = True
    condition_len: int = None

    def score(self, module, hidden_states, keys, values, attentions, kwargs):
        """
        Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
        """

        bsz, num_key_value_heads, q_len, _ = keys.shape
        num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

        if attentions is not None:
            attn_weights = attentions[..., -self.condition_len :, : -self.condition_len]
        else:
            attn_weights = SnapKVPress.compute_window_attention(
                module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"]
            )

        if self.normalize_scores:
            non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
            non_zero_counts = non_zero_counts.to(attn_weights.device)
            attn_weights = attn_weights * non_zero_counts

        # Average per group
        scores = attn_weights.mean(dim=-2)
        scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
        scores = scores.mean(dim=2)

        # Add back the observation window. Use max score to make sure the window is not pruned.
        scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
        return scores

    def compress(self, module, hidden_states, keys, values, attentions, kwargs):
        """
        Scores are computed by chunks, keys and values are then compressed and re-rotated.
        """

        q_len = hidden_states.shape[1]

        if self.compression_ratio == 0:
            return keys, values
        assert (self.condition_len is not None) and (self.condition_len < q_len)

        # Compute scores
        scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

        # Compute indices by chunks
        indices = []
        chunk_size = math.ceil(q_len / self.split_size)
        for i, chunk_scores in enumerate(torch.split(scores, chunk_size, dim=2)):
            n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio)))
            chunk_indices = i * chunk_size + chunk_scores.topk(n_kept, dim=-1).indices
            indices.append(chunk_indices)

        indices = torch.cat(indices, dim=-1)
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

        # Rerotate keys and values
        cos, sin = kwargs["position_embeddings"]
        keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
        keys = keys.gather(2, indices).contiguous()
        cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
        keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

        values = values.gather(2, indices).contiguous()

        return keys, values

I will run it for 50% compression to compare with what you reported

@giulio98
Copy link
Copy Markdown
Contributor Author

@SimJeg
Also, I noticed that for small compression rate (e.g 0.1, 0.25) FinchPress has degraded performance with respect to SnapKVPress which is a bit weird, maybe the inclusion of the answer_prefix in the condition degrate performance?
Would you happen to have compute available to run a quick experiment to test this hypothesis?

@giulio98
Copy link
Copy Markdown
Contributor Author

Draft for a proposal of a simplified FinchPress:

# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import math
from dataclasses import dataclass

import torch
from torch.nn import functional as F

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from transformers.models.llama.modeling_llama import rotate_half


@dataclass
class FinchPress(BasePress):

    compression_ratio: float = 0.0
    split_size: int = 1
    normalize_scores: bool = True
    condition_len: int = None

    def score(self, module, hidden_states, keys, values, attentions, kwargs):
        """
        Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
        """

        bsz, num_key_value_heads, q_len, _ = keys.shape
        num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

        if attentions is not None:
            attn_weights = attentions[..., -self.condition_len :, : -self.condition_len]
        else:
            attn_weights = SnapKVPress.compute_window_attention(
                module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"]
            )

        if self.normalize_scores:
            non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
            non_zero_counts = non_zero_counts.to(attn_weights.device)
            attn_weights = attn_weights * non_zero_counts

        # Average per group
        scores = attn_weights.mean(dim=-2)
        scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
        scores = scores.mean(dim=2)

        # Add back the observation window. Use max score to make sure the window is not pruned.
        scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
        return scores

    def compress(self, module, hidden_states, keys, values, attentions, kwargs):
        """
        Scores are computed by chunks, keys and values are then compressed and re-rotated.
        """

        q_len = hidden_states.shape[1]

        if self.compression_ratio == 0:
            return keys, values
        assert (self.condition_len is not None) and (self.condition_len < q_len)

        # Compute scores
        scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

        # Compute indices by chunks
        indices = []
        chunk_size = math.ceil(q_len / self.split_size)
        for i, chunk_scores in enumerate(torch.split(scores, chunk_size, dim=2)):
            n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio)))
            chunk_indices = i * chunk_size + chunk_scores.topk(n_kept, dim=-1).indices
            indices.append(chunk_indices)

        indices = torch.cat(indices, dim=-1)
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

        # Rerotate keys and values
        cos, sin = kwargs["position_embeddings"]
        keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
        keys = keys.gather(2, indices).contiguous()
        cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
        keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

        values = values.gather(2, indices).contiguous()

        return keys, values

I will run it for 50% compression to compare with what you reported

@SimJeg I just noticed that you may have to order the indices as in line 114 in finch_press.py, as we saw this will enhance performance.

Signed-off-by: giulio98 <[email protected]>
@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 16, 2025

@giulio98 I created a new branch here: https://github.com/NVIDIA/kvpress/tree/simon/finch

What it contains:

  • improved version of my draft above
  • new way to compute condition_len (renamed window_size to align with SnapKV) which avoid important refacto of the pipeline with the silent concatenation of context_ids and question_ids. This is inspired by what you first suggested in the issue (FinchPress Scorer #59). I used <bos_token> instead of a new [SEP] token as it's possible to access it directly from the model (so no need to add extra param to the press).
  • moved from split_size to chunk_length to align with the paper

I will share results with 50% compression. Could you provide the detailed performances for each subtask you obtained ?

update:

cwe= 94.29
fwe= 91.03
niah_multikey_1= 100.0
niah_multikey_2= 94.74
niah_multikey_3= 88.89
niah_multiquery= 97.5
niah_multivalue= 96.67
niah_single_1= 100.0
niah_single_2= 100.0
niah_single_3= 92.59
qa_1= 79.17
qa_2= 40.74
vt= 91.76

Average: 89.80

So it's slightly lower than your version. Once slight difference however is that I compress the question too but might impact only ~0.5% of compression ratio.

@giulio98
Copy link
Copy Markdown
Contributor Author

@giulio98 I created a new branch here: https://github.com/NVIDIA/kvpress/tree/simon/finch

What it contains:

  • improved version of my draft above

  • new way to compute condition_len (renamed window_size to align with SnapKV) which avoid important refacto of the pipeline with the silent concatenation of context_ids and question_ids. This is inspired by what you first suggested in the issue (FinchPress Scorer #59). I used <bos_token> instead of a new [SEP] token as it's possible to access it directly from the model (so no need to add extra param to the press).

  • moved from split_size to chunk_length to align with the paper

I will share results with 50% compression. Could you provide the detailed performances for each subtask you obtained ?

update:


cwe= 94.29

fwe= 91.03

niah_multikey_1= 100.0

niah_multikey_2= 94.74

niah_multikey_3= 88.89

niah_multiquery= 97.5

niah_multivalue= 96.67

niah_single_1= 100.0

niah_single_2= 100.0

niah_single_3= 92.59

qa_1= 79.17

qa_2= 40.74

vt= 91.76



Average: 89.80

So it's slightly lower than your version. Once slight difference however is that I compress the question too but might impact only ~0.5% of compression ratio.

Hello,

I have to rerun the experiment, meantime first thing I noticed is missing is the sorting of the indices just after the topk, in fact top k return indices according to higher score by default, however we may need to sort them to their natural order because otherwise they can assume different meaning (this can be one thing that can enhance performance also in other presses).

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 17, 2025

Great catch ! It also impacts KeyRerotationPress so I will correct it, but it won't impact other presses as the order of keys and values does not matter.

I re-ran your implementation and get 91.6. Also I made a mistake in the numbers I reported above (I reported results for 5% of the data). With the error, I get 90.8.

I will correct the error and report results for 4 options: with / without - normalization / rerotation.

@giulio98
Copy link
Copy Markdown
Contributor Author

Great catch ! It also impacts KeyRerotationPress so I will correct it, but it won't impact other presses as the order of keys and values does not matter.

I re-ran your implementation and get 91.6. Also I made a mistake in the numbers I reported above (I reported results for 5% of the data). With the error, I get 90.8.

I will correct the error and report results for 4 options: with / without - normalization / rerotation.

Ah yes! If no rerotation is applied it is permutation invariant.

@SimJeg SimJeg mentioned this pull request Apr 17, 2025
@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 17, 2025

Updated results. Fixing the bug led to slighly worse performance (especially for niah_multikey_3).
Also again, the original implementation slightly over-estimates scores because compression ratio is not exactly the correct one.
Results w/o re-rotation are not reported because I forgot it would imply to update pipeline.py too.

Will look again to what might be the difference.

original new new (no norm)
cwe 98.8 97.5 97.6
fwe 94.4 91.3 91.1
niah_multikey_1 99.6 99.2 99.2
niah_multikey_2 80.2 79.4 80
niah_multikey_3 78.8 69.8 69.8
niah_multiquery 99.2 99.2 99
niah_multivalue 96.4 96 96.4
niah_single_1 100 99.8 99.8
niah_single_2 99.4 99.4 99.4
niah_single_3 94.6 91.6 92.2
qa_1 87.2 87 87.4
qa_2 62.8 64 64
vt 99.9 99.9 99.9
average 91.6 90.3 90.5

Comment on lines +106 to +109
n_kept_context = int(context_length * (1 - self.compression_ratio))
else:
past_cache_len = scores.shape[-1] - q_len
n_kept_context = int((q_len - self.condition_len) * (1 - self.compression_ratio)) + past_cache_len
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not correct, for instance for last_iteration should be

n_kept_context = int(context_length * (1 - self.compression_ratio) - self.condition_len * self.compression_ratio)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition len shouldn't be compressed isn't it? Because in other presses the question is provided as is so if we compress also the question then the comparison will not be fair

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the difference is very small (~0.5% in CR) so I don't think it should impact much performance

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.

Query-aware compression like finch or snapkkv and question agnostic comes with trade off. Query-aware compression gives highest performance but we have to rerun the compression for each new query, instead query agnostic like other presses can be done independently on the question giving more throughtput, this is a topic we explored in our recent paper: https://arxiv.org/pdf/2503.04973 where we propose a middle ground and compress using just task and few shot examples. That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO. What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we have to rerun the compression for each new query

Approaches like SnapKV or Finch are close to sparse attention (i.e. retrieve the right KV pairs from the full KV cache) and somehow a bit different from compression, because as you mention you have to re-run them for each query.

That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO

I agree that's why I proposed this update. Could you review #69 and comment ? You can also open a new PR with the same code if you want to appear in the contributors of the repo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be possible to add Co-authored by in your commits adding myself @miriam-16 and @FaureElia ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, could you please close this branch ?

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 17, 2025

Ok I found the difference, in your implementation the condition_len tokens include the question but also the end of the chat template and the answer prefix. My version is aligned with what we did with other presses (only include the question). So I would favor the new PR I opened.

@giulio98 giulio98 closed this Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants