|
10 | 10 | import asyncio |
11 | 11 | import json |
12 | 12 | import logging |
| 13 | +import os |
13 | 14 | from enum import Enum |
14 | 15 | from typing import Any, ClassVar, cast |
15 | 16 |
|
@@ -50,6 +51,21 @@ class LLMProvider(str, Enum): |
50 | 51 | OLLAMA = "ollama" |
51 | 52 |
|
52 | 53 |
|
| 54 | +def _resolve_header_env_vars(headers: dict[str, str]) -> dict[str, str]: |
| 55 | + """Resolve {ENV_VAR} placeholders in header values. |
| 56 | +
|
| 57 | + Values wrapped in curly braces (e.g., "{ORG_ID}") are resolved from |
| 58 | + environment variables at runtime. Other values are passed through as-is. |
| 59 | + """ |
| 60 | + resolved = {} |
| 61 | + for k, v in headers.items(): |
| 62 | + if v.startswith("{") and v.endswith("}"): |
| 63 | + resolved[k] = os.environ.get(v[1:-1], "") |
| 64 | + else: |
| 65 | + resolved[k] = v |
| 66 | + return resolved |
| 67 | + |
| 68 | + |
53 | 69 | # =========================================================================== |
54 | 70 | # LLMCall Executor |
55 | 71 | # =========================================================================== |
@@ -110,6 +126,10 @@ class LLMCallInput(BlockInput): |
110 | 126 | default=None, |
111 | 127 | description="Custom API endpoint URL (optional, for custom deployments)", |
112 | 128 | ) |
| 129 | + extra_headers: dict[str, str] = Field( |
| 130 | + default_factory=dict, |
| 131 | + description="Custom HTTP headers for provider requests (resolved from config)", |
| 132 | + ) |
113 | 133 | response_schema: dict[str, Any] | str | None = Field( |
114 | 134 | default=None, |
115 | 135 | description=( |
@@ -604,6 +624,7 @@ async def _resolve_profile_to_inputs( |
604 | 624 | temperature=resolved_config.temperature, |
605 | 625 | max_tokens=resolved_config.max_tokens, |
606 | 626 | validation_prompt_template=inputs.validation_prompt_template, |
| 627 | + extra_headers=resolved_config.extra_headers, |
607 | 628 | ) |
608 | 629 |
|
609 | 630 | def _resolve_profile_with_fallback( |
@@ -915,6 +936,10 @@ async def _call_openai( |
915 | 936 | base_url = base_url.rsplit("/chat/completions", 1)[0] |
916 | 937 | client_kwargs["base_url"] = base_url |
917 | 938 |
|
| 939 | + # Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing) |
| 940 | + if inputs.extra_headers: |
| 941 | + client_kwargs["default_headers"] = _resolve_header_env_vars(inputs.extra_headers) |
| 942 | + |
918 | 943 | # Prepare completion parameters (required parameters only) |
919 | 944 | completion_kwargs: dict[str, Any] = { |
920 | 945 | "model": inputs.model or "", |
@@ -1046,6 +1071,10 @@ async def _call_anthropic( |
1046 | 1071 | if inputs.api_key: |
1047 | 1072 | headers["x-api-key"] = inputs.api_key |
1048 | 1073 |
|
| 1074 | + # Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing) |
| 1075 | + if inputs.extra_headers: |
| 1076 | + headers.update(_resolve_header_env_vars(inputs.extra_headers)) |
| 1077 | + |
1049 | 1078 | async with httpx.AsyncClient(timeout=timeout) as client: |
1050 | 1079 | response = await client.post(url, json=body, headers=headers) |
1051 | 1080 | response.raise_for_status() |
@@ -1428,16 +1457,25 @@ async def execute( # type: ignore[override] |
1428 | 1457 | if model is None: |
1429 | 1458 | model = "text-embedding-3-small" |
1430 | 1459 |
|
| 1460 | + # Resolve extra_headers from profile config |
| 1461 | + default_headers: dict[str, str] | None = None |
| 1462 | + if resolved_config and resolved_config.extra_headers: |
| 1463 | + default_headers = _resolve_header_env_vars(resolved_config.extra_headers) |
| 1464 | + |
1431 | 1465 | # Resolve timeout |
1432 | 1466 | timeout = resolve_interpolatable_numeric(inputs.timeout, int, "timeout", ge=1, le=300) |
1433 | 1467 |
|
1434 | 1468 | try: |
1435 | 1469 | # Use OpenAI SDK which works with any OpenAI-compatible server |
1436 | | - client = AsyncOpenAI( |
1437 | | - api_key=api_key or "not-required", # Some local servers don't need API key |
1438 | | - base_url=api_url, # None = default OpenAI endpoint |
1439 | | - timeout=float(timeout), |
1440 | | - ) |
| 1470 | + client_kwargs: dict[str, Any] = { |
| 1471 | + "api_key": api_key or "not-required", |
| 1472 | + "base_url": api_url, |
| 1473 | + "timeout": float(timeout), |
| 1474 | + } |
| 1475 | + if default_headers: |
| 1476 | + client_kwargs["default_headers"] = default_headers |
| 1477 | + |
| 1478 | + client = AsyncOpenAI(**client_kwargs) |
1441 | 1479 |
|
1442 | 1480 | response = await client.embeddings.create( |
1443 | 1481 | model=model, |
|
0 commit comments