Skip to content

trick to waste less compute and memory #91

@giulio98

Description

@giulio98

Bug

During the prefill phase in KVPressTextGenerationPipeline._forward, the code calls the full model (self.model(...)). This waste memory and compute because at this stage the final lmhead computation is not needed.

Proposed Fix

def _forward(...):
-        with press(self.model) if press is not None else contextlib.nullcontext():
-            # Full forward (includes lm_head) – not necessary for cache prefill
-            self.model(
-                input_ids=context_ids,
-                past_key_values=cache,
-                output_attentions=self.output_attentions(press),
-                num_logits_to_keep=1,
-            )
+        with press(self.model) if press is not None else contextlib.nullcontext():
+            # Run **only** the transformer backbone to build the KV-cache
+            self.model.model(
+                input_ids=context_ids,
+                past_key_values=cache,
+                output_attentions=self.output_attentions(press),
+            )

I tried this small fix and got fastest compute and less peak GPU Memory Usage. Mi

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions