Skip to content

Commit b279bd1

Browse files
committed
fix(generate): handle None entries in GenerationBatch logits_processors
self.logits_processors can be a mixed list of None and List[Callable] after batches with and without processors are merged via extend(). The any() guard at line 1337 returns True for [None, [fn]] but the inner loop assumes every element is iterable, raising TypeError on None entries. Reproduce: BatchGenerator with no constructor processors, insert one prompt with no per-element processors and another with logits_processors=[[fn]], then call next_generated(). Mirrors the existing 'samplers[e] or self.fallback_sampler' pattern at line 1358.
1 parent ed1fca4 commit b279bd1

2 files changed

Lines changed: 18 additions & 1 deletion

File tree

mlx_lm/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ def _step(self) -> Tuple[List[int], List[mx.array]]:
13431343
processed_logits = []
13441344
for e in range(len(self.uids)):
13451345
sample_logits = logits[e : e + 1]
1346-
for processor in self.logits_processors[e]:
1346+
for processor in self.logits_processors[e] or ():
13471347
sample_logits = processor(token_context[e], sample_logits)
13481348
processed_logits.append(sample_logits)
13491349
logits = mx.concatenate(processed_logits, axis=0)

tests/test_generate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,23 @@ def test_batch_generate_with_logits_processors(self):
402402
self.assertEqual(responses[uid1].logprobs[1].item(), 0.0)
403403
self.assertEqual(responses[uid2].logprobs[2].item(), 0.0)
404404

405+
def test_batch_generate_mixed_processor_per_element(self):
406+
# Regression: when one inserted prompt has logits_processors and another
407+
# has none, the merged batch contains a mix of [None] and [fn] entries.
408+
# _step previously crashed iterating self.logits_processors[e] when [e]
409+
# was None.
410+
prompt = self.tokenizer.encode("hello")
411+
412+
batch_gen = BatchGenerator(self.model, max_tokens=1)
413+
batch_gen.insert([prompt])
414+
415+
logit_bias = {0: 2000.0}
416+
processors = make_logits_processors(logit_bias)
417+
batch_gen.insert([prompt], logits_processors=[processors])
418+
419+
responses = batch_gen.next_generated()
420+
self.assertEqual(len(responses), 2)
421+
405422
def test_batch_generate_processor_tokens_match_prompt_on_first_step(self):
406423
prompt = self.tokenizer.encode("hello")
407424
seen = []

0 commit comments

Comments
 (0)