@@ -3,37 +3,110 @@ import MLX
33import MLXLMCommon
44import NovaMLXCore
55
6+ /// Composes up to four LogitProcessors into a single chain.
7+ ///
8+ /// **Process order:** `penalty → grammar → turnStop → thinkingBudget`
9+ ///
10+ /// 1. **penalty** modifies logits in place (repetition/freq/presence) — it
11+ /// never masks, so its result is a strictly-monotonic perturbation of the
12+ /// input distribution.
13+ /// 2. **grammar** (JSON / Schema / GBNF / Regex) masks logits to its allowed
14+ /// vocabulary subset. May force EOS via mask.
15+ /// 3. **turnStop** masks to EOS-only when a turn separator string is observed
16+ /// in the decoded output (works regardless of whether the model emits its
17+ /// configured EOS token id).
18+ /// 4. **thinkingBudget** masks to close-marker-only when the configured
19+ /// thinking budget is exhausted and the model has not yet emitted a
20+ /// close-marker token. Placed LAST so its mask overrides any prior
21+ /// constraint — when the budget fires, "emit close marker now" wins
22+ /// unconditionally. In production we ensure budget is only built when
23+ /// grammar is absent, so the override is benign.
24+ ///
25+ /// All four slots are optional. Caller must ensure at least one is non-nil
26+ /// (otherwise compose() returns nil — see the convenience builder below).
627public final class ComposedLogitProcessor : LogitProcessor , @unchecked Sendable {
7- private var grammarProcessor : any LogitProcessor
28+ private var grammarProcessor : ( any LogitProcessor ) ?
829 private var penaltyProcessor : ( any LogitProcessor ) ?
930 private var turnStopProcessor : ( any LogitProcessor ) ?
31+ private var thinkingBudgetProcessor : ( any LogitProcessor ) ?
1032
1133 public init (
1234 grammarProcessor: any LogitProcessor ,
1335 penaltyProcessor: ( any LogitProcessor ) ? ,
14- turnStopProcessor: ( any LogitProcessor ) ? = nil
36+ turnStopProcessor: ( any LogitProcessor ) ? = nil ,
37+ thinkingBudgetProcessor: ( any LogitProcessor ) ? = nil
1538 ) {
1639 self . grammarProcessor = grammarProcessor
1740 self . penaltyProcessor = penaltyProcessor
1841 self . turnStopProcessor = turnStopProcessor
42+ self . thinkingBudgetProcessor = thinkingBudgetProcessor
43+ }
44+
45+ /// Designated initializer accepting an optional grammar slot. Used when
46+ /// the chain consists of `penalty + turnStop + thinkingBudget` only
47+ /// (no grammar / JSON / schema constraint).
48+ public init (
49+ grammarProcessor: ( any LogitProcessor ) ? ,
50+ penaltyProcessor: ( any LogitProcessor ) ? ,
51+ turnStopProcessor: ( any LogitProcessor ) ? ,
52+ thinkingBudgetProcessor: ( any LogitProcessor ) ? ,
53+ _ allowAllNil: Bool = false
54+ ) {
55+ precondition (
56+ allowAllNil
57+ || grammarProcessor != nil
58+ || penaltyProcessor != nil
59+ || turnStopProcessor != nil
60+ || thinkingBudgetProcessor != nil ,
61+ " ComposedLogitProcessor: at least one slot must be non-nil "
62+ )
63+ self . grammarProcessor = grammarProcessor
64+ self . penaltyProcessor = penaltyProcessor
65+ self . turnStopProcessor = turnStopProcessor
66+ self . thinkingBudgetProcessor = thinkingBudgetProcessor
67+ }
68+
69+ /// Convenience builder — flattens trivially-empty chains.
70+ /// - If exactly one slot is non-nil, returns that slot directly (avoids
71+ /// the per-call allocation overhead of the composed wrapper).
72+ /// - If all slots are nil, returns nil.
73+ /// - Otherwise returns a `ComposedLogitProcessor` with all non-nil slots.
74+ public static func compose(
75+ grammar: ( any LogitProcessor ) ? = nil ,
76+ penalty: ( any LogitProcessor ) ? = nil ,
77+ turnStop: ( any LogitProcessor ) ? = nil ,
78+ thinkingBudget: ( any LogitProcessor ) ? = nil
79+ ) -> ( any LogitProcessor ) ? {
80+ let nonNil = [ grammar, penalty, turnStop, thinkingBudget] . compactMap { $0 }
81+ if nonNil. isEmpty { return nil }
82+ if nonNil. count == 1 { return nonNil [ 0 ] }
83+ return ComposedLogitProcessor (
84+ grammarProcessor: grammar,
85+ penaltyProcessor: penalty,
86+ turnStopProcessor: turnStop,
87+ thinkingBudgetProcessor: thinkingBudget
88+ )
1989 }
2090
2191 public func prompt( _ prompt: MLXArray ) {
22- grammarProcessor. prompt ( prompt)
92+ grammarProcessor? . prompt ( prompt)
2393 penaltyProcessor? . prompt ( prompt)
2494 turnStopProcessor? . prompt ( prompt)
95+ thinkingBudgetProcessor? . prompt ( prompt)
2596 }
2697
2798 public func process( logits: MLXArray ) -> MLXArray {
2899 var result = penaltyProcessor? . process ( logits: logits) ?? logits
29- result = grammarProcessor. process ( logits: result)
100+ result = grammarProcessor? . process ( logits: result) ?? result
30101 result = turnStopProcessor? . process ( logits: result) ?? result
102+ result = thinkingBudgetProcessor? . process ( logits: result) ?? result
31103 return result
32104 }
33105
34106 public func didSample( token: MLXArray ) {
35- grammarProcessor. didSample ( token: token)
107+ grammarProcessor? . didSample ( token: token)
36108 penaltyProcessor? . didSample ( token: token)
37109 turnStopProcessor? . didSample ( token: token)
110+ thinkingBudgetProcessor? . didSample ( token: token)
38111 }
39112}
0 commit comments