Skip to content

Commit 3fcb941

Browse files
committed
fix(workflows-mcp): migrate Ollama executor to /api/chat with reliable structured output
Switch from /api/generate to /api/chat for native structured output support. Ollama's JSON Schema grammar constraint is only available on the chat endpoint. Key changes: - Use /api/chat with messages array instead of /api/generate with flat prompt - Pass format parameter (inputs.response_schema) for grammar-enforced JSON output - Force temperature=0 via options when a schema is requested for determinism - Append explicit field hint to user message: 'Respond with JSON containing fields: <required>' - Introduce validation_schema split: OpenAI uses prepared_schema (strict), all other providers use inputs.response_schema to keep format/validation/retry prompt consistent - Add robust regex markdown extractor in _validate_response to handle chatty models - Add Gemini proxy mode (api_url-based routing) with extra_headers support - Log Ollama response content for observability - Move import re to module level (was incorrectly placed inside staticmethod) - Remove dead prepared_schema parameter from _call_ollama - Fix test api_url to include /api/chat path
1 parent bf3e6a5 commit 3fcb941

2 files changed

Lines changed: 108 additions & 37 deletions

File tree

src/workflows_mcp/engine/executors_llm.py

Lines changed: 105 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import json
1212
import logging
1313
import os
14+
import re
1415
from enum import Enum
1516
from typing import Any, ClassVar, cast
1617

@@ -409,13 +410,27 @@ async def execute( # type: ignore[override]
409410
# Determine if we need schema validation
410411
needs_validation = effective_inputs.response_schema is not None
411412

412-
# Prepare schema ONCE before retry loop - this is the schema LLM receives
413-
# CRITICAL: Use the SAME prepared schema for validation to avoid mismatch
414-
# where LLM generates null (allowed by prepared schema) but validation
415-
# uses original schema (which doesn't allow null)
413+
# Prepare schema ONCE before retry loop
414+
# prepared_schema: strict OpenAI-compatible schema (for OpenAI's response_format param)
415+
# validation_schema: schema used for client-side validation AND retry error prompts
416+
#
417+
# For OpenAI: both are prepared_schema (strict, with anyOf/additionalProperties)
418+
# For all other providers (Ollama, Anthropic, Gemini): validation_schema stays
419+
# as inputs.response_schema so the retry prompt matches what format sends
416420
prepared_schema: dict[str, Any] | None = None
421+
validation_schema: dict[str, Any] | None = None
417422
if needs_validation:
418423
prepared_schema = self._prepare_schema_for_openai(effective_inputs.response_schema)
424+
# Resolve effective provider to determine which schema to validate against
425+
_provider = resolve_interpolatable_enum(
426+
effective_inputs.provider, LLMProvider, "provider"
427+
)
428+
if _provider == LLMProvider.OPENAI:
429+
# OpenAI sends prepared_schema to the API, validate against same schema
430+
validation_schema = prepared_schema
431+
else:
432+
# Other providers use inputs.response_schema as format — validate consistently
433+
validation_schema = effective_inputs.response_schema
419434

420435
for attempt in range(max_retries):
421436
attempts += 1
@@ -431,7 +446,7 @@ async def execute( # type: ignore[override]
431446
+ "\n\n"
432447
+ effective_inputs.validation_prompt_template.format(
433448
validation_error=validation_error,
434-
schema=json.dumps(prepared_schema, indent=2),
449+
schema=json.dumps(validation_schema, indent=2),
435450
)
436451
)
437452

@@ -449,10 +464,10 @@ async def execute( # type: ignore[override]
449464
# Validate response if schema provided (client-side for all providers)
450465
if needs_validation:
451466
try:
452-
# Use the SAME prepared schema that was sent to the LLM
467+
# Validate against the schema matching what the provider was given
453468
validated_response = self._validate_response(
454469
response_text=response_text,
455-
schema=cast(dict[str, Any], prepared_schema),
470+
schema=cast(dict[str, Any], validation_schema),
456471
)
457472

458473
# Success - return validated JSON structure directly
@@ -1114,15 +1129,24 @@ async def _call_gemini(
11141129
) -> tuple[str, dict[str, Any]]:
11151130
"""Call Google Gemini API with null safety.
11161131
1132+
Supports two modes:
1133+
- Direct: api_key required, constructs googleapis.com URL with ?key= param
1134+
- Proxy: api_url is set (via profile), proxy handles auth — no api_key needed
1135+
11171136
Raises:
1118-
ValueError: Missing API key, empty content, or null text
1137+
ValueError: Missing API key (direct mode), empty content, or null text
11191138
httpx.*: Network/API errors
11201139
"""
1121-
if not inputs.api_key:
1122-
raise ValueError("api_key is required for Gemini provider")
1123-
1124-
base_url = "https://generativelanguage.googleapis.com/v1beta"
1125-
url = f"{base_url}/models/{inputs.model}:generateContent?key={inputs.api_key}"
1140+
if inputs.api_url:
1141+
# Proxy mode — proxy manages the API key, use api_url as base
1142+
base_url = inputs.api_url.rstrip("/")
1143+
url = f"{base_url}/v1beta/models/{inputs.model}:generateContent"
1144+
else:
1145+
# Direct mode — api_key required
1146+
if not inputs.api_key:
1147+
raise ValueError("api_key is required for Gemini provider")
1148+
base_url = "https://generativelanguage.googleapis.com/v1beta"
1149+
url = f"{base_url}/models/{inputs.model}:generateContent?key={inputs.api_key}"
11261150

11271151
contents = [{"parts": [{"text": prompt}], "role": "user"}]
11281152

@@ -1142,12 +1166,14 @@ async def _call_gemini(
11421166
if generation_config:
11431167
body["generationConfig"] = generation_config
11441168

1169+
headers: dict[str, str] = {"Content-Type": "application/json"}
1170+
1171+
# Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing)
1172+
if inputs.extra_headers:
1173+
headers.update(_resolve_header_env_vars(inputs.extra_headers))
1174+
11451175
async with httpx.AsyncClient(timeout=timeout) as client:
1146-
response = await client.post(
1147-
url,
1148-
json=body,
1149-
headers={"Content-Type": "application/json"},
1150-
)
1176+
response = await client.post(url, json=body, headers=headers)
11511177
response.raise_for_status()
11521178

11531179
data = response.json()
@@ -1190,42 +1216,73 @@ async def _call_ollama(
11901216
temperature: float | None,
11911217
max_tokens: int | None,
11921218
) -> tuple[str, dict[str, Any]]:
1193-
"""Call Ollama local API with null safety.
1219+
"""Call Ollama API via /api/chat with native structured output.
1220+
1221+
Uses /api/chat (messages format) instead of /api/generate because
1222+
Ollama's structured output (``format`` parameter with JSON schema)
1223+
is only supported on the chat endpoint.
11941224
11951225
Raises:
11961226
ValueError: Null response
11971227
httpx.*: Network/API errors
11981228
"""
1199-
url = inputs.api_url or "http://localhost:11434/api/generate"
1229+
url = inputs.api_url or "http://localhost:11434/api/chat"
12001230

1201-
# Combine system instructions and prompt for Ollama
1202-
full_prompt = prompt
1231+
# Build messages list (Ollama /api/chat uses messages array)
1232+
messages: list[dict[str, str]] = []
12031233
if inputs.system_instructions:
1204-
full_prompt = f"{inputs.system_instructions}\n\n{prompt}"
1234+
messages.append({"role": "system", "content": inputs.system_instructions})
1235+
1236+
# Ollama tip: append JSON instruction to the user message when a schema is defined.
1237+
# Explicitly name the required fields so models without strict grammar enforcement
1238+
# still know what keys to include.
1239+
user_content = prompt
1240+
if inputs.response_schema:
1241+
required_fields = inputs.response_schema.get("required", [])
1242+
if required_fields:
1243+
fields_hint = ", ".join(required_fields)
1244+
user_content = f"{prompt}\n\nRespond with JSON containing fields: {fields_hint}"
1245+
else:
1246+
user_content = f"{prompt}\n\nRespond with JSON"
1247+
messages.append({"role": "user", "content": user_content})
12051248

12061249
body: dict[str, Any] = {
12071250
"model": inputs.model,
1208-
"prompt": full_prompt,
1251+
"messages": messages,
12091252
"stream": False,
12101253
}
12111254

1212-
if temperature is not None:
1255+
# Native structured output — Ollama uses 'format' parameter on /api/chat
1256+
# Use raw inputs.response_schema (NOT prepared_schema) because llama.cpp's grammar
1257+
# engine can't reliably handle OpenAI-specific patterns like anyOf, additionalProperties.
1258+
# prepared_schema is still used for client-side validation in execute().
1259+
if inputs.response_schema:
1260+
body["format"] = inputs.response_schema
1261+
# Ollama tip: force temperature=0 for deterministic structured output
1262+
body["options"] = {"temperature": 0}
1263+
elif temperature is not None:
12131264
body["options"] = {"temperature": temperature}
12141265

1266+
headers: dict[str, str] = {"Content-Type": "application/json"}
1267+
1268+
# Merge extra_headers (e.g., X-Org-Id, X-User-Id for proxy routing)
1269+
if inputs.extra_headers:
1270+
headers.update(_resolve_header_env_vars(inputs.extra_headers))
1271+
1272+
logger.info(f"Engine sending Ollama request to {url} with body: {json.dumps(body)}")
1273+
12151274
async with httpx.AsyncClient(timeout=timeout) as client:
1216-
response = await client.post(
1217-
url,
1218-
json=body,
1219-
headers={"Content-Type": "application/json"},
1220-
)
1275+
response = await client.post(url, json=body, headers=headers)
12211276
response.raise_for_status()
12221277

12231278
data = response.json()
12241279

1225-
# Extract content with null safety
1226-
response_text = data.get("response")
1280+
# Extract content from /api/chat response format
1281+
message = data.get("message", {})
1282+
response_text = message.get("content")
12271283
if response_text is None:
1228-
raise ValueError("Ollama returned null response")
1284+
raise ValueError("Ollama returned null content")
1285+
logger.info(f"Engine received Ollama response content: {response_text!r}")
12291286

12301287
provider_metadata = {
12311288
"model": data.get("model"),
@@ -1246,9 +1303,23 @@ def _validate_response(response_text: str, schema: dict[str, Any]) -> dict[str,
12461303
Raises:
12471304
ValueError: Invalid JSON, non-dict response, or schema validation failure
12481305
"""
1306+
# Robust extraction: find markdown json block anywhere in text
1307+
text = response_text.strip()
1308+
block_match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
1309+
1310+
if block_match:
1311+
text = block_match.group(1).strip()
1312+
else:
1313+
# Fallback for plain text containing JSON or conversational wrappers
1314+
# Find the outermost curly braces
1315+
start_idx = text.find("{")
1316+
end_idx = text.rfind("}")
1317+
if start_idx != -1 and end_idx != -1 and end_idx >= start_idx:
1318+
text = text[start_idx : end_idx + 1]
1319+
12491320
# Try to parse as JSON
12501321
try:
1251-
response = json.loads(response_text)
1322+
response = json.loads(text)
12521323
except json.JSONDecodeError as e:
12531324
raise ValueError(f"Response is not valid JSON: {e}")
12541325

tests/test_llm_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,20 @@ async def test_gemini_basic_call(self, executor, mock_context):
149149

150150
@pytest.mark.asyncio
151151
async def test_ollama_basic_call(self, executor, mock_context):
152-
"""Test basic Ollama local API call."""
152+
"""Test basic Ollama API call via /api/chat."""
153153
inputs = LLMCallInput(
154154
provider="ollama",
155155
model="llama2",
156156
prompt="Hello",
157-
api_url="http://localhost:11434/api/generate",
157+
api_url="http://localhost:11434/api/chat",
158158
timeout=60,
159159
)
160160

161161
mock_response = Mock()
162162
mock_response.status_code = 200
163163
mock_response.json.return_value = {
164164
"model": "llama2",
165-
"response": "Hello! How are you?",
165+
"message": {"role": "assistant", "content": "Hello! How are you?"},
166166
"total_duration": 1234567890,
167167
"load_duration": 123456,
168168
"prompt_eval_count": 5,

0 commit comments

Comments
 (0)