Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,34 @@
)
from opentelemetry.semconv_ai import (
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
LLMRequestTypeValues,
Meters,
SpanAttributes,
)
from opentelemetry.trace import SpanKind, Tracer, get_tracer
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer
from opentelemetry.trace.status import Status, StatusCode
from wrapt import wrap_function_wrapper

from groq._streaming import AsyncStream, Stream
from groq.types.completion_usage import CompletionUsage

logger = logging.getLogger(__name__)

_instruments = ("groq >= 0.9.0",)

_GROQ = GenAIAttributes.GenAiProviderNameValues.GROQ.value
_CHAT = GenAIAttributes.GenAiOperationNameValues.CHAT.value

WRAPPED_METHODS = [
{
"package": "groq.resources.chat.completions",
"object": "Completions",
"method": "create",
"span_name": "groq.chat",
},
]
WRAPPED_AMETHODS = [
{
"package": "groq.resources.chat.completions",
"object": "AsyncCompletions",
"method": "create",
"span_name": "groq.chat",
},
]

Expand Down Expand Up @@ -125,53 +124,90 @@ def _create_metrics(meter: Meter):


def _process_streaming_chunk(chunk):
"""Extract content, finish_reason and usage from a streaming chunk."""
"""Extract content, tool_calls_delta, finish_reasons and usage from a streaming chunk."""
if not chunk.choices:
return None, None, None

delta = chunk.choices[0].delta
content = delta.content if hasattr(delta, "content") else None
finish_reason = chunk.choices[0].finish_reason
return None, [], [], None

content = ""
tool_calls_delta = []
finish_reasons = []
for choice in chunk.choices:
delta = choice.delta
if delta.content:
content += delta.content
if delta.tool_calls:
tool_calls_delta.extend(delta.tool_calls)
if choice.finish_reason:
finish_reasons.append(choice.finish_reason)

# Extract usage from x_groq if present in the final chunk
usage = None
if hasattr(chunk, "x_groq") and chunk.x_groq and chunk.x_groq.usage:
usage = chunk.x_groq.usage

return content, finish_reason, usage
return content, tool_calls_delta, finish_reasons, usage


def _accumulate_tool_calls(accumulated: dict, tool_calls_delta: list) -> None:
"""Merge a list of streaming tool_call delta objects into the accumulator dict.

The accumulator maps tool call index → {id, function: {name, arguments}}.
Arguments arrive as JSON fragments and are concatenated across chunks.
"""
for tc in tool_calls_delta:
idx = tc.index or 0
tc_id = tc.id or ""
fn = tc.function
fn_name = (fn.name or "") if fn else ""
fn_args = (fn.arguments or "") if fn else ""

if idx not in accumulated:
accumulated[idx] = {"id": tc_id, "function": {"name": fn_name, "arguments": ""}}
else:
if tc_id:
accumulated[idx]["id"] = tc_id
if fn_name:
accumulated[idx]["function"]["name"] = fn_name
accumulated[idx]["function"]["arguments"] += fn_args


def _handle_streaming_response(
span, accumulated_content, finish_reason, usage, event_logger
):
set_model_streaming_response_attributes(span, usage)
span: Span,
accumulated_content: str,
tool_calls: dict,
finish_reasons: list[str],
usage: Union[CompletionUsage, None],
event_logger: Union[Logger, None],
) -> None:
# finish_reasons is a list; use first entry for message-level finish_reason
finish_reason = finish_reasons[0] if finish_reasons else None
set_model_streaming_response_attributes(span, usage, finish_reasons)
if should_emit_events() and event_logger:
emit_streaming_response_events(accumulated_content, finish_reason, event_logger)
emit_streaming_response_events(accumulated_content, finish_reason, event_logger, tool_calls=tool_calls)
else:
set_streaming_response_attributes(
span, accumulated_content, finish_reason, usage
)
set_streaming_response_attributes(span, accumulated_content, finish_reason, tool_calls=tool_calls)


def _create_stream_processor(response, span, event_logger):
"""Create a generator that processes a stream while collecting telemetry."""
accumulated_content = ""
finish_reason = None
accumulated_tool_calls: dict = {}
accumulated_finish_reasons: list = []
usage = None

for chunk in response:
content, chunk_finish_reason, chunk_usage = _process_streaming_chunk(chunk)
content, tool_calls_delta, chunk_finish_reasons, chunk_usage = _process_streaming_chunk(chunk)
if content:
accumulated_content += content
if chunk_finish_reason:
finish_reason = chunk_finish_reason
if tool_calls_delta:
_accumulate_tool_calls(accumulated_tool_calls, tool_calls_delta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not use the return value here..

Copy link
Copy Markdown
Member Author

@lenatraceloop lenatraceloop Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: _accumulate_tool_calls now returns None — modifies the dict in-place, no return value needed.

accumulated_finish_reasons.extend(chunk_finish_reasons)
if chunk_usage:
usage = chunk_usage
yield chunk

_handle_streaming_response(
span, accumulated_content, finish_reason, usage, event_logger
)
tool_calls = [accumulated_tool_calls[i] for i in sorted(accumulated_tool_calls)] or None
_handle_streaming_response(span, accumulated_content, tool_calls, accumulated_finish_reasons, usage, event_logger)

if span.is_recording():
span.set_status(Status(StatusCode.OK))
Expand All @@ -182,22 +218,23 @@ def _create_stream_processor(response, span, event_logger):
async def _create_async_stream_processor(response, span, event_logger):
"""Create an async generator that processes a stream while collecting telemetry."""
accumulated_content = ""
finish_reason = None
accumulated_tool_calls: dict = {}
accumulated_finish_reasons: list = []
usage = None

async for chunk in response:
content, chunk_finish_reason, chunk_usage = _process_streaming_chunk(chunk)
content, tool_calls_delta, chunk_finish_reasons, chunk_usage = _process_streaming_chunk(chunk)
if content:
accumulated_content += content
if chunk_finish_reason:
finish_reason = chunk_finish_reason
if tool_calls_delta:
_accumulate_tool_calls(accumulated_tool_calls, tool_calls_delta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - you do not use the return file

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: _accumulate_tool_calls now returns None — modifies the dict in-place, no return value needed.

accumulated_finish_reasons.extend(chunk_finish_reasons)
if chunk_usage:
usage = chunk_usage
yield chunk

_handle_streaming_response(
span, accumulated_content, finish_reason, usage, event_logger
)
tool_calls = [accumulated_tool_calls[i] for i in sorted(accumulated_tool_calls)] or None
_handle_streaming_response(span, accumulated_content, tool_calls, accumulated_finish_reasons, usage, event_logger)

if span.is_recording():
span.set_status(Status(StatusCode.OK))
Expand Down Expand Up @@ -240,13 +277,14 @@ def _wrap(
):
return wrapped(*args, **kwargs)

name = to_wrap.get("span_name")
llm_model = kwargs.get("model", "")
span = tracer.start_span(
name,
f"{_CHAT} {llm_model}",
kind=SpanKind.CLIENT,
attributes={
GenAIAttributes.GEN_AI_SYSTEM: "groq",
SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value,
GenAIAttributes.GEN_AI_PROVIDER_NAME: _GROQ,
GenAIAttributes.GEN_AI_OPERATION_NAME: _CHAT,
GenAIAttributes.GEN_AI_REQUEST_MODEL: llm_model,
},
)

Expand All @@ -255,14 +293,17 @@ def _wrap(
start_time = time.time()
try:
response = wrapped(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
except Exception as e:
end_time = time.time()
attributes = error_metrics_attributes(e)

if duration_histogram:
duration = end_time - start_time
duration_histogram.record(duration, attributes=attributes)

if span.is_recording():
span.set_status(Status(StatusCode.ERROR))
span.end()
raise e

end_time = time.time()
Expand Down Expand Up @@ -291,7 +332,7 @@ def _wrap(

_handle_response(span, response, token_histogram, event_logger)

except Exception as ex: # pylint: disable=broad-except
except Exception as ex:
logger.warning(
"Failed to set response attributes for groq span, error: %s",
str(ex),
Expand Down Expand Up @@ -322,13 +363,14 @@ async def _awrap(
):
return await wrapped(*args, **kwargs)

name = to_wrap.get("span_name")
llm_model = kwargs.get("model", "")
span = tracer.start_span(
name,
f"{_CHAT} {llm_model}",
kind=SpanKind.CLIENT,
attributes={
GenAIAttributes.GEN_AI_SYSTEM: "groq",
SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value,
GenAIAttributes.GEN_AI_PROVIDER_NAME: _GROQ,
GenAIAttributes.GEN_AI_OPERATION_NAME: _CHAT,
GenAIAttributes.GEN_AI_REQUEST_MODEL: llm_model,
},
)

Expand All @@ -338,21 +380,24 @@ async def _awrap(

try:
response = await wrapped(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
except Exception as e:
end_time = time.time()
attributes = error_metrics_attributes(e)

if duration_histogram:
duration = end_time - start_time
duration_histogram.record(duration, attributes=attributes)

if span.is_recording():
span.set_status(Status(StatusCode.ERROR))
span.end()
raise e

end_time = time.time()

if is_streaming_response(response):
try:
return await _create_async_stream_processor(response, span, event_logger)
return _create_async_stream_processor(response, span, event_logger)
except Exception as ex:
logger.warning(
"Failed to process streaming response for groq span, error: %s",
Expand All @@ -362,16 +407,23 @@ async def _awrap(
span.end()
raise
elif response:
metric_attributes = shared_metrics_attributes(response)
try:
metric_attributes = shared_metrics_attributes(response)

if duration_histogram:
duration = time.time() - start_time
duration_histogram.record(
duration,
attributes=metric_attributes,
)
if duration_histogram:
duration = time.time() - start_time
duration_histogram.record(
duration,
attributes=metric_attributes,
)

_handle_response(span, response, token_histogram, event_logger)
_handle_response(span, response, token_histogram, event_logger)

except Exception as ex:
logger.warning(
"Failed to set response attributes for groq span, error: %s",
str(ex),
)

if span.is_recording():
span.set_status(Status(StatusCode.OK))
Expand Down Expand Up @@ -424,9 +476,7 @@ def _instrument(self, **kwargs):
event_logger = None
if not Config.use_legacy_attributes:
logger_provider = kwargs.get("logger_provider")
event_logger = get_logger(
__name__, __version__, logger_provider=logger_provider
)
event_logger = get_logger(__name__, __version__, logger_provider=logger_provider)

for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
Expand Down
Loading
Loading