forked from open-telemetry/opentelemetry-python-contrib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompletion_hook.py
More file actions
404 lines (354 loc) · 14.2 KB
/
completion_hook.py
File metadata and controls
404 lines (354 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import hashlib
import json
import logging
import posixpath
import threading
from collections import OrderedDict
from concurrent.futures import ( # pylint: disable=no-name-in-module; TODO #4199
Future,
ThreadPoolExecutor,
)
from contextlib import ExitStack
from dataclasses import asdict, dataclass
from functools import partial
from time import time
from typing import Any, Callable, Final, Literal
from uuid import uuid4
import fsspec
from opentelemetry._logs import LogRecord
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.trace import Span
from opentelemetry.util.genai import types
from opentelemetry.util.genai.completion_hook import CompletionHook
from opentelemetry.util.genai.utils import gen_ai_json_dump
GEN_AI_INPUT_MESSAGES_REF: Final = (
gen_ai_attributes.GEN_AI_INPUT_MESSAGES + "_ref"
)
GEN_AI_OUTPUT_MESSAGES_REF: Final = (
gen_ai_attributes.GEN_AI_OUTPUT_MESSAGES + "_ref"
)
GEN_AI_SYSTEM_INSTRUCTIONS_REF: Final = (
gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS + "_ref"
)
GEN_AI_TOOL_DEFINITIONS = getattr(
gen_ai_attributes, "GEN_AI_TOOL_DEFINITIONS", "gen_ai.tool.definitions"
)
GEN_AI_TOOL_DEFINITIONS_REF: Final = GEN_AI_TOOL_DEFINITIONS + "_ref"
_MESSAGE_INDEX_KEY = "index"
_DEFAULT_MAX_QUEUE_SIZE = 20
_DEFAULT_FORMAT = "json"
Format = Literal["json", "jsonl"]
_FORMATS: tuple[Format, ...] = ("json", "jsonl")
_logger = logging.getLogger(__name__)
@dataclass
class Completion:
inputs: list[types.InputMessage] | None
outputs: list[types.OutputMessage] | None
system_instruction: list[types.MessagePart] | None
tool_definitions: list[types.ToolDefinition] | None
@dataclass
class CompletionRefs:
inputs_ref: str
outputs_ref: str
system_instruction_ref: str
tool_definitions_ref: str
JsonEncodeable = list[dict[str, Any]]
# mapping of upload path and whether the contents were hashed to the filename to function computing upload data dict
UploadData = dict[tuple[str, bool], Callable[[], JsonEncodeable]]
def is_message_part_list_hashable(
message_parts: list[types.MessagePart] | None,
) -> bool:
return bool(message_parts) and all(
isinstance(x, types.Text) for x in message_parts
)
def hash_tool_definitions(
tool_definitions: list[types.ToolDefinition] | None,
) -> str | None:
if not tool_definitions:
return None
try:
tool_dicts = [
{k: v for k, v in dataclasses.asdict(t).items() if v is not None}
for t in tool_definitions
]
encoded_tools = json.dumps(
tool_dicts,
sort_keys=True,
).encode("utf-8")
return hashlib.sha256(
encoded_tools,
usedforsecurity=False,
).hexdigest()
except (TypeError, AttributeError):
return None
class UploadCompletionHook(CompletionHook):
"""An completion hook using ``fsspec`` to upload to external storage
This function can be used as the
:func:`~opentelemetry.util.genai.completion_hook.load_completion_hook` implementation by
setting :envvar:`OTEL_INSTRUMENTATION_GENAI_COMPLETION_HOOK` to ``upload``.
:envvar:`OTEL_INSTRUMENTATION_GENAI_UPLOAD_BASE_PATH` must be configured to specify the
base path for uploads.
Both the ``fsspec`` and ``opentelemetry-sdk`` packages should be installed, or a no-op
implementation will be used instead. You can use ``opentelemetry-util-genai[upload]``
as a requirement to achieve this.
"""
def __init__(
self,
*,
base_path: str,
max_queue_size: int = _DEFAULT_MAX_QUEUE_SIZE,
upload_format: Format = _DEFAULT_FORMAT,
lru_cache_max_size: int = 1024,
) -> None:
self._max_queue_size = max_queue_size
self._fs, base_path = fsspec.url_to_fs(base_path)
self._base_path = self._fs.unstrip_protocol(base_path)
self.lru_dict: OrderedDict[str, bool] = OrderedDict()
self.lru_cache_max_size = lru_cache_max_size
if upload_format not in _FORMATS:
raise ValueError(
f"Invalid {upload_format=}. Must be one of {_FORMATS}"
)
self._format = upload_format
self._content_type = (
"application/json"
if self._format == "json"
else "application/jsonl"
)
test_path = posixpath.join(
self._base_path,
f".one_off_test_to_see_if_upload_works.{self._format}",
)
try:
with self._fs.open(
test_path, "w", content_type=self._content_type
) as file:
file.write("\n")
except Exception as exception: # pylint: disable=broad-exception-caught
raise ValueError(
f"Failed to write file to the following path, upload is not working: {test_path}.\n Got error: {exception}"
)
# Try to delete the file.. But we don't explicitly ask people to grant the GCS delete IAM permission in our
# docs, so if delete fails just leave the file..
try:
self._fs.rm_file(test_path) # pyright: ignore[reportUnknownMemberType]
except Exception: # pylint: disable=broad-exception-caught
pass
# Use a ThreadPoolExecutor for its queueing and thread management. The semaphore
# limits the number of queued tasks. If the queue is full, data will be dropped.
self._executor = ThreadPoolExecutor(
max_workers=min(self._max_queue_size, 64)
)
self._semaphore = threading.BoundedSemaphore(self._max_queue_size)
def _submit_all(self, upload_data: UploadData) -> None:
def done(future: Future[None]) -> None:
try:
future.result()
except Exception: # pylint: disable=broad-except
_logger.exception("uploader failed")
finally:
self._semaphore.release()
for (
path,
contents_hashed_to_filename,
), json_encodeable in upload_data.items():
if contents_hashed_to_filename and path in self.lru_dict:
self.lru_dict.move_to_end(path)
continue
# could not acquire, drop data
if not self._semaphore.acquire(blocking=False): # pylint: disable=consider-using-with
_logger.warning(
"upload queue is full, dropping upload %s",
path,
)
continue
try:
fut = self._executor.submit(
self._do_upload,
path,
contents_hashed_to_filename,
json_encodeable,
)
fut.add_done_callback(done)
except RuntimeError:
_logger.info(
"attempting to upload file after UploadCompletionHook.shutdown() was already called"
)
self._semaphore.release()
def _calculate_ref_path(
self,
system_instruction: list[types.MessagePart],
tool_definitions: list[types.ToolDefinition] | None = None,
) -> CompletionRefs:
# TODO: experimental with using the trace_id and span_id, or fetching
# gen_ai.response.id from the active span.
system_instruction_hash = None
if is_message_part_list_hashable(system_instruction):
# Get a hash of the text.
system_instruction_hash = hashlib.sha256(
"\n".join(x.content for x in system_instruction).encode( # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue, reportUnknownArgumentType, reportCallIssue, reportArgumentType]
"utf-8"
),
usedforsecurity=False,
).hexdigest()
tool_definitions_hash = hash_tool_definitions(tool_definitions)
uuid_str = str(uuid4())
return CompletionRefs(
inputs_ref=posixpath.join(
self._base_path, f"{uuid_str}_inputs.{self._format}"
),
outputs_ref=posixpath.join(
self._base_path, f"{uuid_str}_outputs.{self._format}"
),
system_instruction_ref=posixpath.join(
self._base_path,
f"{system_instruction_hash or uuid_str}_system_instruction.{self._format}",
),
tool_definitions_ref=posixpath.join(
self._base_path,
f"{tool_definitions_hash or uuid_str}_tool.definitions.{self._format}",
),
)
def _file_exists(self, path: str) -> bool:
if path in self.lru_dict:
self.lru_dict.move_to_end(path)
return True
# https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists
file_exists = self._fs.exists(path)
# don't cache this because soon the file will exist..
if not file_exists:
return False
self.lru_dict[path] = True
if len(self.lru_dict) > self.lru_cache_max_size:
self.lru_dict.popitem(last=False)
return True
def _do_upload(
self,
path: str,
contents_hashed_to_filename: bool,
json_encodeable: Callable[[], JsonEncodeable],
) -> None:
if contents_hashed_to_filename and self._file_exists(path):
return
if self._format == "json":
# output as a single line with the json messages array
message_lines = [json_encodeable()]
else:
# output as one line per message in the array
message_lines = json_encodeable()
# add an index for streaming readers of jsonl
for message_idx, line in enumerate(message_lines):
line[_MESSAGE_INDEX_KEY] = message_idx
with self._fs.open(path, "w", content_type=self._content_type) as file:
for message in message_lines:
gen_ai_json_dump(message, file)
file.write("\n")
if contents_hashed_to_filename:
self.lru_dict[path] = True
if len(self.lru_dict) > self.lru_cache_max_size:
self.lru_dict.popitem(last=False)
def on_completion(
self,
*,
inputs: list[types.InputMessage],
outputs: list[types.OutputMessage],
system_instruction: list[types.MessagePart],
tool_definitions: list[types.ToolDefinition] | None = None,
span: Span | None = None,
log_record: LogRecord | None = None,
**kwargs: Any,
) -> None:
if not any([inputs, outputs, system_instruction, tool_definitions]):
return
# An empty list will not be uploaded.
completion = Completion(
inputs=inputs or None,
outputs=outputs or None,
system_instruction=system_instruction or None,
tool_definitions=tool_definitions or None,
)
# generate the paths to upload to
ref_names = self._calculate_ref_path(
system_instruction, tool_definitions
)
def to_dict(
dataclass_list: list[types.InputMessage]
| list[types.OutputMessage]
| list[types.MessagePart]
| list[types.ToolDefinition],
) -> JsonEncodeable:
return [asdict(dc) for dc in dataclass_list]
references = [
(ref_name, ref, ref_attr, contents_hashed_to_filename)
for ref_name, ref, ref_attr, contents_hashed_to_filename in [
(
ref_names.inputs_ref,
completion.inputs,
GEN_AI_INPUT_MESSAGES_REF,
False,
),
(
ref_names.outputs_ref,
completion.outputs,
GEN_AI_OUTPUT_MESSAGES_REF,
False,
),
(
ref_names.system_instruction_ref,
completion.system_instruction,
GEN_AI_SYSTEM_INSTRUCTIONS_REF,
is_message_part_list_hashable(
completion.system_instruction
),
),
(
ref_names.tool_definitions_ref,
completion.tool_definitions,
GEN_AI_TOOL_DEFINITIONS_REF,
bool(completion.tool_definitions),
),
]
if ref # Filter out empty input/output/sys instruction/tool defs
]
self._submit_all(
{
(ref_name, contents_hashed_to_filename): partial(to_dict, ref)
for ref_name, ref, _, contents_hashed_to_filename in references
}
)
# stamp the refs on telemetry
references = {ref_attr: name for name, _, ref_attr, _ in references}
if span:
span.set_attributes(references)
if log_record:
log_record.attributes = {
**(log_record.attributes or {}),
**references,
}
def shutdown(self, *, timeout_sec: float = 10.0) -> None:
deadline = time() + timeout_sec
# Wait for all tasks to finish to flush the queue
with ExitStack() as stack:
for _ in range(self._max_queue_size):
remaining = deadline - time()
if not self._semaphore.acquire(timeout=remaining): # pylint: disable=consider-using-with
# Couldn't finish flushing all uploads before timeout
break
stack.callback(self._semaphore.release)
# Queue is flushed and blocked, start shutdown
self._executor.shutdown(wait=False)