-
Notifications
You must be signed in to change notification settings - Fork 127
add FinchPress #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
add FinchPress #64
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
73dcb45
first commit
FaureElia 81b53f9
created setup
FaureElia 8f87626
first basic implementation of FinchPress
FaureElia 1db9786
adjusted FinchPress
FaureElia 98bc55a
introduce new parameter condition_len (still have to modify compute_f…
FaureElia 008e7e9
Adapted compute_finch_attention to condition_len. TODO: Normalization…
miriam-16 5f2d089
fix question len, adjusted the question ids for generation.
giulio98 c512283
fix: parameter condition_len in function
miriam-16 87ba07e
add nomralization
FaureElia fa0e28d
clean FinchPress
FaureElia c7dff0e
normalization
giulio98 8970c97
fix len question ids
giulio98 84ce456
add self
giulio98 f0efe61
chunked forward -- not working(gibberish output), check shapes and le…
giulio98 15db060
Fixed attention mask and binary mask in compute_finch_attention
FaureElia 26c64b0
updated gitignore
FaureElia a6ebc11
n_kept value at last iteration (TODO: further checks)
miriam-16 e2f3dd6
fixed pipeline for other presses and finished adjusting number of tok…
FaureElia 54c991b
Fix: n_kept at last iteration considers condition_len
miriam-16 91e9f05
adjusted n_kept, at final chunk keeps 90% of context+full question
FaureElia c48ea98
added docstring and general code cleaning
miriam-16 7ec514a
reorder indices, add normalization parameter, add sink tokens
giulio98 657a27d
partially fixed finch press
FaureElia d0bb5d0
fixed finch press, aligned to original code
FaureElia fdde74d
preparing for pr
giulio98 c1acb0d
add ChunkKV
Dominic789654 8533beb
Update copyright date (#60)
SimJeg e71f4e9
Add QFilterPress (#54)
NathanGodey 4f917cc
Add longbench benchmark
Xnhyacinth 8d95b14
Add DuoAttention on the fly (#63)
SimJeg cace3a5
Resolve stash conflict in tests/default_presses.py
giulio98 a833ae3
Fix: remove conflict markers after saving changes
giulio98 39417ad
test passes
giulio98 2cdd69c
add readme
giulio98 0e4d4c8
commit
giulio98 ea42501
Merge branch 'main' into features/finch_press
giulio98 c570844
make style
giulio98 9afd7f8
remove leftovers
giulio98 e8d84fa
make style check pass
giulio98 689c9c5
fix: readme
giulio98 f628ca4
unmodify qfilter_press.py duo_attention_press.py and longbech/calcula…
giulio98 c5a9ca7
typo duo attention press import
giulio98 f8034e5
use SnapKV window attention for finch
giulio98 f830665
update past_key_values using last_output from original forward
giulio98 3f2ed03
add import snapkv
giulio98 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| import math | ||
| from contextlib import contextmanager | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch.nn import functional as F | ||
| from transformers import PreTrainedModel, QuantizedCache | ||
| from transformers.models.llama.modeling_llama import repeat_kv, rotate_half | ||
|
|
||
| from kvpress.presses.base_press import BasePress | ||
| from kvpress.presses.snapkv_press import SnapKVPress | ||
|
|
||
|
|
||
| @dataclass | ||
| class FinchPress(BasePress): | ||
| """ | ||
| Finch uses the attention information between the prompt and the document chunk to dynamically | ||
| identify the most relevant KV pairs across different layers. | ||
| This information then is stored in the KV cache for the processing of the next input chunk | ||
| (https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280) | ||
|
|
||
| Parameters | ||
| ---------- | ||
| condition_len : int | ||
| The number of tokens in the prompt. | ||
| split_size : int | ||
| The number of chunks to split the context into. | ||
| """ | ||
|
|
||
| compression_ratio: float = 0.0 | ||
| split_size: int = 1 | ||
| normalize_scores: bool = True | ||
| condition_len: int = None # calculate on length of question dynamically | ||
|
|
||
| def score(self, module, hidden_states, keys, values, attentions, kwargs): | ||
|
|
||
| bsz, num_key_value_heads, q_len, _ = keys.shape | ||
| num_key_value_groups = module.config.num_attention_heads // num_key_value_heads | ||
|
|
||
| if attentions is not None: | ||
| attn_weights = attentions[..., -self.condition_len:, : -self.condition_len] | ||
| else: | ||
| attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"]) | ||
| if self.normalize_scores: | ||
| non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None] | ||
| non_zero_counts = non_zero_counts.to(attn_weights.device) | ||
| attn_weights = attn_weights * non_zero_counts | ||
|
|
||
| # Average per group | ||
| scores = attn_weights.mean(dim=-2) | ||
| scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len) | ||
| scores = scores.mean(dim=2) | ||
|
|
||
| # Add back the observation window. Use max score to make sure the window is not pruned. | ||
| scores = F.pad(scores, (0, self.condition_len), value=scores.max().item()) | ||
| return scores | ||
|
|
||
| @staticmethod | ||
| def _rerotate_cos_sin(x, inv_freq, important_pos_batch): | ||
| B, H, L = important_pos_batch.shape | ||
| device = important_pos_batch.device | ||
| device_type = x.device.type | ||
| dtype = x.dtype | ||
| idx = torch.arange(0, L, device=device) | ||
| idx = idx.unsqueeze(0) | ||
| inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) | ||
| idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) | ||
| delta_pos = idx - important_pos_batch | ||
| delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) | ||
|
|
||
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | ||
|
|
||
| with torch.autocast(device_type=device_type, enabled=False): | ||
| freqs = delta_pos.float() * inv_freq.float() | ||
| freqs = freqs.transpose(2, 3) | ||
| emb = torch.cat((freqs, freqs), dim=-1) | ||
| cos = emb.cos().contiguous() | ||
| sin = emb.sin().contiguous() | ||
| return cos.to(dtype=dtype), sin.to(dtype=dtype) | ||
|
|
||
| def compress( | ||
| self, | ||
| module: nn.Module, | ||
| hidden_states: torch.Tensor, | ||
| keys: torch.Tensor, | ||
| values: torch.Tensor, | ||
| attentions: torch.Tensor, | ||
| kwargs: dict, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if self.compression_ratio == 0: | ||
| return keys, values | ||
|
|
||
| # Compute scores from base press | ||
| scores = self.score(module, hidden_states, keys, values, attentions, kwargs) | ||
|
|
||
| context_length = kwargs["context_length"] | ||
| last_iteration = kwargs["split_idx"] == self.split_size - 1 | ||
| q_len = hidden_states.shape[1] | ||
|
|
||
| if last_iteration: | ||
| n_kept_context = int(context_length * (1 - self.compression_ratio)) | ||
| else: | ||
| past_cache_len = scores.shape[-1] - q_len | ||
| n_kept_context = int((q_len - self.condition_len) * (1 - self.compression_ratio)) + past_cache_len | ||
|
|
||
| indices_context = scores[:, :, : -self.condition_len].topk(n_kept_context, dim=-1).indices | ||
| indices_condition = torch.arange(scores.shape[-1] - self.condition_len, scores.shape[-1], device=scores.device)[ | ||
| None, None, : | ||
| ].expand(scores.shape[0], scores.shape[1], -1) | ||
| indices_context, _ = torch.sort(indices_context, dim=-1) | ||
|
|
||
| # concatenate the indices | ||
| indices = torch.cat([indices_context, indices_condition], dim=-1) | ||
|
|
||
| # rerotate the positions | ||
| new_cos, new_sin = self._rerotate_cos_sin(keys, module.rotary_emb.inv_freq, indices) | ||
| indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) | ||
| keys = keys.gather(2, indices).contiguous() | ||
| keys = (keys * new_cos) + (rotate_half(keys) * new_sin) | ||
| values = values.gather(2, indices).contiguous() | ||
| return keys, values | ||
|
|
||
| def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there any update here compared to BasePress ? if not remove |
||
| """ | ||
| Default forward hook called after the forward pass of an attention layer. | ||
| The hook calls the compress method to compress the KV cache while ensuring: | ||
| - compression is only applied only during the pre-filling phase | ||
| - KV cache quantization is handled correctly | ||
|
|
||
| Parameters | ||
| ---------- | ||
| module : | ||
| Transformer attention layer. | ||
| input : | ||
| Input to the hook. This is the input to the forward pass of the layer. | ||
| kwargs : | ||
| Keyword arguments, as given to the forward pass of the layer. | ||
| output : | ||
| Output of the hook. This is the original output of the forward pass of the layer. | ||
|
|
||
| Returns | ||
| ------- | ||
| Modified output of the forward pass of the layer. | ||
| """ | ||
| hidden_states = kwargs["hidden_states"] | ||
| cache = kwargs["past_key_value"] | ||
| q_len = hidden_states.shape[1] | ||
|
|
||
| # Don't compress after pre-filling | ||
| if kwargs["cache_position"][-1] > q_len + cache.get_seq_length(): | ||
| return output | ||
|
|
||
| if isinstance(cache, QuantizedCache): | ||
| keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) | ||
| values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) | ||
| else: | ||
| keys = cache.key_cache[module.layer_idx] | ||
| values = cache.value_cache[module.layer_idx] | ||
|
|
||
| keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) | ||
|
|
||
| if isinstance(cache, QuantizedCache): | ||
| cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) | ||
| cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) | ||
| cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) | ||
| cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device) | ||
| cache._seen_tokens = keys.shape[2] | ||
| else: | ||
| cache.key_cache[module.layer_idx] = keys | ||
| cache.value_cache[module.layer_idx] = values | ||
| return output | ||
|
|
||
| @contextmanager | ||
| def __call__(self, model: PreTrainedModel): | ||
| hooks = [] | ||
| try: | ||
| for layer in model.model.layers: | ||
| layer.self_attn.rotary_emb = model.model.rotary_emb | ||
| hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) | ||
| original_forward = model.forward | ||
|
|
||
| def chunked_forward(*args, **kwargs): | ||
| args = list(args) | ||
| kwargs["input_ids"] = kwargs.get("input_ids", args.pop(0) if args else None) | ||
| kwargs["attention_mask"] = kwargs.get("attention_mask", args.pop(0) if args else None) | ||
| args = tuple(args) | ||
|
|
||
| input_ids = kwargs["input_ids"] | ||
| attention_mask = kwargs.get("attention_mask") | ||
|
|
||
| # Split input_ids into context and question tokens. | ||
| context_ids = input_ids[:, : -self.condition_len] | ||
| question_ids = input_ids[:, -self.condition_len :] | ||
|
|
||
| if attention_mask is not None: | ||
| context_attention_mask = attention_mask[:, : -self.condition_len] | ||
| question_attention_mask = attention_mask[:, -self.condition_len :] | ||
|
|
||
| # Calculate the total number of context tokens. | ||
| context_length = context_ids.shape[1] | ||
| kwargs["context_length"] = context_length | ||
|
|
||
| # Determine the chunk size so that we split context_ids into exactly split_size chunks. | ||
| chunk_size = context_length // self.split_size | ||
| last_output = None | ||
|
|
||
| for i in range(self.split_size): | ||
| kwargs["split_idx"] = i | ||
| start = i * chunk_size | ||
| # For the last chunk, include any remaining tokens. | ||
| end = start + chunk_size if i < self.split_size - 1 else context_length | ||
|
|
||
| # Get the current chunk from context_ids and combine with the question tokens. | ||
| context_chunk = context_ids[:, start:end] | ||
| kwargs["input_ids"] = torch.cat([context_chunk, question_ids], dim=1) | ||
|
|
||
| if attention_mask is not None: | ||
| context_attention_mask_chunk = context_attention_mask[:, start:end] | ||
| kwargs["attention_mask"] = torch.cat( | ||
| [context_attention_mask_chunk, question_attention_mask], dim=1 | ||
| ) | ||
|
|
||
| last_output = original_forward(use_cache=True, *args, **kwargs) | ||
|
|
||
| # Only adjust the past key/values caches if it's not the last iteration. | ||
| if i < self.split_size - 1: | ||
| for layer_idx, _ in enumerate(model.model.layers): | ||
| # Adjust the past key/values caches to remove the question tokens for the next iteration | ||
| last_output.past_key_values.key_cache[layer_idx] = ( | ||
| last_output.past_key_values | ||
| .key_cache[layer_idx][:, :, : -self.condition_len, :] | ||
| .contiguous() | ||
| ) | ||
| last_output.past_key_values.value_cache[layer_idx] = ( | ||
| last_output.past_key_values | ||
| .value_cache[layer_idx][:, :, : -self.condition_len, :] | ||
| .contiguous() | ||
| ) | ||
| last_output.past_key_values._seen_tokens = last_output.past_key_values.get_seq_length() | ||
| kwargs["past_key_values"] = last_output.past_key_values | ||
|
|
||
| return last_output | ||
|
|
||
| # Override the model's forward with the chunked version | ||
| model.forward = chunked_forward | ||
| yield | ||
| finally: | ||
| # Remove all hooks and restore the original forward method | ||
| for hook in hooks: | ||
| hook.remove() | ||
| model.forward = original_forward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not correct, for instance for
last_iterationshould beThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition len shouldn't be compressed isn't it? Because in other presses the question is provided as is so if we compress also the question then the comparison will not be fair
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that the difference is very small (~0.5% in CR) so I don't think it should impact much performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Query-aware compression like finch or snapkkv and question agnostic comes with trade off. Query-aware compression gives highest performance but we have to rerun the compression for each new query, instead query agnostic like other presses can be done independently on the question giving more throughtput, this is a topic we explored in our recent paper: https://arxiv.org/pdf/2503.04973 where we propose a middle ground and compress using just task and few shot examples. That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approaches like SnapKV or Finch are close to sparse attention (i.e. retrieve the right KV pairs from the full KV cache) and somehow a bit different from compression, because as you mention you have to re-run them for each query.
I agree that's why I proposed this update. Could you review #69 and comment ? You can also open a new PR with the same code if you want to appear in the contributors of the repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be possible to add
Co-authored byin your commits adding myself @miriam-16 and @FaureElia ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, could you please close this branch ?