Skip to content

Commit bf3e6a5

Browse files
committed
feat(engine): add generic extra_headers support for LLM providers
Add extra_headers field to ProviderConfig, ResolvedLLMConfig, and LLMCallInput. Values use {ENV_VAR} syntax for runtime resolution. Headers are merged into _call_openai() (default_headers), _call_anthropic() (headers dict), and EmbeddingExecutor (default_headers). This enables config-driven proxy routing where the YAML generator injects X-Org-Id and X-User-Id headers without engine-specific code. TASK-207
1 parent dc4caef commit bf3e6a5

2 files changed

Lines changed: 55 additions & 5 deletions

File tree

src/workflows_mcp/engine/executors_llm.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import asyncio
1111
import json
1212
import logging
13+
import os
1314
from enum import Enum
1415
from typing import Any, ClassVar, cast
1516

@@ -50,6 +51,21 @@ class LLMProvider(str, Enum):
5051
OLLAMA = "ollama"
5152

5253

54+
def _resolve_header_env_vars(headers: dict[str, str]) -> dict[str, str]:
55+
"""Resolve {ENV_VAR} placeholders in header values.
56+
57+
Values wrapped in curly braces (e.g., "{ORG_ID}") are resolved from
58+
environment variables at runtime. Other values are passed through as-is.
59+
"""
60+
resolved = {}
61+
for k, v in headers.items():
62+
if v.startswith("{") and v.endswith("}"):
63+
resolved[k] = os.environ.get(v[1:-1], "")
64+
else:
65+
resolved[k] = v
66+
return resolved
67+
68+
5369
# ===========================================================================
5470
# LLMCall Executor
5571
# ===========================================================================
@@ -110,6 +126,10 @@ class LLMCallInput(BlockInput):
110126
default=None,
111127
description="Custom API endpoint URL (optional, for custom deployments)",
112128
)
129+
extra_headers: dict[str, str] = Field(
130+
default_factory=dict,
131+
description="Custom HTTP headers for provider requests (resolved from config)",
132+
)
113133
response_schema: dict[str, Any] | str | None = Field(
114134
default=None,
115135
description=(
@@ -604,6 +624,7 @@ async def _resolve_profile_to_inputs(
604624
temperature=resolved_config.temperature,
605625
max_tokens=resolved_config.max_tokens,
606626
validation_prompt_template=inputs.validation_prompt_template,
627+
extra_headers=resolved_config.extra_headers,
607628
)
608629

609630
def _resolve_profile_with_fallback(
@@ -915,6 +936,10 @@ async def _call_openai(
915936
base_url = base_url.rsplit("/chat/completions", 1)[0]
916937
client_kwargs["base_url"] = base_url
917938

939+
# Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing)
940+
if inputs.extra_headers:
941+
client_kwargs["default_headers"] = _resolve_header_env_vars(inputs.extra_headers)
942+
918943
# Prepare completion parameters (required parameters only)
919944
completion_kwargs: dict[str, Any] = {
920945
"model": inputs.model or "",
@@ -1046,6 +1071,10 @@ async def _call_anthropic(
10461071
if inputs.api_key:
10471072
headers["x-api-key"] = inputs.api_key
10481073

1074+
# Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing)
1075+
if inputs.extra_headers:
1076+
headers.update(_resolve_header_env_vars(inputs.extra_headers))
1077+
10491078
async with httpx.AsyncClient(timeout=timeout) as client:
10501079
response = await client.post(url, json=body, headers=headers)
10511080
response.raise_for_status()
@@ -1428,16 +1457,25 @@ async def execute( # type: ignore[override]
14281457
if model is None:
14291458
model = "text-embedding-3-small"
14301459

1460+
# Resolve extra_headers from profile config
1461+
default_headers: dict[str, str] | None = None
1462+
if resolved_config and resolved_config.extra_headers:
1463+
default_headers = _resolve_header_env_vars(resolved_config.extra_headers)
1464+
14311465
# Resolve timeout
14321466
timeout = resolve_interpolatable_numeric(inputs.timeout, int, "timeout", ge=1, le=300)
14331467

14341468
try:
14351469
# Use OpenAI SDK which works with any OpenAI-compatible server
1436-
client = AsyncOpenAI(
1437-
api_key=api_key or "not-required", # Some local servers don't need API key
1438-
base_url=api_url, # None = default OpenAI endpoint
1439-
timeout=float(timeout),
1440-
)
1470+
client_kwargs: dict[str, Any] = {
1471+
"api_key": api_key or "not-required",
1472+
"base_url": api_url,
1473+
"timeout": float(timeout),
1474+
}
1475+
if default_headers:
1476+
client_kwargs["default_headers"] = default_headers
1477+
1478+
client = AsyncOpenAI(**client_kwargs)
14411479

14421480
response = await client.embeddings.create(
14431481
model=model,

src/workflows_mcp/engine/llm_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ class ProviderConfig(BaseModel):
117117
le=60.0,
118118
description="Initial retry delay in seconds (exponential backoff)",
119119
)
120+
extra_headers: dict[str, str] = Field(
121+
default_factory=dict,
122+
description=(
123+
"Custom HTTP headers to include in all requests to this provider. "
124+
"Values can use {ENV_VAR} syntax for runtime environment variable resolution."
125+
),
126+
)
120127
# Azure OpenAI specific fields
121128
deployment_name: str | None = Field(
122129
default=None,
@@ -237,6 +244,10 @@ class ResolvedLLMConfig(BaseModel):
237244
temperature: float | None = Field(default=None, description="Sampling temperature")
238245
max_tokens: int | None = Field(default=None, description="Maximum tokens to generate")
239246
system_instructions: str | None = Field(default=None, description="System instructions")
247+
extra_headers: dict[str, str] = Field(
248+
default_factory=dict,
249+
description="Custom HTTP headers for provider requests",
250+
)
240251
# Azure OpenAI specific
241252
deployment_name: str | None = Field(default=None)
242253
api_version: str | None = Field(default=None)
@@ -447,6 +458,7 @@ def resolve_profile(
447458
system_instructions=inline.get("system_instructions"),
448459
deployment_name=inline.get("deployment_name", provider_config.deployment_name),
449460
api_version=inline.get("api_version", provider_config.api_version),
461+
extra_headers=inline.get("extra_headers", provider_config.extra_headers),
450462
)
451463

452464
def get_default_profile(self) -> str | None:

0 commit comments

Comments
 (0)