Bug
During the prefill phase in KVPressTextGenerationPipeline._forward, the code calls the full model (self.model(...)). This waste memory and compute because at this stage the final lmhead computation is not needed.
Proposed Fix
def _forward(...):
- with press(self.model) if press is not None else contextlib.nullcontext():
- # Full forward (includes lm_head) – not necessary for cache prefill
- self.model(
- input_ids=context_ids,
- past_key_values=cache,
- output_attentions=self.output_attentions(press),
- num_logits_to_keep=1,
- )
+ with press(self.model) if press is not None else contextlib.nullcontext():
+ # Run **only** the transformer backbone to build the KV-cache
+ self.model.model(
+ input_ids=context_ids,
+ past_key_values=cache,
+ output_attentions=self.output_attentions(press),
+ )
I tried this small fix and got fastest compute and less peak GPU Memory Usage. Mi
Bug
During the prefill phase in
KVPressTextGenerationPipeline._forward, the code calls the full model (self.model(...)). This waste memory and compute because at this stage the final lmhead computation is not needed.Proposed Fix
I tried this small fix and got fastest compute and less peak GPU Memory Usage. Mi