Skip to content

Commit 4d94731

Browse files
committed
refactor: reject multiturn for structured task (not supported)
1 parent 2edf7b8 commit 4d94731

File tree

3 files changed

+277
-31
lines changed

3 files changed

+277
-31
lines changed

libs/core/kiln_ai/adapters/chat/test_chat_formatter.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter
42
from kiln_ai.adapters.chat.chat_formatter import (
53
COT_FINAL_ANSWER_PROMPT,
@@ -235,24 +233,6 @@ def test_multiturn_formatter_multiple_tool_results():
235233
assert first.final_call
236234

237235

238-
def _make_formatter(user_input: Any) -> MultiturnFormatter:
239-
return MultiturnFormatter(
240-
prior_trace=[{"role": "system", "content": "sys"}], user_input=user_input
241-
) # type: ignore[arg-type]
242-
243-
244-
def test_multiturn_formatter_is_tool_result_detection():
245-
"""_is_tool_result correctly identifies tool result inputs."""
246-
assert _make_formatter({"tool_call_id": "x", "content": "y"})._is_tool_result
247-
assert _make_formatter(
248-
[{"tool_call_id": "x", "content": "y"}, {"tool_call_id": "z", "content": "w"}]
249-
)._is_tool_result
250-
assert not _make_formatter("plain string")._is_tool_result
251-
assert not _make_formatter({"content": "no id"})._is_tool_result
252-
assert not _make_formatter([])._is_tool_result
253-
assert not _make_formatter([{"content": "no id"}])._is_tool_result
254-
255-
256236
def test_multiturn_formatter_user_input_not_confused_with_tool_result():
257237
"""A regular dict input (no tool_call_id) is treated as a user message."""
258238
prior_trace = [{"role": "system", "content": "sys"}]

libs/core/kiln_ai/adapters/model_adapters/base_adapter.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,24 @@ def model_provider(self) -> KilnModelProvider:
158158
)
159159
return self._model_provider
160160

161+
@staticmethod
162+
def _normalize_prior_trace(
163+
prior_trace: list[ChatCompletionMessageParam] | None,
164+
) -> list[ChatCompletionMessageParam] | None:
165+
if not prior_trace:
166+
return None
167+
return prior_trace
168+
169+
def _reject_multiturn_with_structured_input(
170+
self,
171+
prior_trace: list[ChatCompletionMessageParam] | None,
172+
) -> None:
173+
if prior_trace is not None and self.input_schema is not None:
174+
raise ValueError(
175+
"Cannot run multiturn execution with a task that has a structured input schema. "
176+
"Use an unstructured task, or call without prior_trace."
177+
)
178+
161179
async def invoke(
162180
self,
163181
input: InputType,
@@ -177,19 +195,18 @@ async def _run_returning_run_output(
177195
prior_trace: list[ChatCompletionMessageParam] | None = None,
178196
parent_task_run: TaskRun | None = None,
179197
) -> Tuple[TaskRun, RunOutput]:
180-
# validate input, allowing arrays.
181-
# Skip when prior_trace is provided: the input may be a tool result or a
182-
# follow-up message that shouldn't be validated against the task input schema.
183-
if self.input_schema is not None and prior_trace is None:
198+
prior_trace = self._normalize_prior_trace(prior_trace)
199+
self._reject_multiturn_with_structured_input(prior_trace)
200+
201+
# validate input, allowing arrays
202+
if self.input_schema is not None:
184203
validate_schema_with_value_error(
185204
input,
186205
self.input_schema,
187206
"This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
188207
require_object=False,
189208
)
190209

191-
prior_trace = prior_trace if prior_trace else None
192-
193210
# Format model input for model call (we save the original input in the task without formatting)
194211
formatted_input = input
195212
formatter_id = self.model_provider().formatter
@@ -339,18 +356,17 @@ def _prepare_stream(
339356
input: InputType,
340357
prior_trace: list[ChatCompletionMessageParam] | None,
341358
) -> AdapterStream:
342-
# Skip input schema validation when prior_trace is provided: the input may be
343-
# a tool result or follow-up message not matching the task input schema.
344-
if self.input_schema is not None and prior_trace is None:
359+
prior_trace = self._normalize_prior_trace(prior_trace)
360+
self._reject_multiturn_with_structured_input(prior_trace)
361+
362+
if self.input_schema is not None:
345363
validate_schema_with_value_error(
346364
input,
347365
self.input_schema,
348366
"This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
349367
require_object=False,
350368
)
351369

352-
prior_trace = prior_trace if prior_trace else None
353-
354370
formatted_input = input
355371
formatter_id = self.model_provider().formatter
356372
if formatter_id is not None:
@@ -533,6 +549,8 @@ def build_chat_formatter(
533549
input: InputType,
534550
prior_trace: list[ChatCompletionMessageParam] | None = None,
535551
) -> ChatFormatter:
552+
prior_trace = self._normalize_prior_trace(prior_trace)
553+
self._reject_multiturn_with_structured_input(prior_trace)
536554
if prior_trace is not None:
537555
return MultiturnFormatter(prior_trace, input)
538556
if self.prompt_builder is None:

libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from unittest.mock import AsyncMock, MagicMock, patch
23

34
import pytest
@@ -28,6 +29,7 @@
2829
from kiln_ai.datamodel.skill import Skill
2930
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId
3031
from kiln_ai.tools.base_tool import KilnToolInterface
32+
from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
3133

3234

3335
class MockAdapter(BaseAdapter):
@@ -446,6 +448,13 @@ def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapt
446448
assert formatter.initial_messages() == prior_trace
447449

448450

451+
def test_build_chat_formatter_empty_prior_trace_matches_none(adapter):
452+
fmt_empty = adapter.build_chat_formatter("new input", prior_trace=[])
453+
fmt_none = adapter.build_chat_formatter("new input", prior_trace=None)
454+
assert type(fmt_empty) is type(fmt_none)
455+
assert fmt_empty.__class__.__name__ != "MultiturnFormatter"
456+
457+
449458
@pytest.mark.asyncio
450459
async def test_invoke_with_prior_trace_none_starts_fresh(base_project):
451460
task = Task(
@@ -493,6 +502,7 @@ async def test_invoke_with_prior_trace_none_starts_fresh(base_project):
493502
),
494503
):
495504
run = await adapter.invoke("input", prior_trace=None)
505+
assert isinstance(run, TaskRun)
496506
assert run.output.output == "ok"
497507
adapter._run.assert_called_once()
498508
assert adapter._run.call_args[1].get("prior_trace") is None
@@ -549,6 +559,244 @@ async def mock_run(input, **kwargs):
549559
assert captured_prior_trace == trace
550560

551561

562+
_INPUT_OBJECT_SCHEMA = json.dumps(
563+
{
564+
"type": "object",
565+
"properties": {"x": {"type": "number"}},
566+
"required": ["x"],
567+
}
568+
)
569+
570+
_MULTITURN_STRUCTURED_ERROR = (
571+
"Cannot run multiturn execution with a task that has a structured input schema"
572+
)
573+
574+
575+
def test_normalize_prior_trace_empty_and_none():
576+
assert BaseAdapter._normalize_prior_trace(None) is None
577+
assert BaseAdapter._normalize_prior_trace([]) is None
578+
trace: list[ChatCompletionMessageParam] = [{"role": "user", "content": "h"}]
579+
assert BaseAdapter._normalize_prior_trace(trace) == trace
580+
581+
582+
@pytest.mark.asyncio
583+
async def test_invoke_rejects_multiturn_with_structured_input(tmp_path):
584+
project = Project(name="proj", path=tmp_path / "proj.kiln")
585+
project.save_to_file()
586+
task = Task(
587+
name="t",
588+
instruction="i",
589+
parent=project,
590+
)
591+
task.save_to_file()
592+
adapter = MockAdapter(
593+
task=task,
594+
run_config=KilnAgentRunConfigProperties(
595+
model_name="gpt_4o",
596+
model_provider_name=ModelProviderName.openai,
597+
prompt_id="simple_prompt_builder",
598+
structured_output_mode=StructuredOutputMode.json_schema,
599+
),
600+
)
601+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
602+
adapter._run = AsyncMock()
603+
prior_trace: list[ChatCompletionMessageParam] = [
604+
{"role": "user", "content": "hi"},
605+
]
606+
607+
with pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR):
608+
await adapter.invoke({"x": 1}, prior_trace=prior_trace)
609+
610+
adapter._run.assert_not_called()
611+
612+
613+
@pytest.mark.asyncio
614+
@pytest.mark.parametrize("prior_trace", [None, []])
615+
async def test_invoke_validates_input_schema_when_single_turn(
616+
tmp_path, prior_trace: list[ChatCompletionMessageParam] | None
617+
):
618+
project = Project(name="proj", path=tmp_path / "proj.kiln")
619+
project.save_to_file()
620+
task = Task(
621+
name="t",
622+
instruction="i",
623+
parent=project,
624+
)
625+
task.save_to_file()
626+
adapter = MockAdapter(
627+
task=task,
628+
run_config=KilnAgentRunConfigProperties(
629+
model_name="gpt_4o",
630+
model_provider_name=ModelProviderName.openai,
631+
prompt_id="simple_prompt_builder",
632+
structured_output_mode=StructuredOutputMode.json_schema,
633+
),
634+
)
635+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
636+
adapter._run = AsyncMock()
637+
638+
with pytest.raises(ValueError, match="input schema"):
639+
await adapter.invoke({}, prior_trace=prior_trace)
640+
641+
adapter._run.assert_not_called()
642+
643+
644+
@pytest.mark.asyncio
645+
@pytest.mark.parametrize("prior_trace", [None, []])
646+
async def test_invoke_empty_prior_trace_like_none_allows_structured_input(
647+
tmp_path, prior_trace: list[ChatCompletionMessageParam] | None
648+
):
649+
project = Project(name="proj", path=tmp_path / "proj.kiln")
650+
project.save_to_file()
651+
task = Task(
652+
name="t",
653+
instruction="i",
654+
parent=project,
655+
)
656+
task.save_to_file()
657+
adapter = MockAdapter(
658+
task=task,
659+
run_config=KilnAgentRunConfigProperties(
660+
model_name="gpt_4o",
661+
model_provider_name=ModelProviderName.openai,
662+
prompt_id="simple_prompt_builder",
663+
structured_output_mode=StructuredOutputMode.json_schema,
664+
),
665+
)
666+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
667+
adapter._run = AsyncMock(
668+
return_value=(
669+
RunOutput(output="ok", intermediate_outputs=None, trace=None),
670+
None,
671+
)
672+
)
673+
674+
with (
675+
patch(
676+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
677+
return_value=MagicMock(
678+
parse_output=MagicMock(
679+
return_value=RunOutput(
680+
output="ok", intermediate_outputs=None, trace=None
681+
)
682+
)
683+
),
684+
),
685+
patch(
686+
"kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id",
687+
),
688+
patch.object(
689+
adapter,
690+
"model_provider",
691+
return_value=MagicMock(
692+
parser="default",
693+
formatter=None,
694+
reasoning_capable=False,
695+
),
696+
),
697+
):
698+
await adapter.invoke({"x": 1}, prior_trace=prior_trace)
699+
700+
adapter._run.assert_called_once()
701+
assert adapter._run.call_args[1].get("prior_trace") is None
702+
703+
704+
def test_prepare_stream_rejects_multiturn_with_structured_input(tmp_path):
705+
project = Project(name="proj", path=tmp_path / "proj.kiln")
706+
project.save_to_file()
707+
task = Task(
708+
name="t",
709+
instruction="i",
710+
parent=project,
711+
)
712+
task.save_to_file()
713+
adapter = MockAdapter(
714+
task=task,
715+
run_config=KilnAgentRunConfigProperties(
716+
model_name="gpt_4o",
717+
model_provider_name=ModelProviderName.openai,
718+
prompt_id="simple_prompt_builder",
719+
structured_output_mode=StructuredOutputMode.json_schema,
720+
),
721+
)
722+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
723+
prior_trace: list[ChatCompletionMessageParam] = [
724+
{"role": "user", "content": "hi"},
725+
]
726+
727+
with (
728+
patch.object(
729+
adapter,
730+
"model_provider",
731+
return_value=MagicMock(formatter=None),
732+
),
733+
pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR),
734+
):
735+
adapter._prepare_stream({"x": 1}, prior_trace=prior_trace)
736+
737+
738+
@pytest.mark.parametrize("prior_trace", [None, []])
739+
def test_prepare_stream_validates_input_schema_when_single_turn(
740+
tmp_path, prior_trace: list[ChatCompletionMessageParam] | None
741+
):
742+
project = Project(name="proj", path=tmp_path / "proj.kiln")
743+
project.save_to_file()
744+
task = Task(
745+
name="t",
746+
instruction="i",
747+
parent=project,
748+
)
749+
task.save_to_file()
750+
adapter = MockAdapter(
751+
task=task,
752+
run_config=KilnAgentRunConfigProperties(
753+
model_name="gpt_4o",
754+
model_provider_name=ModelProviderName.openai,
755+
prompt_id="simple_prompt_builder",
756+
structured_output_mode=StructuredOutputMode.json_schema,
757+
),
758+
)
759+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
760+
invalid_input: dict = {}
761+
762+
with (
763+
patch.object(
764+
adapter,
765+
"model_provider",
766+
return_value=MagicMock(formatter=None),
767+
),
768+
pytest.raises(ValueError, match="input schema"),
769+
):
770+
adapter._prepare_stream(invalid_input, prior_trace=prior_trace)
771+
772+
773+
def test_build_chat_formatter_rejects_multiturn_with_structured_input(tmp_path):
774+
project = Project(name="proj", path=tmp_path / "proj.kiln")
775+
project.save_to_file()
776+
task = Task(
777+
name="t",
778+
instruction="i",
779+
parent=project,
780+
)
781+
task.save_to_file()
782+
adapter = MockAdapter(
783+
task=task,
784+
run_config=KilnAgentRunConfigProperties(
785+
model_name="gpt_4o",
786+
model_provider_name=ModelProviderName.openai,
787+
prompt_id="simple_prompt_builder",
788+
structured_output_mode=StructuredOutputMode.json_schema,
789+
),
790+
)
791+
adapter.input_schema = _INPUT_OBJECT_SCHEMA
792+
prior_trace: list[ChatCompletionMessageParam] = [
793+
{"role": "user", "content": "hi"},
794+
]
795+
796+
with pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR):
797+
adapter.build_chat_formatter("new input", prior_trace=prior_trace)
798+
799+
552800
@pytest.mark.parametrize(
553801
"initial_mode,expected_mode",
554802
[

0 commit comments

Comments
 (0)