Skip to content

Commit e26afe7

Browse files
author
lucasliu
committed
fix(scheduler): eliminate 4 concurrency races in FusedBatchScheduler
- Route all shared-state continuation.yield/finish through safeYield/safeFinish (FinishGuard prevents double-finish that caused SIGSEGV in _swift_release_dealloc) - Guard budget release with _finishedByDecodeStep flag — preempt/abort paths release budget themselves, decode step skips if already released - Replace TOCTOU canAdmit()+separate-increment with atomic canAdmitAndReserve() that checks concurrency limit AND reserves slot in one lock acquisition - Fix fire-and-forget Task in admitQueued with defer+resumed guard to guarantee CheckedContinuation is resumed exactly once even on cancellation
1 parent 97d9bb6 commit e26afe7

2 files changed

Lines changed: 83 additions & 52 deletions

File tree

Sources/NovaMLXCore/Types.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Logging
33

44
public enum NovaMLX {}
55

6-
public let version = "1.0.6"
6+
public let version = "1.0.7"
77

88
public var buildTimestamp: String {
99
guard let execURL = Bundle.main.executableURL,

Sources/NovaMLXEngine/FusedBatchScheduler.swift

Lines changed: 82 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ struct ActiveStreamSequence: @unchecked Sendable {
9191
/// finish — `isFinished` is advisory for fast reads, `finishGuard` is the
9292
/// truth for continuation lifecycle and budget release ownership.
9393
let finishGuard: FinishGuard
94+
/// True only when *this* decode step was the one that successfully
95+
/// called `safeFinish()` — prevents double budget release when
96+
/// preempt/abort already released the sequence.
97+
var _finishedByDecodeStep: Bool = false
9498
/// Recent token IDs for N-gram speculation context.
9599
var recentTokenIds: [Int]
96100
/// Frequency penalty value — prevents repetition collapse in small quantized models.
@@ -265,22 +269,13 @@ public final class FusedBatchScheduler: @unchecked Sendable {
265269

266270
public func submit(_ request: InferenceRequest) async throws -> InferenceResult {
267271
let modelId = request.model
268-
let bytesPerToken = engine.effectiveBytesPerToken(modelId: modelId)
269-
let estimatedTokens = engine.estimateRequestTokens(modelId: modelId, request: request)
270272

271-
// Quick admission check — if memory is tight, fall back to engine (no fused optimization)
272-
guard await canAdmit(request) else {
273+
// Atomic admission check + slot reserve — if memory/concurrency tight, fall back to engine
274+
guard await canAdmitAndReserve(request) else {
273275
NovaMLXLog.info("FusedScheduler: submit() can't admit — falling back to engine")
274276
return try await engine.generate(request)
275277
}
276278

277-
// Reserve budget
278-
lock.withLock { activeModelCounts[modelId] = (activeModelCounts[modelId] ?? 0) + 1 }
279-
await budgetTracker.reserve(
280-
modelId: modelId, sequenceId: request.id,
281-
weightsBytes: 0, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
282-
)
283-
284279
let startTime = Date()
285280
let promptTokensBox = MutableSendableBox<Int>(0)
286281

@@ -304,6 +299,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
304299
await self.budgetTracker.release(sequenceId: request.id)
305300
}
306301
} catch {
302+
// SAFE: pre-prefill error path, continuation is exclusively owned
307303
continuation.finish(throwing: error)
308304
self.lock.withLock {
309305
self.activeModelCounts[modelId] = max(0, (self.activeModelCounts[modelId] ?? 1) - 1)
@@ -349,8 +345,8 @@ public final class FusedBatchScheduler: @unchecked Sendable {
349345
return AsyncThrowingStream { continuation in
350346
Task {
351347
do {
352-
// Check admission
353-
let canStart = await self.canAdmit(request)
348+
// Atomic admission check + slot reserve
349+
let canStart = await self.canAdmitAndReserve(request)
354350
if !canStart {
355351
// Queue for later admission
356352
NovaMLXLog.info("FusedScheduler: queuing stream \(reqTag) — memory/concurrency limit")
@@ -364,15 +360,6 @@ public final class FusedBatchScheduler: @unchecked Sendable {
364360
return
365361
}
366362

367-
// Reserve budget
368-
let bytesPerToken = self.engine.effectiveBytesPerToken(modelId: modelId)
369-
let estimatedTokens = self.engine.estimateRequestTokens(modelId: modelId, request: request)
370-
self.lock.withLock { self.activeModelCounts[modelId] = (self.activeModelCounts[modelId] ?? 0) + 1 }
371-
await self.budgetTracker.reserve(
372-
modelId: modelId, sequenceId: request.id,
373-
weightsBytes: 0, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
374-
)
375-
376363
// Prefill and add to active sequences
377364
let seq = try await self.prefillSequence(request: request, continuation: continuation)
378365
if !seq.isFinished {
@@ -390,6 +377,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
390377
await self.budgetTracker.release(sequenceId: request.id)
391378
}
392379
} catch {
380+
// SAFE: pre-prefill error path, continuation is exclusively owned
393381
continuation.finish(throwing: error)
394382
self.lock.withLock {
395383
self.activeModelCounts[modelId] = max(0, (self.activeModelCounts[modelId] ?? 1) - 1)
@@ -415,6 +403,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
415403
}
416404
self.scheduleRunLoop()
417405
} catch {
406+
// SAFE: pre-prefill error path, continuation is exclusively owned
418407
continuation.finish(throwing: error)
419408
self.lock.withLock {
420409
self.activeModelCounts[modelId] = max(0, (self.activeModelCounts[modelId] ?? 1) - 1)
@@ -450,6 +439,47 @@ public final class FusedBatchScheduler: @unchecked Sendable {
450439
return canAdmitMemory
451440
}
452441

442+
/// Atomically check admission AND reserve the concurrency slot.
443+
/// Returns `false` if the concurrency limit is reached or memory budget insufficient.
444+
/// On `true`, the caller MUST proceed — the slot is already counted.
445+
private func canAdmitAndReserve(_ request: InferenceRequest) async -> Bool {
446+
let modelId = request.model
447+
let concurrentLimit = await optimalConcurrency(for: modelId)
448+
449+
// Atomically check limit AND reserve slot
450+
let reserved = lock.withLock { () -> Bool in
451+
let current = activeModelCounts[modelId] ?? 0
452+
guard current < concurrentLimit else { return false }
453+
activeModelCounts[modelId] = current + 1
454+
return true
455+
}
456+
guard reserved else {
457+
NovaMLXLog.info("FusedScheduler: queuing \(request.id.uuidString.prefix(8)) — model at concurrency limit")
458+
return false
459+
}
460+
461+
// Check memory budget — roll back slot if insufficient
462+
let bytesPerToken = engine.effectiveBytesPerToken(modelId: modelId)
463+
let estimatedTokens = engine.estimateRequestTokens(modelId: modelId, request: request)
464+
let canAdmitMemory = await budgetTracker.canAdmit(
465+
modelId: modelId, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
466+
)
467+
468+
guard canAdmitMemory else {
469+
// Roll back the slot we just reserved
470+
lock.withLock { activeModelCounts[modelId] = max(0, (activeModelCounts[modelId] ?? 1) - 1) }
471+
NovaMLXLog.info("FusedScheduler: queuing \(request.id.uuidString.prefix(8)) — insufficient memory")
472+
return false
473+
}
474+
475+
// Reserve memory budget (slot already counted above)
476+
await budgetTracker.reserve(
477+
modelId: modelId, sequenceId: request.id,
478+
weightsBytes: 0, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
479+
)
480+
return true
481+
}
482+
453483
// MARK: - Preemption
454484

455485
/// Preempt the newest active sequence for the given model to free memory.
@@ -484,7 +514,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
484514
}
485515

486516
// Finish the victim's stream with a retryable error
487-
victim.continuation.finish(throwing: NovaMLXError.inferenceFailed(
517+
victim.safeFinish(throwing: NovaMLXError.inferenceFailed(
488518
"Sequence preempted due to memory pressure — please retry"
489519
))
490520

@@ -690,6 +720,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
690720
let isEOS = eosId == firstTokenId
691721

692722
// Yield first token to client (unless it's EOS and we finish immediately)
723+
// SAFE: pre-shared-state yields — seq hasn't entered activeByModel yet
693724
let scrubbedFirstToken = MLXEngine.scrubControlTokens(firstTokenText)
694725
if !isEOS && !scrubbedFirstToken.isEmpty {
695726
continuation.yield(Token(id: 0, text: scrubbedFirstToken))
@@ -796,6 +827,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
796827
}
797828
NovaMLXLog.info("[Prefill:\(reqTag)] First token (remaining): id=\(firstTokenId), text='\(firstTokenText)'")
798829

830+
// SAFE: pre-shared-state yield — seq hasn't entered activeByModel yet
799831
let scrubbedPrefixToken = MLXEngine.scrubControlTokens(firstTokenText)
800832
if !scrubbedPrefixToken.isEmpty {
801833
continuation.yield(Token(id: 0, text: scrubbedPrefixToken))
@@ -903,7 +935,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
903935
)
904936

905937
for item in sortedGenerate {
906-
let canStart = await canAdmit(item.request)
938+
let canStart = await canAdmitAndReserve(item.request)
907939
if canStart {
908940
admittedGenerate.append(item)
909941
} else {
@@ -912,7 +944,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
912944
let preempted = await preemptNewest(for: item.request.model, toMakeRoomFor: item.request.id)
913945
if preempted {
914946
// Re-check admission after preemption freed some memory
915-
let retryCanAdmit = await canAdmit(item.request)
947+
let retryCanAdmit = await canAdmitAndReserve(item.request)
916948
if retryCanAdmit {
917949
admittedGenerate.append(item)
918950
continue
@@ -926,18 +958,16 @@ public final class FusedBatchScheduler: @unchecked Sendable {
926958

927959
for item in admittedGenerate {
928960
let modelId = item.request.model
929-
let bytesPerToken = engine.effectiveBytesPerToken(modelId: modelId)
930-
let estimatedTokens = engine.estimateRequestTokens(modelId: modelId, request: item.request)
931-
932-
lock.withLock { activeModelCounts[modelId] = (activeModelCounts[modelId] ?? 0) + 1 }
933-
await budgetTracker.reserve(
934-
modelId: modelId, sequenceId: item.request.id,
935-
weightsBytes: 0, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
936-
)
937-
961+
// Budget already reserved atomically by canAdmitAndReserve()
938962
// Wrap generate as stream, collect result
939963
let stream = internalSubmitStream(item.request)
940964
Task {
965+
var resumed = false
966+
defer {
967+
if !resumed {
968+
item.continuation.resume(throwing: CancellationError())
969+
}
970+
}
941971
var text = ""
942972
var completionTokens = 0
943973
var finishReason: FinishReason = .stop
@@ -949,12 +979,14 @@ public final class FusedBatchScheduler: @unchecked Sendable {
949979
let promptTokens = lock.withLock {
950980
activeByModel[modelId]?.first?.promptTokenCount ?? 0
951981
}
982+
resumed = true
952983
item.continuation.resume(returning: InferenceResult(
953984
id: item.request.id, model: modelId, text: text,
954985
tokensPerSecond: 0, promptTokens: promptTokens,
955986
completionTokens: completionTokens, finishReason: finishReason
956987
))
957988
} catch {
989+
resumed = true
958990
item.continuation.resume(throwing: error)
959991
}
960992
}
@@ -972,15 +1004,15 @@ public final class FusedBatchScheduler: @unchecked Sendable {
9721004
)
9731005

9741006
for item in sortedStream {
975-
let canStart = await canAdmit(item.request)
1007+
let canStart = await canAdmitAndReserve(item.request)
9761008
if canStart {
9771009
admittedStream.append(item)
9781010
} else {
9791011
// Try preemption for high-priority items
9801012
if item.priority == .high {
9811013
let preempted = await preemptNewest(for: item.request.model, toMakeRoomFor: item.request.id)
9821014
if preempted {
983-
let retryCanAdmit = await canAdmit(item.request)
1015+
let retryCanAdmit = await canAdmitAndReserve(item.request)
9841016
if retryCanAdmit {
9851017
admittedStream.append(item)
9861018
continue
@@ -994,14 +1026,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
9941026

9951027
for item in admittedStream {
9961028
let modelId = item.request.model
997-
let bytesPerToken = engine.effectiveBytesPerToken(modelId: modelId)
998-
let estimatedTokens = engine.estimateRequestTokens(modelId: modelId, request: item.request)
999-
1000-
lock.withLock { activeModelCounts[modelId] = (activeModelCounts[modelId] ?? 0) + 1 }
1001-
await budgetTracker.reserve(
1002-
modelId: modelId, sequenceId: item.request.id,
1003-
weightsBytes: 0, estimatedTokens: estimatedTokens, bytesPerToken: bytesPerToken
1004-
)
1029+
// Budget already reserved atomically by canAdmitAndReserve()
10051030

10061031
do {
10071032
let seq = try await prefillSequence(request: item.request, continuation: item.continuation)
@@ -1010,6 +1035,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
10101035
activeByModel[modelId]?.append(seq)
10111036
}
10121037
} catch {
1038+
// SAFE: pre-prefill error path, continuation is exclusively owned
10131039
item.continuation.finish(throwing: error)
10141040
lock.withLock {
10151041
activeModelCounts[modelId] = max(0, (activeModelCounts[modelId] ?? 1) - 1)
@@ -1036,8 +1062,9 @@ public final class FusedBatchScheduler: @unchecked Sendable {
10361062
guard let container = engine.getContainer(for: modelId),
10371063
let mlxContainer = container.mlxContainer else {
10381064
for seq in active {
1039-
seq.continuation.finish(throwing: NovaMLXError.modelNotFound(modelId))
1040-
await budgetTracker.release(sequenceId: seq.id)
1065+
if seq.safeFinish(throwing: NovaMLXError.modelNotFound(modelId)) {
1066+
await budgetTracker.release(sequenceId: seq.id)
1067+
}
10411068
}
10421069
lock.withLock { activeByModel[modelId] = []; activeModelCounts[modelId] = 0 }
10431070
continue
@@ -1285,7 +1312,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
12851312
}
12861313
seq.generatedText = seq.lastDecodedText
12871314
if !decoded.isEmpty {
1288-
seq.continuation.yield(Token(id: 0, text: decoded))
1315+
seq.safeYield(Token(id: 0, text: decoded))
12891316
}
12901317

12911318
lock.withLock { totalTokensViaFused += 1 }
@@ -1337,13 +1364,14 @@ public final class FusedBatchScheduler: @unchecked Sendable {
13371364
// its fallback promotes the entire buffer to
13381365
// `content`, leaving `reasoning_content` empty.
13391366
if seq.harmonyInThinking {
1340-
seq.continuation.yield(Token(id: 0, text: "</think>"))
1367+
seq.safeYield(Token(id: 0, text: "</think>"))
13411368
seq.harmonyInThinking = false
13421369
}
1343-
seq.continuation.yield(Token(id: 0, text: "", finishReason: finishReason))
1344-
seq.continuation.finish()
1370+
seq.safeYield(Token(id: 0, text: "", finishReason: finishReason))
1371+
let finishedByUs = seq.safeFinish()
13451372
seq.isFinished = true
13461373
sequenceFinished = true
1374+
seq._finishedByDecodeStep = finishedByUs
13471375
break
13481376
}
13491377
}
@@ -1380,6 +1408,9 @@ public final class FusedBatchScheduler: @unchecked Sendable {
13801408
}
13811409

13821410
for seq in finished {
1411+
// Only release budget if WE finished the sequence in this decode step.
1412+
// Preempt/abort paths release budget themselves; skip to avoid double-release.
1413+
guard seq._finishedByDecodeStep else { continue }
13831414
await budgetTracker.release(sequenceId: seq.id)
13841415
lock.withLock {
13851416
activeModelCounts[modelId] = max(0, (activeModelCounts[modelId] ?? 1) - 1)
@@ -1452,7 +1483,7 @@ public final class FusedBatchScheduler: @unchecked Sendable {
14521483
for (_, var sequences) in activeByModel {
14531484
if let idx = sequences.firstIndex(where: { $0.id == requestId }) {
14541485
sequences[idx].isFinished = true
1455-
sequences[idx].continuation.finish()
1486+
sequences[idx].safeFinish()
14561487
sequences.remove(at: idx)
14571488
}
14581489
}

0 commit comments

Comments
 (0)