|
| 1 | +import json |
1 | 2 | from unittest.mock import AsyncMock, MagicMock, patch |
2 | 3 |
|
3 | 4 | import pytest |
|
28 | 29 | from kiln_ai.datamodel.skill import Skill |
29 | 30 | from kiln_ai.datamodel.tool_id import KilnBuiltInToolId |
30 | 31 | from kiln_ai.tools.base_tool import KilnToolInterface |
| 32 | +from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam |
31 | 33 |
|
32 | 34 |
|
33 | 35 | class MockAdapter(BaseAdapter): |
@@ -446,6 +448,13 @@ def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapt |
446 | 448 | assert formatter.initial_messages() == prior_trace |
447 | 449 |
|
448 | 450 |
|
| 451 | +def test_build_chat_formatter_empty_prior_trace_matches_none(adapter): |
| 452 | + fmt_empty = adapter.build_chat_formatter("new input", prior_trace=[]) |
| 453 | + fmt_none = adapter.build_chat_formatter("new input", prior_trace=None) |
| 454 | + assert type(fmt_empty) is type(fmt_none) |
| 455 | + assert fmt_empty.__class__.__name__ != "MultiturnFormatter" |
| 456 | + |
| 457 | + |
449 | 458 | @pytest.mark.asyncio |
450 | 459 | async def test_invoke_with_prior_trace_none_starts_fresh(base_project): |
451 | 460 | task = Task( |
@@ -493,6 +502,7 @@ async def test_invoke_with_prior_trace_none_starts_fresh(base_project): |
493 | 502 | ), |
494 | 503 | ): |
495 | 504 | run = await adapter.invoke("input", prior_trace=None) |
| 505 | + assert isinstance(run, TaskRun) |
496 | 506 | assert run.output.output == "ok" |
497 | 507 | adapter._run.assert_called_once() |
498 | 508 | assert adapter._run.call_args[1].get("prior_trace") is None |
@@ -549,6 +559,244 @@ async def mock_run(input, **kwargs): |
549 | 559 | assert captured_prior_trace == trace |
550 | 560 |
|
551 | 561 |
|
| 562 | +_INPUT_OBJECT_SCHEMA = json.dumps( |
| 563 | + { |
| 564 | + "type": "object", |
| 565 | + "properties": {"x": {"type": "number"}}, |
| 566 | + "required": ["x"], |
| 567 | + } |
| 568 | +) |
| 569 | + |
| 570 | +_MULTITURN_STRUCTURED_ERROR = ( |
| 571 | + "Cannot run multiturn execution with a task that has a structured input schema" |
| 572 | +) |
| 573 | + |
| 574 | + |
| 575 | +def test_normalize_prior_trace_empty_and_none(): |
| 576 | + assert BaseAdapter._normalize_prior_trace(None) is None |
| 577 | + assert BaseAdapter._normalize_prior_trace([]) is None |
| 578 | + trace: list[ChatCompletionMessageParam] = [{"role": "user", "content": "h"}] |
| 579 | + assert BaseAdapter._normalize_prior_trace(trace) == trace |
| 580 | + |
| 581 | + |
| 582 | +@pytest.mark.asyncio |
| 583 | +async def test_invoke_rejects_multiturn_with_structured_input(tmp_path): |
| 584 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 585 | + project.save_to_file() |
| 586 | + task = Task( |
| 587 | + name="t", |
| 588 | + instruction="i", |
| 589 | + parent=project, |
| 590 | + ) |
| 591 | + task.save_to_file() |
| 592 | + adapter = MockAdapter( |
| 593 | + task=task, |
| 594 | + run_config=KilnAgentRunConfigProperties( |
| 595 | + model_name="gpt_4o", |
| 596 | + model_provider_name=ModelProviderName.openai, |
| 597 | + prompt_id="simple_prompt_builder", |
| 598 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 599 | + ), |
| 600 | + ) |
| 601 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 602 | + adapter._run = AsyncMock() |
| 603 | + prior_trace: list[ChatCompletionMessageParam] = [ |
| 604 | + {"role": "user", "content": "hi"}, |
| 605 | + ] |
| 606 | + |
| 607 | + with pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR): |
| 608 | + await adapter.invoke({"x": 1}, prior_trace=prior_trace) |
| 609 | + |
| 610 | + adapter._run.assert_not_called() |
| 611 | + |
| 612 | + |
| 613 | +@pytest.mark.asyncio |
| 614 | +@pytest.mark.parametrize("prior_trace", [None, []]) |
| 615 | +async def test_invoke_validates_input_schema_when_single_turn( |
| 616 | + tmp_path, prior_trace: list[ChatCompletionMessageParam] | None |
| 617 | +): |
| 618 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 619 | + project.save_to_file() |
| 620 | + task = Task( |
| 621 | + name="t", |
| 622 | + instruction="i", |
| 623 | + parent=project, |
| 624 | + ) |
| 625 | + task.save_to_file() |
| 626 | + adapter = MockAdapter( |
| 627 | + task=task, |
| 628 | + run_config=KilnAgentRunConfigProperties( |
| 629 | + model_name="gpt_4o", |
| 630 | + model_provider_name=ModelProviderName.openai, |
| 631 | + prompt_id="simple_prompt_builder", |
| 632 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 633 | + ), |
| 634 | + ) |
| 635 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 636 | + adapter._run = AsyncMock() |
| 637 | + |
| 638 | + with pytest.raises(ValueError, match="input schema"): |
| 639 | + await adapter.invoke({}, prior_trace=prior_trace) |
| 640 | + |
| 641 | + adapter._run.assert_not_called() |
| 642 | + |
| 643 | + |
| 644 | +@pytest.mark.asyncio |
| 645 | +@pytest.mark.parametrize("prior_trace", [None, []]) |
| 646 | +async def test_invoke_empty_prior_trace_like_none_allows_structured_input( |
| 647 | + tmp_path, prior_trace: list[ChatCompletionMessageParam] | None |
| 648 | +): |
| 649 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 650 | + project.save_to_file() |
| 651 | + task = Task( |
| 652 | + name="t", |
| 653 | + instruction="i", |
| 654 | + parent=project, |
| 655 | + ) |
| 656 | + task.save_to_file() |
| 657 | + adapter = MockAdapter( |
| 658 | + task=task, |
| 659 | + run_config=KilnAgentRunConfigProperties( |
| 660 | + model_name="gpt_4o", |
| 661 | + model_provider_name=ModelProviderName.openai, |
| 662 | + prompt_id="simple_prompt_builder", |
| 663 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 664 | + ), |
| 665 | + ) |
| 666 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 667 | + adapter._run = AsyncMock( |
| 668 | + return_value=( |
| 669 | + RunOutput(output="ok", intermediate_outputs=None, trace=None), |
| 670 | + None, |
| 671 | + ) |
| 672 | + ) |
| 673 | + |
| 674 | + with ( |
| 675 | + patch( |
| 676 | + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", |
| 677 | + return_value=MagicMock( |
| 678 | + parse_output=MagicMock( |
| 679 | + return_value=RunOutput( |
| 680 | + output="ok", intermediate_outputs=None, trace=None |
| 681 | + ) |
| 682 | + ) |
| 683 | + ), |
| 684 | + ), |
| 685 | + patch( |
| 686 | + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", |
| 687 | + ), |
| 688 | + patch.object( |
| 689 | + adapter, |
| 690 | + "model_provider", |
| 691 | + return_value=MagicMock( |
| 692 | + parser="default", |
| 693 | + formatter=None, |
| 694 | + reasoning_capable=False, |
| 695 | + ), |
| 696 | + ), |
| 697 | + ): |
| 698 | + await adapter.invoke({"x": 1}, prior_trace=prior_trace) |
| 699 | + |
| 700 | + adapter._run.assert_called_once() |
| 701 | + assert adapter._run.call_args[1].get("prior_trace") is None |
| 702 | + |
| 703 | + |
| 704 | +def test_prepare_stream_rejects_multiturn_with_structured_input(tmp_path): |
| 705 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 706 | + project.save_to_file() |
| 707 | + task = Task( |
| 708 | + name="t", |
| 709 | + instruction="i", |
| 710 | + parent=project, |
| 711 | + ) |
| 712 | + task.save_to_file() |
| 713 | + adapter = MockAdapter( |
| 714 | + task=task, |
| 715 | + run_config=KilnAgentRunConfigProperties( |
| 716 | + model_name="gpt_4o", |
| 717 | + model_provider_name=ModelProviderName.openai, |
| 718 | + prompt_id="simple_prompt_builder", |
| 719 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 720 | + ), |
| 721 | + ) |
| 722 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 723 | + prior_trace: list[ChatCompletionMessageParam] = [ |
| 724 | + {"role": "user", "content": "hi"}, |
| 725 | + ] |
| 726 | + |
| 727 | + with ( |
| 728 | + patch.object( |
| 729 | + adapter, |
| 730 | + "model_provider", |
| 731 | + return_value=MagicMock(formatter=None), |
| 732 | + ), |
| 733 | + pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR), |
| 734 | + ): |
| 735 | + adapter._prepare_stream({"x": 1}, prior_trace=prior_trace) |
| 736 | + |
| 737 | + |
| 738 | +@pytest.mark.parametrize("prior_trace", [None, []]) |
| 739 | +def test_prepare_stream_validates_input_schema_when_single_turn( |
| 740 | + tmp_path, prior_trace: list[ChatCompletionMessageParam] | None |
| 741 | +): |
| 742 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 743 | + project.save_to_file() |
| 744 | + task = Task( |
| 745 | + name="t", |
| 746 | + instruction="i", |
| 747 | + parent=project, |
| 748 | + ) |
| 749 | + task.save_to_file() |
| 750 | + adapter = MockAdapter( |
| 751 | + task=task, |
| 752 | + run_config=KilnAgentRunConfigProperties( |
| 753 | + model_name="gpt_4o", |
| 754 | + model_provider_name=ModelProviderName.openai, |
| 755 | + prompt_id="simple_prompt_builder", |
| 756 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 757 | + ), |
| 758 | + ) |
| 759 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 760 | + invalid_input: dict = {} |
| 761 | + |
| 762 | + with ( |
| 763 | + patch.object( |
| 764 | + adapter, |
| 765 | + "model_provider", |
| 766 | + return_value=MagicMock(formatter=None), |
| 767 | + ), |
| 768 | + pytest.raises(ValueError, match="input schema"), |
| 769 | + ): |
| 770 | + adapter._prepare_stream(invalid_input, prior_trace=prior_trace) |
| 771 | + |
| 772 | + |
| 773 | +def test_build_chat_formatter_rejects_multiturn_with_structured_input(tmp_path): |
| 774 | + project = Project(name="proj", path=tmp_path / "proj.kiln") |
| 775 | + project.save_to_file() |
| 776 | + task = Task( |
| 777 | + name="t", |
| 778 | + instruction="i", |
| 779 | + parent=project, |
| 780 | + ) |
| 781 | + task.save_to_file() |
| 782 | + adapter = MockAdapter( |
| 783 | + task=task, |
| 784 | + run_config=KilnAgentRunConfigProperties( |
| 785 | + model_name="gpt_4o", |
| 786 | + model_provider_name=ModelProviderName.openai, |
| 787 | + prompt_id="simple_prompt_builder", |
| 788 | + structured_output_mode=StructuredOutputMode.json_schema, |
| 789 | + ), |
| 790 | + ) |
| 791 | + adapter.input_schema = _INPUT_OBJECT_SCHEMA |
| 792 | + prior_trace: list[ChatCompletionMessageParam] = [ |
| 793 | + {"role": "user", "content": "hi"}, |
| 794 | + ] |
| 795 | + |
| 796 | + with pytest.raises(ValueError, match=_MULTITURN_STRUCTURED_ERROR): |
| 797 | + adapter.build_chat_formatter("new input", prior_trace=prior_trace) |
| 798 | + |
| 799 | + |
552 | 800 | @pytest.mark.parametrize( |
553 | 801 | "initial_mode,expected_mode", |
554 | 802 | [ |
|
0 commit comments