Skip to content

generate.py: GenerationBatch.filter — add else branches so logits_processors / samplers length stays in lockstep with uids#1225

Open
mloiterman wants to merge 1 commit intoml-explore:mainfrom
mloiterman:fix/generation-batch-filter-stale-length
Open

generate.py: GenerationBatch.filter — add else branches so logits_processors / samplers length stays in lockstep with uids#1225
mloiterman wants to merge 1 commit intoml-explore:mainfrom
mloiterman:fix/generation-batch-filter-stale-length

Conversation

@mloiterman
Copy link
Copy Markdown

Summary

GenerationBatch.filter (around line 1392 of mlx_lm/generate.py)
uses if any(self.samplers) and if any(self.logits_processors) to
guard whether it trims those per-sequence lists. When every current
slot is None / [] (the common shape for any in-flight request
that does not attach a custom sampler or logits processor),
any(...) is False and the trim is skipped. The lists keep stale
length while self.uids is correctly trimmed.

When new sequences subsequently arrive via extend, their entries
are appended to the over-long list. _step's per-sequence loop
(for e in range(len(self.uids)): for processor in self.logits_processors[e]:)
then reads stale-index slots — silently bypassing the new sequence's
processor for exactly one generation step before the next filter
corrects the length.

The structurally-symmetric PromptProcessingBatch.filter (around
line 1117) already handles this case with explicit else branches
that reset to [None] * len(keep) / [[]] * len(keep). This PR
mirrors those branches onto GenerationBatch.filter.

Impact (downstream)

Discovered while running schema-constrained generation
(response_format=json_schema with strict: true) in a
production-shape FastAPI server. After any tool-bearing chat
completion (which carries no logits_processors), the very next
request's xgrammar-based grammar processor was silently bypassed —
the model returned plain unconstrained text instead of schema-valid
output. The bug is silent (no exception, no log, no error code) and
affects exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.

Diff

--- a/mlx_lm/generate.py
+++ b/mlx_lm/generate.py
@@ -1391,8 +1391,12 @@ class GenerationBatch:
         self.tokens = [self.tokens[idx] for idx in keep]
         if any(self.samplers):
             self.samplers = [self.samplers[idx] for idx in keep]
+        else:
+            self.samplers = [None] * len(keep)
         if any(self.logits_processors):
             self.logits_processors = [self.logits_processors[idx] for idx in keep]
+        else:
+            self.logits_processors = [[]] * len(keep)
         self.max_tokens = [self.max_tokens[idx] for idx in keep]
         self.state_machines = [self.state_machines[idx] for idx in keep]

The [[]] (logits_processors) and [None] (samplers) defaults are
copied verbatim from PromptProcessingBatch.filter so the two
methods stay symmetric.

Standalone reproducer (no model required)

"""Standalone reproducer for the GenerationBatch.filter stale-length bug.

Uses object.__new__(GenerationBatch) to bypass the model-requiring
__init__ — runs in milliseconds, no model load, no GPU.

Run:
    python upstream_repro.py

Exit code 0 = bug present (current upstream main).
Exit code 1 = bug fixed.
"""
from __future__ import annotations
from mlx_lm.generate import GenerationBatch


def make_bare_batch() -> GenerationBatch:
    b = object.__new__(GenerationBatch)
    b.model = None
    b.uids = []
    b.prompt_cache = []
    b.tokens = []
    b.samplers = []
    b.fallback_sampler = None
    b.logits_processors = []
    b.state_machines = []
    b.max_tokens = []
    b._next_tokens = None
    b._next_logprobs = []
    b._token_context = []
    b._num_tokens = []
    b._matcher_states = []
    return b


def populate_no_processor_request(b: GenerationBatch, uid: int) -> None:
    """A 'tool-style' request: no logits processors, no custom sampler.
    This is the common shape for any request that does not ask for
    grammar / penalty / thinking-budget logits modification."""
    b.uids.append(uid)
    b.prompt_cache.append(None)
    b.tokens.append([1, 2, 3])
    b.samplers.append(None)
    b.logits_processors.append([])
    b.state_machines.append(None)
    b.max_tokens.append(16)
    b._token_context.append(None)
    b._num_tokens.append(0)
    b._matcher_states.append(None)
    b._next_logprobs.append(None)


def main() -> int:
    b = make_bare_batch()
    populate_no_processor_request(b, uid=42)

    assert len(b.uids) == 1
    assert b.logits_processors == [[]]
    assert b.samplers == [None]

    # Simulate the request finishing — keep=[] means drop everything.
    b.filter(keep=[])

    print(f"after filter(keep=[]):")
    print(f"  uids               = {b.uids} (len={len(b.uids)})")
    print(f"  logits_processors  = {b.logits_processors} (len={len(b.logits_processors)})")
    print(f"  samplers           = {b.samplers} (len={len(b.samplers)})")

    bug = (
        len(b.logits_processors) != len(b.uids)
        or len(b.samplers) != len(b.uids)
    )
    if bug:
        print("\nBUG REPRODUCED — list length(s) do NOT match len(uids).")
        return 0
    print("\nFIX VERIFIED — all per-sequence lists have len == len(uids).")
    return 1


if __name__ == "__main__":
    raise SystemExit(main())

Verified output (against mlx-lm 0.31.3 and main at the time of
this PR):

after filter(keep=[]):
  uids               = [] (len=0)
  logits_processors  = [[]] (len=1)
  samplers           = [None] (len=1)

BUG REPRODUCED — list length(s) do NOT match len(uids).

With the diff applied:

after filter(keep=[]):
  uids               = [] (len=0)
  logits_processors  = [] (len=0)
  samplers           = [] (len=0)

FIX VERIFIED — all per-sequence lists have len == len(uids).

Suggested unit test

def test_generation_batch_filter_clears_logits_processors_when_all_empty():
    """GenerationBatch.filter must keep self.logits_processors length
    in lockstep with self.uids even when every per-sequence slot is
    empty. Regression: prior versions guarded the trim with
    `if any(self.logits_processors)` and skipped on all-`[]`, leaving
    a stale-length list that would later mis-index `_step`."""
    from mlx_lm.generate import GenerationBatch

    b = object.__new__(GenerationBatch)
    b.uids = [42]
    b.prompt_cache = []
    b.tokens = [[1, 2, 3]]
    b.samplers = [None]
    b.fallback_sampler = None
    b.logits_processors = [[]]
    b.state_machines = [None]
    b.max_tokens = [16]
    b._next_tokens = None
    b._next_logprobs = [None]
    b._token_context = [None]
    b._num_tokens = [0]
    b._matcher_states = [None]

    b.filter(keep=[])

    assert len(b.uids) == 0
    assert len(b.logits_processors) == 0
    assert len(b.samplers) == 0

Related (different bugs)

The filter-list-length bug fixed here appears unreported.

…-sequence lists

`GenerationBatch.filter` (around line 1392) uses `if any(self.samplers)`
and `if any(self.logits_processors)` to guard whether it trims those
per-sequence lists. When every current slot is `None` / `[]` (the common
shape for any in-flight request that does not attach a custom sampler
or logits processor), `any(...)` is `False` and the trim is skipped.
The lists keep stale length while `self.uids` is correctly trimmed.

When new sequences subsequently arrive via `extend`, their entries are
appended to the over-long list. `_step`'s per-sequence loop
(`for e in range(len(self.uids)): for processor in
self.logits_processors[e]:`) then reads stale-index slots — silently
bypassing the new sequence's processor for exactly one generation step
before the next `filter` corrects the length.

The structurally-symmetric `PromptProcessingBatch.filter` (around line
1117) already handles this case with explicit `else` branches that
reset to `[None] * len(keep)` / `[[]] * len(keep)`. This commit
mirrors those branches onto `GenerationBatch.filter`.

Discovered while running schema-constrained generation
(`response_format=json_schema`, `strict: true`) in a production-shape
FastAPI server. After any tool-bearing chat completion (which carries
no `logits_processors`), the very next request's grammar processor
was silently bypassed — model returned plain unconstrained text instead
of schema-valid output. The bug was silent (no exception, no log) and
affected exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.
@mloiterman
Copy link
Copy Markdown
Author

Cross-reference: #1230 lands the symmetric fix in PromptProcessingBatch.extend, where [None] * N for absent per-slot logits_processors produces a list shape [None, ..., [fn], ...] after merging with a processor-equipped batch. That shape later crashes GenerationBatch._step at line 1346 with TypeError: 'NoneType' object is not iterable.

The two PRs touch different code paths (filter here, extend there) but share the same per-slot sentinel argument:

  • samplers[e] — type Optional[Callable]. Consumed at line 1358 as self.samplers[e] or self.fallback_sampler, so None is the correct sentinel.
  • logits_processors[e] — type List[Callable]. Consumed at line 1346 as for processor in self.logits_processors[e], so [] is the correct sentinel (None crashes the iterator).

Both #1225 and #1230 use this asymmetric pair (None vs []) and match the existing PromptProcessingBatch.filter pattern at line 1120. Safe to land independently or together.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant