Skip to content

Commit 8dd4622

Browse files
committed
Minor updates to DMS and KVzap
Signed-off-by: SimJeg <sjegou@nvidia.com>
1 parent 8b3c2f7 commit 8dd4622

File tree

5 files changed

+14
-9
lines changed

5 files changed

+14
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses:
150150
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
151151
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
152152
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
153-
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
153+
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill.
154154

155155
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
156156

kvpress/presses/dms_press.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class DMSPress(BasePress):
1717
"""
1818
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
1919
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.
20+
This press implements a dense-prefill version of DMS, not the sparse-prefill version.
2021
2122
Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
2223
to determine which KV pairs to evict. This allows for adaptive compression where the actual

kvpress/presses/kvzap_press.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class KVzapModel(PreTrainedModel):
2424

2525
def __init__(self, config):
2626
super().__init__(config)
27+
self.all_tied_weights_keys = {}
2728
if config.hidden_dim is None:
2829
# Linear model
2930
self.layers = nn.ModuleList(
@@ -72,8 +73,7 @@ def score(
7273
attentions: torch.Tensor,
7374
kwargs: dict,
7475
) -> torch.Tensor:
75-
module = self.kvzap_model.layers[module.layer_idx]
76-
module = module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
77-
with torch.no_grad():
78-
scores = module(hidden_states).transpose(1, 2)
76+
kvzap_module = self.kvzap_model.layers[module.layer_idx]
77+
kvzap_module = kvzap_module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
78+
scores = kvzap_module(hidden_states).transpose(1, 2)
7979
return scores

kvzap/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _forward_hook(self, module, input, kwargs, output):
201201
scale = scale.repeat_interleave(module.o_proj.block_size[0], dim=0)
202202
scale = scale.repeat_interleave(module.o_proj.block_size[1], dim=1)
203203
Wo = Wo.to(V.dtype) * scale
204-
Wo = Wo.view(module.config.num_attention_heads, module.head_dim, module.config.hidden_size)
204+
Wo = Wo.view(module.config.num_attention_heads, V.shape[-1], module.config.hidden_size)
205205
WoV_norm = torch.einsum("h i j, b h t i -> b h t j", Wo.to(dtype=V.dtype), V).norm(dim=-1)
206206
scores = torch.einsum("b h t i, b h i -> b h t i", scores, WoV_norm)
207207

kvzap/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,21 @@ def train_linear(X: torch.Tensor, y: torch.Tensor) -> KVzapModel:
106106
# Train a linear model for each layer
107107
params = []
108108
for layer_idx in tqdm(range(X.shape[1]), desc="Training linear models"):
109+
X_train = X[:, layer_idx].clone().to(torch.float32).numpy()
110+
y_train = y[:, layer_idx].clone().to(torch.float32).numpy()
109111
linear = Ridge()
110-
linear.fit(X[:, layer_idx].float(), y[:, layer_idx].float())
112+
linear.fit(X_train, y_train)
111113
params.append((linear.coef_, linear.intercept_))
112114

113115
# Load the parameters into a KVzapModel
114116
linear_model = KVzapModel(
115117
KVzapConfig(input_dim=X.shape[2], hidden_dim=None, output_dim=y.shape[2], n_modules=X.shape[1])
116118
)
117119
for layer_idx, (W, b) in enumerate(params):
118-
linear_model.layers[layer_idx].weight.data = torch.tensor(W, dtype=X.dtype) # type: ignore[index]
119-
linear_model.layers[layer_idx].bias.data = torch.tensor(b, dtype=X.dtype) # type: ignore[index]
120+
W = torch.tensor(np.atleast_2d(W), dtype=X.dtype)
121+
b = torch.tensor(np.atleast_1d(b), dtype=X.dtype)
122+
linear_model.layers[layer_idx].weight.data = W # type: ignore[index]
123+
linear_model.layers[layer_idx].bias.data = b # type: ignore[index]
120124
return linear_model
121125

122126

0 commit comments

Comments
 (0)