generate.py: GenerationBatch.filter — add else branches so logits_processors / samplers length stays in lockstep with uids#1225
Open
mloiterman wants to merge 1 commit intoml-explore:mainfrom
Conversation
…-sequence lists `GenerationBatch.filter` (around line 1392) uses `if any(self.samplers)` and `if any(self.logits_processors)` to guard whether it trims those per-sequence lists. When every current slot is `None` / `[]` (the common shape for any in-flight request that does not attach a custom sampler or logits processor), `any(...)` is `False` and the trim is skipped. The lists keep stale length while `self.uids` is correctly trimmed. When new sequences subsequently arrive via `extend`, their entries are appended to the over-long list. `_step`'s per-sequence loop (`for e in range(len(self.uids)): for processor in self.logits_processors[e]:`) then reads stale-index slots — silently bypassing the new sequence's processor for exactly one generation step before the next `filter` corrects the length. The structurally-symmetric `PromptProcessingBatch.filter` (around line 1117) already handles this case with explicit `else` branches that reset to `[None] * len(keep)` / `[[]] * len(keep)`. This commit mirrors those branches onto `GenerationBatch.filter`. Discovered while running schema-constrained generation (`response_format=json_schema`, `strict: true`) in a production-shape FastAPI server. After any tool-bearing chat completion (which carries no `logits_processors`), the very next request's grammar processor was silently bypassed — model returned plain unconstrained text instead of schema-valid output. The bug was silent (no exception, no log) and affected exactly one request before self-recovering, which made it ordering-dependent and hard to spot.
Author
|
Cross-reference: #1230 lands the symmetric fix in The two PRs touch different code paths (
Both #1225 and #1230 use this asymmetric pair ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
GenerationBatch.filter(around line 1392 ofmlx_lm/generate.py)uses
if any(self.samplers)andif any(self.logits_processors)toguard whether it trims those per-sequence lists. When every current
slot is
None/[](the common shape for any in-flight requestthat does not attach a custom sampler or logits processor),
any(...)isFalseand the trim is skipped. The lists keep stalelength while
self.uidsis correctly trimmed.When new sequences subsequently arrive via
extend, their entriesare appended to the over-long list.
_step's per-sequence loop(
for e in range(len(self.uids)): for processor in self.logits_processors[e]:)then reads stale-index slots — silently bypassing the new sequence's
processor for exactly one generation step before the next
filtercorrects the length.
The structurally-symmetric
PromptProcessingBatch.filter(aroundline 1117) already handles this case with explicit
elsebranchesthat reset to
[None] * len(keep)/[[]] * len(keep). This PRmirrors those branches onto
GenerationBatch.filter.Impact (downstream)
Discovered while running schema-constrained generation
(
response_format=json_schemawithstrict: true) in aproduction-shape FastAPI server. After any tool-bearing chat
completion (which carries no
logits_processors), the very nextrequest's
xgrammar-based grammar processor was silently bypassed —the model returned plain unconstrained text instead of schema-valid
output. The bug is silent (no exception, no log, no error code) and
affects exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.
Diff
The
[[]](logits_processors) and[None](samplers) defaults arecopied verbatim from
PromptProcessingBatch.filterso the twomethods stay symmetric.
Standalone reproducer (no model required)
Verified output (against
mlx-lm 0.31.3andmainat the time ofthis PR):
With the diff applied:
Suggested unit test
Related (different bugs)
mtp_generate_step: logits processors see stale prev_tokens on draft calls— different mechanism (token history notupdated in the MTP path), distinct from the list-length bug fixed
here.
Stateful logits processors see stale tokens due to lazy evaluation in stream context— also a processor-stalenessstory, but in a different code path (lazy eval ordering in streaming
generation).
The
filter-list-length bug fixed here appears unreported.