Skip to content

Commit a6dc6ba

Browse files
authored
Merge pull request #1141 from Kiln-AI/leonard/kil-466-refactor-adapter-should-expose-toolcalls
refactor: client control toolcalls
2 parents e022209 + 4d94731 commit a6dc6ba

15 files changed

+1650
-98
lines changed

libs/core/kiln_ai/adapters/chat/chat_formatter.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ class MultiturnFormatter(ChatFormatter):
256256
Takes prior_trace (existing conversation) and appends the new user message.
257257
Produces a single turn: the new user message. Tool calls and multi-turn
258258
model responses are handled by _run_model_turn's internal loop.
259+
260+
When user_input is a dict or list with tool_call_id keys, the input is
261+
treated as tool call results (role "tool") rather than a user message.
262+
This supports resuming after a return_on_tool_call interrupt.
259263
"""
260264

261265
def __init__(
@@ -274,14 +278,44 @@ def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
274278
"""Messages to seed the conversation (prior trace)."""
275279
return list(self._prior_trace)
276280

281+
@property
282+
def _is_tool_result(self) -> bool:
283+
"""Return True if user_input looks like one or more tool call results."""
284+
input = self.user_input
285+
if isinstance(input, dict):
286+
return "tool_call_id" in input
287+
if isinstance(input, list):
288+
return bool(input) and all(
289+
isinstance(item, dict) and "tool_call_id" in item for item in input
290+
)
291+
return False
292+
277293
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
278294
if self._state == "start":
279-
# prior trace is already in the messages list and contains system and so on, we only need
280-
# to append the latest new user message
281-
user_msg = BasicChatMessage("user", format_user_message(self.user_input))
282295
self._state = "awaiting_final"
283-
self._messages.append(user_msg)
284-
return ChatTurn(messages=[user_msg], final_call=True)
296+
if self._is_tool_result:
297+
if isinstance(self.user_input, dict):
298+
raw_items: list[dict] = [self.user_input]
299+
else:
300+
raw_items = list(self.user_input) # type: ignore[arg-type]
301+
msgs: list[ChatMessage] = [
302+
ToolResponseMessage(
303+
role="tool",
304+
content=str(item.get("content", "")),
305+
tool_call_id=item["tool_call_id"],
306+
)
307+
for item in raw_items
308+
]
309+
self._messages.extend(msgs)
310+
return ChatTurn(messages=msgs, final_call=True)
311+
else:
312+
# prior trace is already in the messages list and contains system and so on, we only need
313+
# to append the latest new user message
314+
user_msg = BasicChatMessage(
315+
"user", format_user_message(self.user_input)
316+
)
317+
self._messages.append(user_msg)
318+
return ChatTurn(messages=[user_msg], final_call=True)
285319

286320
if self._state == "awaiting_final":
287321
if previous_output is None:

libs/core/kiln_ai/adapters/chat/test_chat_formatter.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,62 @@ def test_multiturn_formatter_preserves_tool_call_messages():
190190
assert first.final_call
191191

192192

193+
def test_multiturn_formatter_single_tool_result():
194+
"""Tool result dict with tool_call_id should produce a ToolResponseMessage."""
195+
prior_trace = [
196+
{"role": "assistant", "content": None, "tool_calls": [{"id": "call_1"}]},
197+
]
198+
formatter = MultiturnFormatter(
199+
prior_trace=prior_trace,
200+
user_input={"tool_call_id": "call_1", "content": "42"},
201+
)
202+
203+
first = formatter.next_turn()
204+
assert first is not None
205+
assert len(first.messages) == 1
206+
msg = first.messages[0]
207+
assert msg.role == "tool"
208+
assert msg.content == "42"
209+
assert msg.tool_call_id == "call_1"
210+
assert first.final_call
211+
212+
213+
def test_multiturn_formatter_multiple_tool_results():
214+
"""List of tool result dicts should produce multiple ToolResponseMessages."""
215+
prior_trace = [
216+
{"role": "assistant", "content": None, "tool_calls": []},
217+
]
218+
tool_results = [
219+
{"tool_call_id": "call_1", "content": "15"},
220+
{"tool_call_id": "call_2", "content": "36"},
221+
]
222+
formatter = MultiturnFormatter(prior_trace=prior_trace, user_input=tool_results)
223+
224+
first = formatter.next_turn()
225+
assert first is not None
226+
assert len(first.messages) == 2
227+
assert first.messages[0].role == "tool"
228+
assert first.messages[0].tool_call_id == "call_1"
229+
assert first.messages[0].content == "15"
230+
assert first.messages[1].role == "tool"
231+
assert first.messages[1].tool_call_id == "call_2"
232+
assert first.messages[1].content == "36"
233+
assert first.final_call
234+
235+
236+
def test_multiturn_formatter_user_input_not_confused_with_tool_result():
237+
"""A regular dict input (no tool_call_id) is treated as a user message."""
238+
prior_trace = [{"role": "system", "content": "sys"}]
239+
formatter = MultiturnFormatter(
240+
prior_trace=prior_trace,
241+
user_input={"question": "what is 2+2?"},
242+
)
243+
first = formatter.next_turn()
244+
assert first is not None
245+
assert len(first.messages) == 1
246+
assert first.messages[0].role == "user"
247+
248+
193249
def test_format_user_message():
194250
# String
195251
assert format_user_message("test input") == "test input"

libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM
16-
from kiln_ai.adapters.chat.chat_formatter import ChatFormatter
16+
from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, ToolResponseMessage
1717
from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion
1818
from kiln_ai.adapters.ml_model_list import KilnModelProvider
1919
from kiln_ai.adapters.model_adapters.stream_events import (
@@ -101,23 +101,30 @@ async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]:
101101
for message in turn.messages:
102102
if message.content is None:
103103
raise ValueError("Empty message content isn't allowed")
104-
self._messages.append(
105-
{"role": message.role, "content": message.content} # type: ignore[arg-type]
106-
)
104+
msg_dict: dict = {"role": message.role, "content": message.content}
105+
if isinstance(message, ToolResponseMessage):
106+
msg_dict["tool_call_id"] = message.tool_call_id
107+
self._messages.append(msg_dict) # type: ignore[arg-type]
107108

108109
skip_response_format = not turn.final_call
109110
turn_top_logprobs = self._top_logprobs if turn.final_call else None
110111

112+
interrupted = False
111113
async for event in self._stream_model_turn(
112114
skip_response_format, turn_top_logprobs
113115
):
114116
if isinstance(event, _ModelTurnComplete):
115117
usage += event.usage
116118
prior_output = event.assistant_message
117119
final_choice = event.model_choice
120+
if event.interrupted_by_tool_calls:
121+
interrupted = True
118122
else:
119123
yield event
120124

125+
if interrupted:
126+
break
127+
121128
if not prior_output:
122129
raise RuntimeError("No assistant message/output returned from model")
123130

@@ -176,6 +183,39 @@ async def _stream_model_turn(
176183
self._messages.append(response_choice.message)
177184

178185
if tool_calls and len(tool_calls) > 0:
186+
# Check for return_on_tool_call BEFORE processing
187+
if self._adapter.base_adapter_config.return_on_tool_call:
188+
real_tool_calls = [
189+
tc for tc in tool_calls if tc.function.name != "task_response"
190+
]
191+
if real_tool_calls:
192+
# Yield INPUT_AVAILABLE events for each tool call
193+
for tc in real_tool_calls:
194+
try:
195+
parsed_args = json.loads(tc.function.arguments)
196+
except (json.JSONDecodeError, TypeError):
197+
parsed_args = None
198+
yield ToolCallEvent(
199+
event_type=ToolCallEventType.INPUT_AVAILABLE,
200+
tool_call_id=tc.id,
201+
tool_name=tc.function.name or "unknown",
202+
arguments=parsed_args,
203+
error=(
204+
f"Failed to parse arguments: {tc.function.arguments}"
205+
if parsed_args is None
206+
else None
207+
),
208+
)
209+
210+
yield _ModelTurnComplete(
211+
assistant_message="",
212+
model_choice=response_choice,
213+
usage=usage,
214+
interrupted_by_tool_calls=True,
215+
)
216+
return
217+
218+
# Existing flow: handle tool calls internally
179219
async for event in self._handle_tool_calls(tool_calls):
180220
yield event
181221

@@ -265,6 +305,7 @@ class _ModelTurnComplete:
265305
assistant_message: str
266306
model_choice: Choices | None
267307
usage: Usage
308+
interrupted_by_tool_calls: bool = False
268309

269310

270311
def _validate_response(

0 commit comments

Comments
 (0)