@@ -91,6 +91,10 @@ struct ActiveStreamSequence: @unchecked Sendable {
9191 /// finish — `isFinished` is advisory for fast reads, `finishGuard` is the
9292 /// truth for continuation lifecycle and budget release ownership.
9393 let finishGuard : FinishGuard
94+ /// True only when *this* decode step was the one that successfully
95+ /// called `safeFinish()` — prevents double budget release when
96+ /// preempt/abort already released the sequence.
97+ var _finishedByDecodeStep : Bool = false
9498 /// Recent token IDs for N-gram speculation context.
9599 var recentTokenIds : [ Int ]
96100 /// Frequency penalty value — prevents repetition collapse in small quantized models.
@@ -265,22 +269,13 @@ public final class FusedBatchScheduler: @unchecked Sendable {
265269
266270 public func submit( _ request: InferenceRequest ) async throws -> InferenceResult {
267271 let modelId = request. model
268- let bytesPerToken = engine. effectiveBytesPerToken ( modelId: modelId)
269- let estimatedTokens = engine. estimateRequestTokens ( modelId: modelId, request: request)
270272
271- // Quick admission check — if memory is tight, fall back to engine (no fused optimization)
272- guard await canAdmit ( request) else {
273+ // Atomic admission check + slot reserve — if memory/concurrency tight, fall back to engine
274+ guard await canAdmitAndReserve ( request) else {
273275 NovaMLXLog . info ( " FusedScheduler: submit() can't admit — falling back to engine " )
274276 return try await engine. generate ( request)
275277 }
276278
277- // Reserve budget
278- lock. withLock { activeModelCounts [ modelId] = ( activeModelCounts [ modelId] ?? 0 ) + 1 }
279- await budgetTracker. reserve (
280- modelId: modelId, sequenceId: request. id,
281- weightsBytes: 0 , estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
282- )
283-
284279 let startTime = Date ( )
285280 let promptTokensBox = MutableSendableBox < Int > ( 0 )
286281
@@ -304,6 +299,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
304299 await self . budgetTracker. release ( sequenceId: request. id)
305300 }
306301 } catch {
302+ // SAFE: pre-prefill error path, continuation is exclusively owned
307303 continuation. finish ( throwing: error)
308304 self . lock. withLock {
309305 self . activeModelCounts [ modelId] = max ( 0 , ( self . activeModelCounts [ modelId] ?? 1 ) - 1 )
@@ -349,8 +345,8 @@ public final class FusedBatchScheduler: @unchecked Sendable {
349345 return AsyncThrowingStream { continuation in
350346 Task {
351347 do {
352- // Check admission
353- let canStart = await self . canAdmit ( request)
348+ // Atomic admission check + slot reserve
349+ let canStart = await self . canAdmitAndReserve ( request)
354350 if !canStart {
355351 // Queue for later admission
356352 NovaMLXLog . info ( " FusedScheduler: queuing stream \( reqTag) — memory/concurrency limit " )
@@ -364,15 +360,6 @@ public final class FusedBatchScheduler: @unchecked Sendable {
364360 return
365361 }
366362
367- // Reserve budget
368- let bytesPerToken = self . engine. effectiveBytesPerToken ( modelId: modelId)
369- let estimatedTokens = self . engine. estimateRequestTokens ( modelId: modelId, request: request)
370- self . lock. withLock { self . activeModelCounts [ modelId] = ( self . activeModelCounts [ modelId] ?? 0 ) + 1 }
371- await self . budgetTracker. reserve (
372- modelId: modelId, sequenceId: request. id,
373- weightsBytes: 0 , estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
374- )
375-
376363 // Prefill and add to active sequences
377364 let seq = try await self . prefillSequence ( request: request, continuation: continuation)
378365 if !seq. isFinished {
@@ -390,6 +377,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
390377 await self . budgetTracker. release ( sequenceId: request. id)
391378 }
392379 } catch {
380+ // SAFE: pre-prefill error path, continuation is exclusively owned
393381 continuation. finish ( throwing: error)
394382 self . lock. withLock {
395383 self . activeModelCounts [ modelId] = max ( 0 , ( self . activeModelCounts [ modelId] ?? 1 ) - 1 )
@@ -415,6 +403,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
415403 }
416404 self . scheduleRunLoop ( )
417405 } catch {
406+ // SAFE: pre-prefill error path, continuation is exclusively owned
418407 continuation. finish ( throwing: error)
419408 self . lock. withLock {
420409 self . activeModelCounts [ modelId] = max ( 0 , ( self . activeModelCounts [ modelId] ?? 1 ) - 1 )
@@ -450,6 +439,47 @@ public final class FusedBatchScheduler: @unchecked Sendable {
450439 return canAdmitMemory
451440 }
452441
442+ /// Atomically check admission AND reserve the concurrency slot.
443+ /// Returns `false` if the concurrency limit is reached or memory budget insufficient.
444+ /// On `true`, the caller MUST proceed — the slot is already counted.
445+ private func canAdmitAndReserve( _ request: InferenceRequest ) async -> Bool {
446+ let modelId = request. model
447+ let concurrentLimit = await optimalConcurrency ( for: modelId)
448+
449+ // Atomically check limit AND reserve slot
450+ let reserved = lock. withLock { ( ) -> Bool in
451+ let current = activeModelCounts [ modelId] ?? 0
452+ guard current < concurrentLimit else { return false }
453+ activeModelCounts [ modelId] = current + 1
454+ return true
455+ }
456+ guard reserved else {
457+ NovaMLXLog . info ( " FusedScheduler: queuing \( request. id. uuidString. prefix ( 8 ) ) — model at concurrency limit " )
458+ return false
459+ }
460+
461+ // Check memory budget — roll back slot if insufficient
462+ let bytesPerToken = engine. effectiveBytesPerToken ( modelId: modelId)
463+ let estimatedTokens = engine. estimateRequestTokens ( modelId: modelId, request: request)
464+ let canAdmitMemory = await budgetTracker. canAdmit (
465+ modelId: modelId, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
466+ )
467+
468+ guard canAdmitMemory else {
469+ // Roll back the slot we just reserved
470+ lock. withLock { activeModelCounts [ modelId] = max ( 0 , ( activeModelCounts [ modelId] ?? 1 ) - 1 ) }
471+ NovaMLXLog . info ( " FusedScheduler: queuing \( request. id. uuidString. prefix ( 8 ) ) — insufficient memory " )
472+ return false
473+ }
474+
475+ // Reserve memory budget (slot already counted above)
476+ await budgetTracker. reserve (
477+ modelId: modelId, sequenceId: request. id,
478+ weightsBytes: 0 , estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
479+ )
480+ return true
481+ }
482+
453483 // MARK: - Preemption
454484
455485 /// Preempt the newest active sequence for the given model to free memory.
@@ -484,7 +514,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
484514 }
485515
486516 // Finish the victim's stream with a retryable error
487- victim. continuation . finish ( throwing: NovaMLXError . inferenceFailed (
517+ victim. safeFinish ( throwing: NovaMLXError . inferenceFailed (
488518 " Sequence preempted due to memory pressure — please retry "
489519 ) )
490520
@@ -690,6 +720,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
690720 let isEOS = eosId == firstTokenId
691721
692722 // Yield first token to client (unless it's EOS and we finish immediately)
723+ // SAFE: pre-shared-state yields — seq hasn't entered activeByModel yet
693724 let scrubbedFirstToken = MLXEngine . scrubControlTokens ( firstTokenText)
694725 if !isEOS && !scrubbedFirstToken. isEmpty {
695726 continuation. yield ( Token ( id: 0 , text: scrubbedFirstToken) )
@@ -796,6 +827,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
796827 }
797828 NovaMLXLog . info ( " [Prefill: \( reqTag) ] First token (remaining): id= \( firstTokenId) , text=' \( firstTokenText) ' " )
798829
830+ // SAFE: pre-shared-state yield — seq hasn't entered activeByModel yet
799831 let scrubbedPrefixToken = MLXEngine . scrubControlTokens ( firstTokenText)
800832 if !scrubbedPrefixToken. isEmpty {
801833 continuation. yield ( Token ( id: 0 , text: scrubbedPrefixToken) )
@@ -903,7 +935,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
903935 )
904936
905937 for item in sortedGenerate {
906- let canStart = await canAdmit ( item. request)
938+ let canStart = await canAdmitAndReserve ( item. request)
907939 if canStart {
908940 admittedGenerate. append ( item)
909941 } else {
@@ -912,7 +944,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
912944 let preempted = await preemptNewest ( for: item. request. model, toMakeRoomFor: item. request. id)
913945 if preempted {
914946 // Re-check admission after preemption freed some memory
915- let retryCanAdmit = await canAdmit ( item. request)
947+ let retryCanAdmit = await canAdmitAndReserve ( item. request)
916948 if retryCanAdmit {
917949 admittedGenerate. append ( item)
918950 continue
@@ -926,18 +958,16 @@ public final class FusedBatchScheduler: @unchecked Sendable {
926958
927959 for item in admittedGenerate {
928960 let modelId = item. request. model
929- let bytesPerToken = engine. effectiveBytesPerToken ( modelId: modelId)
930- let estimatedTokens = engine. estimateRequestTokens ( modelId: modelId, request: item. request)
931-
932- lock. withLock { activeModelCounts [ modelId] = ( activeModelCounts [ modelId] ?? 0 ) + 1 }
933- await budgetTracker. reserve (
934- modelId: modelId, sequenceId: item. request. id,
935- weightsBytes: 0 , estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
936- )
937-
961+ // Budget already reserved atomically by canAdmitAndReserve()
938962 // Wrap generate as stream, collect result
939963 let stream = internalSubmitStream ( item. request)
940964 Task {
965+ var resumed = false
966+ defer {
967+ if !resumed {
968+ item. continuation. resume ( throwing: CancellationError ( ) )
969+ }
970+ }
941971 var text = " "
942972 var completionTokens = 0
943973 var finishReason : FinishReason = . stop
@@ -949,12 +979,14 @@ public final class FusedBatchScheduler: @unchecked Sendable {
949979 let promptTokens = lock. withLock {
950980 activeByModel [ modelId] ? . first? . promptTokenCount ?? 0
951981 }
982+ resumed = true
952983 item. continuation. resume ( returning: InferenceResult (
953984 id: item. request. id, model: modelId, text: text,
954985 tokensPerSecond: 0 , promptTokens: promptTokens,
955986 completionTokens: completionTokens, finishReason: finishReason
956987 ) )
957988 } catch {
989+ resumed = true
958990 item. continuation. resume ( throwing: error)
959991 }
960992 }
@@ -972,15 +1004,15 @@ public final class FusedBatchScheduler: @unchecked Sendable {
9721004 )
9731005
9741006 for item in sortedStream {
975- let canStart = await canAdmit ( item. request)
1007+ let canStart = await canAdmitAndReserve ( item. request)
9761008 if canStart {
9771009 admittedStream. append ( item)
9781010 } else {
9791011 // Try preemption for high-priority items
9801012 if item. priority == . high {
9811013 let preempted = await preemptNewest ( for: item. request. model, toMakeRoomFor: item. request. id)
9821014 if preempted {
983- let retryCanAdmit = await canAdmit ( item. request)
1015+ let retryCanAdmit = await canAdmitAndReserve ( item. request)
9841016 if retryCanAdmit {
9851017 admittedStream. append ( item)
9861018 continue
@@ -994,14 +1026,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
9941026
9951027 for item in admittedStream {
9961028 let modelId = item. request. model
997- let bytesPerToken = engine. effectiveBytesPerToken ( modelId: modelId)
998- let estimatedTokens = engine. estimateRequestTokens ( modelId: modelId, request: item. request)
999-
1000- lock. withLock { activeModelCounts [ modelId] = ( activeModelCounts [ modelId] ?? 0 ) + 1 }
1001- await budgetTracker. reserve (
1002- modelId: modelId, sequenceId: item. request. id,
1003- weightsBytes: 0 , estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
1004- )
1029+ // Budget already reserved atomically by canAdmitAndReserve()
10051030
10061031 do {
10071032 let seq = try await prefillSequence ( request: item. request, continuation: item. continuation)
@@ -1010,6 +1035,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
10101035 activeByModel [ modelId] ? . append ( seq)
10111036 }
10121037 } catch {
1038+ // SAFE: pre-prefill error path, continuation is exclusively owned
10131039 item. continuation. finish ( throwing: error)
10141040 lock. withLock {
10151041 activeModelCounts [ modelId] = max ( 0 , ( activeModelCounts [ modelId] ?? 1 ) - 1 )
@@ -1036,8 +1062,9 @@ public final class FusedBatchScheduler: @unchecked Sendable {
10361062 guard let container = engine. getContainer ( for: modelId) ,
10371063 let mlxContainer = container. mlxContainer else {
10381064 for seq in active {
1039- seq. continuation. finish ( throwing: NovaMLXError . modelNotFound ( modelId) )
1040- await budgetTracker. release ( sequenceId: seq. id)
1065+ if seq. safeFinish ( throwing: NovaMLXError . modelNotFound ( modelId) ) {
1066+ await budgetTracker. release ( sequenceId: seq. id)
1067+ }
10411068 }
10421069 lock. withLock { activeByModel [ modelId] = [ ] ; activeModelCounts [ modelId] = 0 }
10431070 continue
@@ -1285,7 +1312,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
12851312 }
12861313 seq. generatedText = seq. lastDecodedText
12871314 if !decoded. isEmpty {
1288- seq. continuation . yield ( Token ( id: 0 , text: decoded) )
1315+ seq. safeYield ( Token ( id: 0 , text: decoded) )
12891316 }
12901317
12911318 lock. withLock { totalTokensViaFused += 1 }
@@ -1337,13 +1364,14 @@ public final class FusedBatchScheduler: @unchecked Sendable {
13371364 // its fallback promotes the entire buffer to
13381365 // `content`, leaving `reasoning_content` empty.
13391366 if seq. harmonyInThinking {
1340- seq. continuation . yield ( Token ( id: 0 , text: " </think> " ) )
1367+ seq. safeYield ( Token ( id: 0 , text: " </think> " ) )
13411368 seq. harmonyInThinking = false
13421369 }
1343- seq. continuation . yield ( Token ( id: 0 , text: " " , finishReason: finishReason) )
1344- seq. continuation . finish ( )
1370+ seq. safeYield ( Token ( id: 0 , text: " " , finishReason: finishReason) )
1371+ let finishedByUs = seq. safeFinish ( )
13451372 seq. isFinished = true
13461373 sequenceFinished = true
1374+ seq. _finishedByDecodeStep = finishedByUs
13471375 break
13481376 }
13491377 }
@@ -1380,6 +1408,9 @@ public final class FusedBatchScheduler: @unchecked Sendable {
13801408 }
13811409
13821410 for seq in finished {
1411+ // Only release budget if WE finished the sequence in this decode step.
1412+ // Preempt/abort paths release budget themselves; skip to avoid double-release.
1413+ guard seq. _finishedByDecodeStep else { continue }
13831414 await budgetTracker. release ( sequenceId: seq. id)
13841415 lock. withLock {
13851416 activeModelCounts [ modelId] = max ( 0 , ( activeModelCounts [ modelId] ?? 1 ) - 1 )
@@ -1452,7 +1483,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
14521483 for (_, var sequences) in activeByModel {
14531484 if let idx = sequences. firstIndex ( where: { $0. id == requestId } ) {
14541485 sequences [ idx] . isFinished = true
1455- sequences [ idx] . continuation . finish ( )
1486+ sequences [ idx] . safeFinish ( )
14561487 sequences. remove ( at: idx)
14571488 }
14581489 }
0 commit comments