-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathllm_streaming.py
More file actions
232 lines (182 loc) · 7.69 KB
/
llm_streaming.py
File metadata and controls
232 lines (182 loc) · 7.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""Helpers for consuming OpenAI-compatible chat completion streams."""
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class AggregatedFunctionCall(BaseModel):
"""Aggregated function call payload."""
name: Optional[str] = None
arguments: str = ""
class AggregatedToolCall(BaseModel):
"""Aggregated tool call payload."""
id: str
type: str = "function"
function: AggregatedFunctionCall
class AggregatedChatMessage(BaseModel):
"""Aggregated assistant message compatible with OpenAI chat completions."""
role: str = "assistant"
content: str = ""
reasoning_content: Optional[str] = None
tool_calls: Optional[List[AggregatedToolCall]] = None
class AggregatedChoice(BaseModel):
"""Aggregated completion choice."""
index: int = 0
message: AggregatedChatMessage
finish_reason: Optional[str] = None
class AggregatedChatCompletion(BaseModel):
"""Aggregated completion response."""
id: str = "chatcmpl-stream"
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = "unknown"
choices: List[AggregatedChoice]
usage: Optional[Dict[str, Any]] = None
def _to_plain_dict(obj: Any) -> Any:
"""Best-effort conversion of SDK objects into plain Python containers."""
if obj is None:
return None
if hasattr(obj, "model_dump"):
try:
return obj.model_dump(exclude_none=False)
except TypeError:
return obj.model_dump()
if isinstance(obj, dict):
return {k: _to_plain_dict(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_plain_dict(item) for item in obj]
return obj
def _extract_text_content(content: Any) -> str:
"""Extract text from content that may arrive as a string or structured blocks."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
item = _to_plain_dict(item)
if isinstance(item, str):
parts.append(item)
continue
if not isinstance(item, dict):
continue
if item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(item["text"])
continue
text_payload = item.get("text")
if isinstance(text_payload, dict) and isinstance(text_payload.get("value"), str):
parts.append(text_payload["value"])
return "".join(parts)
return str(content)
async def _iterate_stream(stream: Any):
"""Iterate over either async or sync stream wrappers."""
if hasattr(stream, "__aiter__"):
async for chunk in stream:
yield chunk
return
for chunk in stream:
yield chunk
def is_streaming_unsupported_error(exc: Exception) -> bool:
"""Return True when the upstream explicitly rejects streaming."""
message = f"{type(exc).__name__}: {exc}".lower()
patterns = [
"streaming not supported",
"stream not supported",
"does not support streaming",
"stream must be false",
"stream=true is not supported",
"streaming unsupported",
"invalid value for stream",
]
return any(pattern in message for pattern in patterns)
async def collect_openai_chat_stream(stream: Any, *, model_name: str) -> AggregatedChatCompletion:
"""Consume a chat-completions stream and aggregate it back into one response."""
chunk_count = 0
content_parts: List[str] = []
reasoning_parts: List[str] = []
tool_calls_by_index: Dict[int, Dict[str, Any]] = {}
finish_reason: Optional[str] = None
response_id = "chatcmpl-stream"
created = int(time.time())
usage: Optional[Dict[str, Any]] = None
role = "assistant"
async for chunk in _iterate_stream(stream):
chunk_dict = _to_plain_dict(chunk)
if not isinstance(chunk_dict, dict):
continue
chunk_count += 1
error_payload = chunk_dict.get("error")
if error_payload:
raise RuntimeError(f"stream returned error payload: {error_payload}")
if isinstance(chunk_dict.get("id"), str) and chunk_dict["id"]:
response_id = chunk_dict["id"]
if isinstance(chunk_dict.get("model"), str) and chunk_dict["model"]:
model_name = chunk_dict["model"]
if isinstance(chunk_dict.get("created"), int):
created = chunk_dict["created"]
if isinstance(chunk_dict.get("usage"), dict):
usage = chunk_dict["usage"]
for choice in chunk_dict.get("choices") or []:
if not isinstance(choice, dict):
continue
delta = choice.get("delta") or {}
if not isinstance(delta, dict):
delta = _to_plain_dict(delta) or {}
if isinstance(delta.get("role"), str) and delta["role"]:
role = delta["role"]
text_delta = _extract_text_content(delta.get("content"))
if text_delta:
content_parts.append(text_delta)
reasoning_delta = _extract_text_content(delta.get("reasoning_content"))
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
for raw_tool_call in delta.get("tool_calls") or []:
tool_call = _to_plain_dict(raw_tool_call)
if not isinstance(tool_call, dict):
continue
index = int(tool_call.get("index", 0))
existing = tool_calls_by_index.setdefault(
index,
{
"id": tool_call.get("id") or f"tool_call_{index}",
"type": tool_call.get("type") or "function",
"function": {"name": "", "arguments": ""},
},
)
if isinstance(tool_call.get("id"), str) and tool_call["id"]:
existing["id"] = tool_call["id"]
if isinstance(tool_call.get("type"), str) and tool_call["type"]:
existing["type"] = tool_call["type"]
function_payload = tool_call.get("function") or {}
if not isinstance(function_payload, dict):
function_payload = _to_plain_dict(function_payload) or {}
function_name = function_payload.get("name")
if isinstance(function_name, str) and function_name:
existing["function"]["name"] = function_name
function_arguments = function_payload.get("arguments")
if isinstance(function_arguments, str) and function_arguments:
existing["function"]["arguments"] += function_arguments
if isinstance(choice.get("finish_reason"), str) and choice["finish_reason"]:
finish_reason = choice["finish_reason"]
if chunk_count == 0:
raise RuntimeError("stream returned no chunks")
tool_calls = None
if tool_calls_by_index:
tool_calls = [
AggregatedToolCall.model_validate(tool_calls_by_index[index])
for index in sorted(tool_calls_by_index)
]
reasoning_content = "".join(reasoning_parts) or None
message = AggregatedChatMessage(
role=role,
content="".join(content_parts),
reasoning_content=reasoning_content,
tool_calls=tool_calls,
)
return AggregatedChatCompletion(
id=response_id,
created=created,
model=model_name,
choices=[AggregatedChoice(message=message, finish_reason=finish_reason)],
usage=usage,
)