Skip to content

Commit c333a11

Browse files
author
lucasliu
committed
feat(api): add logprobs and top_logprobs support (OpenAI standard)
Implements OpenAI-compatible logprobs/top_logprobs across the inference pipeline, closing a P0 Ollama-parity gap for downstream eval tooling (LangChain, RAG evaluators, classifier wrappers). Pipeline: - InferenceRequest gains includeLogprobs + topLogprobsCount; Token gains logprob + topLogprobs; InferenceResult aggregates tokenLogprobs. - FusedBatchScheduler.computeRawLogprobs runs inside mlxContainer.perform; raw (Sendable) data crosses the actor boundary and tokenizer decoding happens outside. - APIServer maps to OpenAILogprobs/Entry/TopLogprob with snake_case keys; populated on both non-stream choices and stream chunks. Correctness: - Spec decoding bypassed when logprobs requested so every accepted token carries logprob data (no silent gaps). - Streaming logprobs attached to only the first SSE chunk per decode token, preventing double-counting when ThinkingParser splits a token into reasoning + content chunks. - bytes field populated from UTF-8 per OpenAI spec. Performance: - Top-K via argPartition (O(N) + O(K log K)) instead of full argSort over the entire vocabulary on every decode step. - Sampled-token-only path (top_logprobs=0) skips eval(logSoftmax) and pulls a single scalar. Tests: +15 (4 semantic invariant tests on top of plumbing/round-trip): order preservation, nil aggregation, UTF-8 bytes (ASCII + CJK), descending top_logprobs and non-positive log-probability invariants.
1 parent 5d1e8bc commit c333a11

5 files changed

Lines changed: 513 additions & 31 deletions

File tree

Sources/NovaMLXAPI/APIServer.swift

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,9 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
17281728
regexPattern: regexPattern, gbnfGrammar: gbnfGrammar,
17291729
thinkingBudget: openAIReq.resolvedThinkingBudget,
17301730
enableThinking: openAIReq.resolvedEnableThinking,
1731-
preserveThinking: openAIReq.resolvedPreserveThinking
1731+
preserveThinking: openAIReq.resolvedPreserveThinking,
1732+
includeLogprobs: openAIReq.logprobs == true,
1733+
topLogprobsCount: openAIReq.topLogprobs
17321734
)
17331735

17341736
CurrentInferenceModel.shared.modelID = request.model
@@ -1819,7 +1821,8 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
18191821
OpenAIChoice(
18201822
index: 0,
18211823
message: message,
1822-
finishReason: finishReason
1824+
finishReason: finishReason,
1825+
logprobs: result.tokenLogprobs.map { Self.buildLogprobs(from: $0) } ?? nil
18231826
)
18241827
],
18251828
usage: {
@@ -1865,7 +1868,9 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
18651868
regexPattern: regexPattern, gbnfGrammar: gbnfGrammar,
18661869
thinkingBudget: openAIReq.resolvedThinkingBudget,
18671870
enableThinking: openAIReq.resolvedEnableThinking,
1868-
preserveThinking: openAIReq.resolvedPreserveThinking
1871+
preserveThinking: openAIReq.resolvedPreserveThinking,
1872+
includeLogprobs: openAIReq.logprobs == true,
1873+
topLogprobsCount: openAIReq.topLogprobs
18691874
)
18701875

18711876
let keepAliveStream = Self.withSSEKeepAlive(inference.stream(request))
@@ -1899,6 +1904,13 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
18991904
for try await event in keepAliveStream {
19001905
switch event {
19011906
case .token(let token):
1907+
// Compute logprob data once per token. A single decode token
1908+
// can split into multiple SSE chunks (e.g., ThinkingParser
1909+
// emits thinking + content separately). Attach logprobs to
1910+
// only the FIRST emitted chunk to avoid double-counting.
1911+
var tokenLogprobs: OpenAILogprobs? = Self.tokenToLogprobEntry(token).map {
1912+
OpenAILogprobs(content: [$0])
1913+
}
19021914
if let tc = token.toolCall {
19031915
let idx = toolCallCounter.increment()
19041916
let tcDelta = OpenAIToolCallDelta(
@@ -1993,8 +2005,9 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
19932005
let chunk = OpenAIStreamChunk(
19942006
id: chunkId,
19952007
model: openAIReq.model,
1996-
choices: [OpenAIStreamChoice(index: 0, delta: delta)]
2008+
choices: [OpenAIStreamChoice(index: 0, delta: delta, logprobs: tokenLogprobs)]
19972009
)
2010+
tokenLogprobs = nil // consume — only first chunk carries logprobs
19982011
let data = try JSONEncoder().encode(chunk)
19992012
try await writer.write(ByteBuffer(string: "data: \(String(data: data, encoding: .utf8) ?? "")\n\n"))
20002013
}
@@ -2005,8 +2018,9 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
20052018
let chunk = OpenAIStreamChunk(
20062019
id: chunkId,
20072020
model: openAIReq.model,
2008-
choices: [OpenAIStreamChoice(index: 0, delta: delta)]
2021+
choices: [OpenAIStreamChoice(index: 0, delta: delta, logprobs: tokenLogprobs)]
20092022
)
2023+
tokenLogprobs = nil // consume — only first chunk carries logprobs
20102024
let data = try JSONEncoder().encode(chunk)
20112025
try await writer.write(ByteBuffer(string: "data: \(String(data: data, encoding: .utf8) ?? "")\n\n"))
20122026
}
@@ -2586,6 +2600,32 @@ public final class NovaMLXAPIServer: @unchecked Sendable {
25862600
try jsonResponse(value, httpStatus: .ok)
25872601
}
25882602

2603+
/// Convert a stream Token's logprob data to OpenAI response format.
2604+
/// Populates `bytes` with UTF-8 byte values per OpenAI spec.
2605+
static func tokenToLogprobEntry(_ token: Token) -> OpenAILogprobEntry? {
2606+
guard let logprob = token.logprob else { return nil }
2607+
let topEntries: [OpenAITopLogprob] = (token.topLogprobs ?? []).map { tp in
2608+
OpenAITopLogprob(
2609+
token: tp.tokenText,
2610+
logprob: tp.logprob,
2611+
bytes: tp.tokenText.utf8.map(Int.init)
2612+
)
2613+
}
2614+
return OpenAILogprobEntry(
2615+
token: token.text,
2616+
logprob: logprob,
2617+
bytes: token.text.utf8.map(Int.init),
2618+
topLogprobs: topEntries
2619+
)
2620+
}
2621+
2622+
/// Build `OpenAILogprobs` from a collection of tokens with logprob data.
2623+
static func buildLogprobs(from tokens: [Token]) -> OpenAILogprobs? {
2624+
let entries = tokens.compactMap { tokenToLogprobEntry($0) }
2625+
guard !entries.isEmpty else { return nil }
2626+
return OpenAILogprobs(content: entries)
2627+
}
2628+
25892629
private static func jsonResponse<T: Encodable>(_ value: T, httpStatus: HTTPResponse.Status) throws -> Response {
25902630
let data = try JSONEncoder().encode(value)
25912631
return Response(

Sources/NovaMLXAPI/OpenAITypes.swift

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,11 @@ public struct OpenAIRequest: Codable, Sendable {
180180
public let preserveThinking: Bool?
181181
public let chatTemplateKwargs: [String: AnyCodable]?
182182
public let reasoningEffort: String?
183+
public let logprobs: Bool?
184+
public let topLogprobs: Int?
183185

184186
private enum CodingKeys: String, CodingKey {
185-
case model, messages, temperature, stream, stop, n, seed, tools
187+
case model, messages, temperature, stream, stop, n, seed, tools, logprobs
186188
case toolChoice = "tool_choice"
187189
case topP = "top_p"
188190
case topK = "top_k"
@@ -199,6 +201,7 @@ public struct OpenAIRequest: Codable, Sendable {
199201
case preserveThinking = "preserve_thinking"
200202
case chatTemplateKwargs = "chat_template_kwargs"
201203
case reasoningEffort = "reasoning_effort"
204+
case topLogprobs = "top_logprobs"
202205
}
203206

204207
public init(
@@ -225,7 +228,9 @@ public struct OpenAIRequest: Codable, Sendable {
225228
enableThinking: Bool? = nil,
226229
preserveThinking: Bool? = nil,
227230
chatTemplateKwargs: [String: AnyCodable]? = nil,
228-
reasoningEffort: String? = nil
231+
reasoningEffort: String? = nil,
232+
logprobs: Bool? = nil,
233+
topLogprobs: Int? = nil
229234
) {
230235
self.model = model
231236
self.messages = messages
@@ -251,6 +256,8 @@ public struct OpenAIRequest: Codable, Sendable {
251256
self.preserveThinking = preserveThinking
252257
self.chatTemplateKwargs = chatTemplateKwargs
253258
self.reasoningEffort = reasoningEffort
259+
self.logprobs = logprobs
260+
self.topLogprobs = topLogprobs
254261
}
255262

256263
/// Resolve thinking toggle from multiple client formats:
@@ -475,16 +482,18 @@ public struct OpenAIChoice: Codable, Sendable {
475482
public let index: Int
476483
public let message: OpenAIChatMessage
477484
public let finishReason: String?
485+
public let logprobs: OpenAILogprobs?
478486

479487
private enum CodingKeys: String, CodingKey {
480-
case index, message
488+
case index, message, logprobs
481489
case finishReason = "finish_reason"
482490
}
483491

484-
public init(index: Int, message: OpenAIChatMessage, finishReason: String? = nil) {
492+
public init(index: Int, message: OpenAIChatMessage, finishReason: String? = nil, logprobs: OpenAILogprobs? = nil) {
485493
self.index = index
486494
self.message = message
487495
self.finishReason = finishReason
496+
self.logprobs = logprobs
488497
}
489498
}
490499

@@ -528,16 +537,18 @@ public struct OpenAIStreamChoice: Codable, Sendable {
528537
public let index: Int
529538
public let delta: OpenAIDelta
530539
public let finishReason: String?
540+
public let logprobs: OpenAILogprobs?
531541

532542
private enum CodingKeys: String, CodingKey {
533-
case index, delta
543+
case index, delta, logprobs
534544
case finishReason = "finish_reason"
535545
}
536546

537-
public init(index: Int, delta: OpenAIDelta, finishReason: String? = nil) {
547+
public init(index: Int, delta: OpenAIDelta, finishReason: String? = nil, logprobs: OpenAILogprobs? = nil) {
538548
self.index = index
539549
self.delta = delta
540550
self.finishReason = finishReason
551+
self.logprobs = logprobs
541552
}
542553
}
543554

@@ -561,6 +572,45 @@ public struct OpenAIDelta: Codable, Sendable {
561572
}
562573
}
563574

575+
public struct OpenAILogprobs: Codable, Sendable {
576+
public let content: [OpenAILogprobEntry]?
577+
578+
public init(content: [OpenAILogprobEntry]? = nil) {
579+
self.content = content
580+
}
581+
}
582+
583+
public struct OpenAILogprobEntry: Codable, Sendable {
584+
public let token: String
585+
public let logprob: Float
586+
public let bytes: [Int]?
587+
public let topLogprobs: [OpenAITopLogprob]
588+
589+
private enum CodingKeys: String, CodingKey {
590+
case token, logprob, bytes
591+
case topLogprobs = "top_logprobs"
592+
}
593+
594+
public init(token: String, logprob: Float, bytes: [Int]? = nil, topLogprobs: [OpenAITopLogprob] = []) {
595+
self.token = token
596+
self.logprob = logprob
597+
self.bytes = bytes
598+
self.topLogprobs = topLogprobs
599+
}
600+
}
601+
602+
public struct OpenAITopLogprob: Codable, Sendable {
603+
public let token: String
604+
public let logprob: Float
605+
public let bytes: [Int]?
606+
607+
public init(token: String, logprob: Float, bytes: [Int]? = nil) {
608+
self.token = token
609+
self.logprob = logprob
610+
self.bytes = bytes
611+
}
612+
}
613+
564614
public struct OpenAIToolCallDelta: Codable, Sendable {
565615
public let index: Int
566616
public let id: String?

Sources/NovaMLXCore/Types.swift

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ public struct InferenceRequest: @unchecked Sendable {
217217
public let draftModel: String?
218218
/// Number of tokens the draft model proposes per speculation round (default: 4).
219219
public let numDraftTokens: Int?
220+
/// When true, compute log probabilities for sampled tokens and top-K alternatives.
221+
public let includeLogprobs: Bool
222+
/// Number of top logprobs to return per token (only used when includeLogprobs is true).
223+
public let topLogprobsCount: Int?
220224

221225
public init(
222226
id: UUID = UUID(),
@@ -243,7 +247,9 @@ public struct InferenceRequest: @unchecked Sendable {
243247
enableThinking: Bool? = nil,
244248
preserveThinking: Bool? = nil,
245249
draftModel: String? = nil,
246-
numDraftTokens: Int? = nil
250+
numDraftTokens: Int? = nil,
251+
includeLogprobs: Bool = false,
252+
topLogprobsCount: Int? = nil
247253
) {
248254
self.id = id
249255
self.model = model
@@ -270,6 +276,8 @@ public struct InferenceRequest: @unchecked Sendable {
270276
self.preserveThinking = preserveThinking
271277
self.draftModel = draftModel
272278
self.numDraftTokens = numDraftTokens
279+
self.includeLogprobs = includeLogprobs
280+
self.topLogprobsCount = topLogprobsCount
273281
}
274282
}
275283

@@ -307,6 +315,7 @@ public struct InferenceResult: Codable, Sendable {
307315
public let promptTokens: Int
308316
public let completionTokens: Int
309317
public let finishReason: FinishReason
318+
public let tokenLogprobs: [Token]?
310319

311320
public init(
312321
id: UUID,
@@ -316,7 +325,8 @@ public struct InferenceResult: Codable, Sendable {
316325
tokensPerSecond: Double,
317326
promptTokens: Int,
318327
completionTokens: Int,
319-
finishReason: FinishReason
328+
finishReason: FinishReason,
329+
tokenLogprobs: [Token]? = nil
320330
) {
321331
self.id = id
322332
self.model = model
@@ -326,6 +336,7 @@ public struct InferenceResult: Codable, Sendable {
326336
self.promptTokens = promptTokens
327337
self.completionTokens = completionTokens
328338
self.finishReason = finishReason
339+
self.tokenLogprobs = tokenLogprobs
329340
}
330341
}
331342

@@ -335,17 +346,31 @@ public enum FinishReason: String, Codable, Sendable {
335346
case toolCalls = "tool_calls"
336347
}
337348

349+
public struct TopLogprob: Codable, Sendable {
350+
public let tokenId: Int
351+
public let tokenText: String
352+
public let logprob: Float
353+
354+
public init(tokenId: Int, tokenText: String, logprob: Float) {
355+
self.tokenId = tokenId
356+
self.tokenText = tokenText
357+
self.logprob = logprob
358+
}
359+
}
360+
338361
public struct Token: Codable, Sendable {
339362
public let id: Int
340363
public let text: String
341364
public let logprob: Float?
365+
public let topLogprobs: [TopLogprob]?
342366
public let finishReason: FinishReason?
343367
public let toolCall: ToolCallResult?
344368

345-
public init(id: Int, text: String, logprob: Float? = nil, finishReason: FinishReason? = nil, toolCall: ToolCallResult? = nil) {
369+
public init(id: Int, text: String, logprob: Float? = nil, topLogprobs: [TopLogprob]? = nil, finishReason: FinishReason? = nil, toolCall: ToolCallResult? = nil) {
346370
self.id = id
347371
self.text = text
348372
self.logprob = logprob
373+
self.topLogprobs = topLogprobs
349374
self.finishReason = finishReason
350375
self.toolCall = toolCall
351376
}

0 commit comments

Comments
 (0)