Skip to content

Commit 423301b

Browse files
committed
fix(generate): avoid None entries in merged logits_processors
PromptProcessingBatch.extend filled missing per-slot logits_processors with [None] when either side lacked configured processors. Merging an unconfigured batch with a processor-equipped batch then produced a list shaped like [None, ..., [fn], ...]. GenerationBatch._step at line 1346 iterates self.logits_processors[e] under the any() guard at line 1337, which raises TypeError on the None slots. Fill with [[]] instead. Matches the existing pattern at line 1120 (filter() restoring [[]] * len(keep)) and the per-slot type List[Callable]. Reproduce: construct two PromptProcessingBatch instances, one without processors and one with, then call extend; the merged self.logits_processors contains None entries. New unit test covers this shape directly.
1 parent ed1fca4 commit 423301b

2 files changed

Lines changed: 31 additions & 2 deletions

File tree

mlx_lm/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,12 +1063,12 @@ def extend(self, batch):
10631063
if not any(self.samplers):
10641064
self.samplers = [None] * len(self.uids)
10651065
if not any(self.logits_processors):
1066-
self.logits_processors = [None] * len(self.uids)
1066+
self.logits_processors = [[]] * len(self.uids)
10671067
samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids)
10681068
logits_processors = (
10691069
batch.logits_processors
10701070
if any(batch.logits_processors)
1071-
else [None] * len(batch.uids)
1071+
else [[]] * len(batch.uids)
10721072
)
10731073

10741074
self.uids.extend(batch.uids)

tests/test_generate.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mlx_lm.generate import (
1010
BatchGenerator,
1111
GenerationResponse,
12+
PromptProcessingBatch,
1213
SequenceStateMachine,
1314
batch_generate,
1415
generate,
@@ -402,6 +403,34 @@ def test_batch_generate_with_logits_processors(self):
402403
self.assertEqual(responses[uid1].logprobs[1].item(), 0.0)
403404
self.assertEqual(responses[uid2].logprobs[2].item(), 0.0)
404405

406+
def test_prompt_processing_batch_extend_mixes_logits_processors(self):
407+
"""Test PromptProcessingBatch.extend produces a per-slot list with no None entries when merging an unconfigured batch with a processor-equipped batch."""
408+
fallback = lambda x: mx.argmax(x, axis=-1)
409+
a = PromptProcessingBatch.empty(self.model, fallback)
410+
a.uids = [0]
411+
a.tokens = [[]]
412+
a.samplers = []
413+
a.logits_processors = []
414+
a.max_tokens = [1]
415+
a.state_machines = [SequenceStateMachine()]
416+
a.prompt_cache = []
417+
418+
procs = make_logits_processors({0: 2000.0})
419+
b = PromptProcessingBatch.empty(self.model, fallback)
420+
b.uids = [1]
421+
b.tokens = [[]]
422+
b.samplers = []
423+
b.logits_processors = [procs]
424+
b.max_tokens = [1]
425+
b.state_machines = [SequenceStateMachine()]
426+
b.prompt_cache = []
427+
428+
a.extend(b)
429+
430+
self.assertEqual(len(a.logits_processors), 2)
431+
for entry in a.logits_processors:
432+
self.assertIsInstance(entry, list)
433+
405434
def test_batch_generate_processor_tokens_match_prompt_on_first_step(self):
406435
prompt = self.tokenizer.encode("hello")
407436
seen = []

0 commit comments

Comments
 (0)