Skip to content

Commit 0188d23

Browse files
Merge pull request #291 from drp8226/drp/agent-api
Adding support for agent abstraction and associated tooling
2 parents 695242a + 9eb2298 commit 0188d23

19 files changed

Lines changed: 1496 additions & 68 deletions

aisuite/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
from .client import Client
2+
from .agents import (
3+
Agent,
4+
Runner,
5+
RunResult,
6+
RunState,
7+
RunStep,
8+
ToolPolicyContext,
9+
ToolPolicyDecision,
10+
)
211
from .framework.message import Message
312
from .utils.tools import Tools

aisuite/agents/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .runner import Runner
2+
from .types import (
3+
Agent,
4+
RunResult,
5+
RunState,
6+
RunStep,
7+
ToolPolicyContext,
8+
ToolPolicyDecision,
9+
)
10+
11+
__all__ = [
12+
"Agent",
13+
"Runner",
14+
"RunResult",
15+
"RunState",
16+
"RunStep",
17+
"ToolPolicyContext",
18+
"ToolPolicyDecision",
19+
]

aisuite/agents/runner.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from typing import Any, Callable, Optional
5+
6+
from ..client import Client
7+
from .types import Agent, RunResult, RunState, RunStatus, RunStep, ToolPolicy
8+
from .utils import (
9+
build_input_messages,
10+
extract_final_message,
11+
extract_final_output,
12+
extract_response_messages,
13+
merge_tags,
14+
new_id,
15+
now,
16+
)
17+
18+
19+
class Runner:
20+
@staticmethod
21+
async def run(
22+
agent: Agent,
23+
input: str | list[dict[str, Any]] | RunState,
24+
*,
25+
client: Optional[Client] = None,
26+
max_turns: int = 5,
27+
run_name: Optional[str] = None,
28+
group_id: Optional[str] = None,
29+
tags: Optional[list[str]] = None,
30+
metadata: Optional[dict[str, Any]] = None,
31+
tool_policy: Optional[ToolPolicy | Callable] = None,
32+
tracing_disabled: bool = False,
33+
**kwargs: Any,
34+
) -> RunResult:
35+
return Runner.run_sync(
36+
agent,
37+
input,
38+
client=client,
39+
max_turns=max_turns,
40+
run_name=run_name,
41+
group_id=group_id,
42+
tags=tags,
43+
metadata=metadata,
44+
tool_policy=tool_policy,
45+
tracing_disabled=tracing_disabled,
46+
**kwargs,
47+
)
48+
49+
@staticmethod
50+
def run_sync(
51+
agent: Agent,
52+
input: str | list[dict[str, Any]] | RunState,
53+
*,
54+
client: Optional[Client] = None,
55+
max_turns: int = 5,
56+
run_name: Optional[str] = None,
57+
group_id: Optional[str] = None,
58+
tags: Optional[list[str]] = None,
59+
metadata: Optional[dict[str, Any]] = None,
60+
tool_policy: Optional[ToolPolicy | Callable] = None,
61+
tracing_disabled: bool = False,
62+
**kwargs: Any,
63+
) -> RunResult:
64+
active_client = client or Client()
65+
trace_id = None if tracing_disabled else new_id("trace")
66+
if isinstance(input, RunState):
67+
messages = copy.deepcopy(input.messages)
68+
effective_run_name = run_name if run_name is not None else input.run_name
69+
effective_group_id = group_id if group_id is not None else input.group_id
70+
effective_tags = merge_tags(agent.tags, input.tags, tags)
71+
effective_metadata = {
72+
**agent.metadata,
73+
**input.metadata,
74+
**(metadata or {}),
75+
}
76+
effective_max_turns = max_turns if max_turns != 5 else input.max_turns
77+
prior_steps = copy.deepcopy(input.steps)
78+
else:
79+
messages = Runner._build_messages(agent, input)
80+
effective_run_name = run_name
81+
effective_group_id = group_id
82+
effective_tags = merge_tags(agent.tags, tags)
83+
effective_metadata = {**agent.metadata, **(metadata or {})}
84+
effective_max_turns = max_turns
85+
prior_steps = []
86+
87+
request_kwargs = {**agent.model_settings, **kwargs}
88+
if agent.tools:
89+
request_kwargs["tools"] = agent.tools
90+
request_kwargs["max_turns"] = effective_max_turns
91+
if tool_policy is not None:
92+
request_kwargs["tool_policy"] = tool_policy
93+
request_kwargs["tool_policy_context"] = {
94+
"agent_name": agent.name,
95+
"run_name": effective_run_name,
96+
"trace_id": trace_id,
97+
"group_id": effective_group_id,
98+
"tags": effective_tags,
99+
"metadata": effective_metadata,
100+
"messages": copy.deepcopy(messages),
101+
}
102+
103+
agent_step = RunStep(
104+
id=new_id("step"),
105+
type="agent",
106+
name=agent.name,
107+
trace_id=trace_id or "",
108+
started_at=now(),
109+
data={
110+
"agent_name": agent.name,
111+
"model": agent.model,
112+
"run_name": effective_run_name,
113+
},
114+
)
115+
116+
try:
117+
response = active_client.chat.completions.create(
118+
model=agent.model,
119+
messages=copy.deepcopy(messages),
120+
**request_kwargs,
121+
)
122+
status: RunStatus = "completed"
123+
except Exception:
124+
agent_step.ended_at = now()
125+
raise
126+
127+
agent_step.ended_at = now()
128+
all_messages = extract_response_messages(response, messages)
129+
raw_responses = [
130+
*getattr(response, "intermediate_responses", []),
131+
response,
132+
]
133+
steps = [
134+
*prior_steps,
135+
agent_step,
136+
*Runner._build_response_steps(raw_responses, trace_id or ""),
137+
*Runner._build_tool_steps(response, trace_id or ""),
138+
]
139+
140+
return RunResult(
141+
final_output=extract_final_output(response),
142+
status=status,
143+
agent=agent,
144+
last_agent=agent,
145+
input=input,
146+
messages=all_messages,
147+
new_items=all_messages[len(messages) :],
148+
raw_responses=raw_responses,
149+
run_name=effective_run_name,
150+
trace_id=trace_id or "",
151+
group_id=effective_group_id,
152+
tags=effective_tags,
153+
metadata=effective_metadata,
154+
steps=steps,
155+
max_turns=effective_max_turns,
156+
_client=active_client,
157+
)
158+
159+
@staticmethod
160+
async def continue_run(
161+
result: RunResult,
162+
input: str | list[dict[str, Any]],
163+
**overrides: Any,
164+
) -> RunResult:
165+
return Runner.continue_sync(result, input, **overrides)
166+
167+
@staticmethod
168+
def continue_sync(
169+
result: RunResult,
170+
input: str | list[dict[str, Any]],
171+
**overrides: Any,
172+
) -> RunResult:
173+
state = result.to_state()
174+
state.add_user_message(input)
175+
return Runner.run_sync(
176+
result.last_agent,
177+
state,
178+
client=overrides.pop("client", result._client),
179+
**overrides,
180+
)
181+
182+
@staticmethod
183+
def _build_messages(
184+
agent: Agent, input: str | list[dict[str, Any]]
185+
) -> list[dict[str, Any]]:
186+
messages = build_input_messages(input)
187+
if not agent.instructions:
188+
return messages
189+
if messages and messages[0].get("role") == "system":
190+
return messages
191+
return [{"role": "system", "content": agent.instructions}, *messages]
192+
193+
@staticmethod
194+
def _build_response_steps(raw_responses: list[Any], trace_id: str) -> list[RunStep]:
195+
steps = []
196+
for response in raw_responses:
197+
message = extract_final_message(response)
198+
data = {
199+
"has_message": message is not None,
200+
"finish_reason": getattr(
201+
getattr(response, "choices", [None])[0], "finish_reason", None
202+
)
203+
if getattr(response, "choices", None)
204+
else None,
205+
}
206+
ended_at = now()
207+
steps.append(
208+
RunStep(
209+
id=new_id("step"),
210+
type="model_response",
211+
name="model_response",
212+
trace_id=trace_id,
213+
started_at=ended_at,
214+
ended_at=ended_at,
215+
data=data,
216+
)
217+
)
218+
return steps
219+
220+
@staticmethod
221+
def _build_tool_steps(response: Any, trace_id: str) -> list[RunStep]:
222+
events = getattr(response, "tool_events", [])
223+
steps = []
224+
for event in events:
225+
ended_at = now()
226+
step_type = (
227+
"tool_result" if event.get("type") == "tool_result" else "tool_call"
228+
)
229+
steps.append(
230+
RunStep(
231+
id=new_id("step"),
232+
type=step_type,
233+
name=event.get("tool_name"),
234+
trace_id=trace_id,
235+
started_at=ended_at,
236+
ended_at=ended_at,
237+
data=copy.deepcopy(event),
238+
)
239+
)
240+
return steps

0 commit comments

Comments
 (0)