Skip to content

Commit 4034290

Browse files
Merge pull request #293 from drp8226/drp/agent-tooling
Add agent tooling, policies, toolkits, and persistence
2 parents c23f3ae + fafefd0 commit 4034290

31 files changed

Lines changed: 4754 additions & 78 deletions

aisuite/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,71 @@
11
from .client import Client
22
from .agents import (
33
Agent,
4+
Artifact,
5+
ArtifactRef,
6+
ArtifactStore,
7+
AllowAllToolPolicy,
8+
AllowToolsPolicy,
9+
CompactionRecord,
10+
DenyAllToolPolicy,
11+
FileArtifactStore,
12+
FileStateStore,
13+
InMemoryArtifactStore,
14+
InMemoryStateStore,
15+
PostgresStateStore,
16+
RequireApprovalPolicy,
417
Runner,
518
RunResult,
619
RunState,
720
RunStep,
21+
StateConflictError,
22+
StateNotFoundError,
23+
StateStore,
24+
StoredRunState,
25+
ThreadAlreadyExistsError,
26+
ToolMetadata,
827
ToolPolicyContext,
928
ToolPolicyDecision,
29+
agent_tool,
30+
tool,
1031
)
1132
from .framework.message import Message
1233
from . import tracing
34+
from . import toolkits
1335
from .utils.tools import Tools
1436

1537
__all__ = [
1638
"Agent",
39+
"Artifact",
40+
"ArtifactRef",
41+
"ArtifactStore",
42+
"AllowAllToolPolicy",
43+
"AllowToolsPolicy",
1744
"Client",
45+
"CompactionRecord",
46+
"DenyAllToolPolicy",
47+
"FileArtifactStore",
48+
"FileStateStore",
49+
"InMemoryArtifactStore",
50+
"InMemoryStateStore",
1851
"Message",
52+
"PostgresStateStore",
53+
"RequireApprovalPolicy",
1954
"RunResult",
2055
"RunState",
2156
"RunStep",
2257
"Runner",
58+
"StateConflictError",
59+
"StateNotFoundError",
60+
"StateStore",
61+
"StoredRunState",
62+
"ThreadAlreadyExistsError",
63+
"ToolMetadata",
2364
"ToolPolicyContext",
2465
"ToolPolicyDecision",
2566
"Tools",
67+
"agent_tool",
68+
"tool",
69+
"toolkits",
2670
"tracing",
2771
]

aisuite/agents/__init__.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,64 @@
1-
from .runner import Runner
1+
from .artifact_store import (
2+
Artifact,
3+
ArtifactRef,
4+
ArtifactStore,
5+
FileArtifactStore,
6+
InMemoryArtifactStore,
7+
)
8+
from .policies import (
9+
AllowAllToolPolicy,
10+
AllowToolsPolicy,
11+
DenyAllToolPolicy,
12+
RequireApprovalPolicy,
13+
tool,
14+
)
15+
from .postgres_state_store import CompactionRecord, PostgresStateStore
16+
from .runner import Runner, StateNotFoundError, ThreadAlreadyExistsError
17+
from .state_store import (
18+
FileStateStore,
19+
InMemoryStateStore,
20+
StateConflictError,
21+
StateStore,
22+
StoredRunState,
23+
)
24+
from .tools import agent_tool
225
from .types import (
326
Agent,
427
RunResult,
528
RunState,
629
RunStep,
30+
ToolMetadata,
731
ToolPolicyContext,
832
ToolPolicyDecision,
933
)
1034

1135
__all__ = [
1236
"Agent",
37+
"Artifact",
38+
"ArtifactRef",
39+
"ArtifactStore",
40+
"AllowAllToolPolicy",
41+
"AllowToolsPolicy",
42+
"CompactionRecord",
43+
"DenyAllToolPolicy",
44+
"RequireApprovalPolicy",
45+
"PostgresStateStore",
46+
"FileArtifactStore",
47+
"FileStateStore",
48+
"InMemoryArtifactStore",
49+
"InMemoryStateStore",
1350
"Runner",
1451
"RunResult",
1552
"RunState",
1653
"RunStep",
54+
"StateConflictError",
55+
"StateNotFoundError",
56+
"StateStore",
57+
"StoredRunState",
58+
"ThreadAlreadyExistsError",
59+
"ToolMetadata",
1760
"ToolPolicyContext",
1861
"ToolPolicyDecision",
62+
"agent_tool",
63+
"tool",
1964
]

aisuite/agents/artifact_store.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
import hashlib
5+
import json
6+
from dataclasses import dataclass, field
7+
from pathlib import Path
8+
from typing import Any, Optional, Protocol
9+
10+
from .types import ensure_json_serializable
11+
from .utils import new_id, now
12+
13+
14+
@dataclass(kw_only=True)
15+
class ArtifactRef:
16+
artifact_id: str
17+
uri: str
18+
media_type: str
19+
size_bytes: int
20+
metadata: dict[str, Any] = field(default_factory=dict)
21+
22+
def to_dict(self) -> dict[str, Any]:
23+
return ensure_json_serializable(
24+
{
25+
"artifact_id": self.artifact_id,
26+
"uri": self.uri,
27+
"media_type": self.media_type,
28+
"size_bytes": self.size_bytes,
29+
"metadata": copy.deepcopy(self.metadata),
30+
}
31+
)
32+
33+
@classmethod
34+
def from_dict(cls, data: dict[str, Any]) -> "ArtifactRef":
35+
return cls(
36+
artifact_id=data["artifact_id"],
37+
uri=data["uri"],
38+
media_type=data["media_type"],
39+
size_bytes=data["size_bytes"],
40+
metadata=copy.deepcopy(data.get("metadata", {})),
41+
)
42+
43+
44+
@dataclass(kw_only=True)
45+
class Artifact:
46+
ref: ArtifactRef
47+
data: bytes
48+
created_at: str = ""
49+
50+
def text(self, encoding: str = "utf-8") -> str:
51+
return self.data.decode(encoding)
52+
53+
54+
class ArtifactStore(Protocol):
55+
def put(
56+
self,
57+
data: bytes | str,
58+
*,
59+
media_type: str,
60+
metadata: Optional[dict[str, Any]] = None,
61+
) -> ArtifactRef:
62+
...
63+
64+
def get(self, ref: ArtifactRef | str) -> Artifact:
65+
...
66+
67+
def delete(self, ref: ArtifactRef | str) -> None:
68+
...
69+
70+
71+
class InMemoryArtifactStore:
72+
def __init__(self):
73+
self._artifacts: dict[str, Artifact] = {}
74+
75+
def put(
76+
self,
77+
data: bytes | str,
78+
*,
79+
media_type: str,
80+
metadata: Optional[dict[str, Any]] = None,
81+
) -> ArtifactRef:
82+
payload = _to_bytes(data)
83+
artifact_id = new_id("artifact")
84+
ref = ArtifactRef(
85+
artifact_id=artifact_id,
86+
uri=f"memory://{artifact_id}",
87+
media_type=media_type,
88+
size_bytes=len(payload),
89+
metadata=_artifact_metadata(payload, metadata),
90+
)
91+
self._artifacts[artifact_id] = Artifact(
92+
ref=ArtifactRef.from_dict(ref.to_dict()),
93+
data=payload,
94+
created_at=now(),
95+
)
96+
return ref
97+
98+
def get(self, ref: ArtifactRef | str) -> Artifact:
99+
artifact_id = _artifact_id(ref)
100+
try:
101+
artifact = self._artifacts[artifact_id]
102+
except KeyError as exc:
103+
raise KeyError(f"Artifact {artifact_id!r} not found.") from exc
104+
return Artifact(
105+
ref=ArtifactRef.from_dict(artifact.ref.to_dict()),
106+
data=bytes(artifact.data),
107+
created_at=artifact.created_at,
108+
)
109+
110+
def delete(self, ref: ArtifactRef | str) -> None:
111+
self._artifacts.pop(_artifact_id(ref), None)
112+
113+
114+
class FileArtifactStore:
115+
def __init__(self, root: str | Path = ".aisuite/artifacts"):
116+
self.root = Path(root)
117+
118+
def put(
119+
self,
120+
data: bytes | str,
121+
*,
122+
media_type: str,
123+
metadata: Optional[dict[str, Any]] = None,
124+
) -> ArtifactRef:
125+
payload = _to_bytes(data)
126+
artifact_id = new_id("artifact")
127+
artifact_dir = self.root / artifact_id
128+
artifact_dir.mkdir(parents=True, exist_ok=False)
129+
data_path = artifact_dir / "data"
130+
meta_path = artifact_dir / "metadata.json"
131+
ref = ArtifactRef(
132+
artifact_id=artifact_id,
133+
uri=f"artifact://{artifact_id}",
134+
media_type=media_type,
135+
size_bytes=len(payload),
136+
metadata=_artifact_metadata(payload, metadata),
137+
)
138+
data_path.write_bytes(payload)
139+
meta_path.write_text(
140+
json.dumps(
141+
{
142+
"ref": ref.to_dict(),
143+
"created_at": now(),
144+
},
145+
sort_keys=True,
146+
)
147+
+ "\n",
148+
encoding="utf-8",
149+
)
150+
return ref
151+
152+
def get(self, ref: ArtifactRef | str) -> Artifact:
153+
artifact_id = _artifact_id(ref)
154+
artifact_dir = self.root / artifact_id
155+
data_path = artifact_dir / "data"
156+
meta_path = artifact_dir / "metadata.json"
157+
if not data_path.exists() or not meta_path.exists():
158+
raise KeyError(f"Artifact {artifact_id!r} not found.")
159+
metadata = json.loads(meta_path.read_text(encoding="utf-8"))
160+
return Artifact(
161+
ref=ArtifactRef.from_dict(metadata["ref"]),
162+
data=data_path.read_bytes(),
163+
created_at=metadata.get("created_at", ""),
164+
)
165+
166+
def delete(self, ref: ArtifactRef | str) -> None:
167+
artifact_id = _artifact_id(ref)
168+
artifact_dir = self.root / artifact_id
169+
try:
170+
(artifact_dir / "data").unlink()
171+
except FileNotFoundError:
172+
pass
173+
try:
174+
(artifact_dir / "metadata.json").unlink()
175+
except FileNotFoundError:
176+
pass
177+
try:
178+
artifact_dir.rmdir()
179+
except FileNotFoundError:
180+
pass
181+
except OSError:
182+
pass
183+
184+
185+
def _artifact_id(ref: ArtifactRef | str) -> str:
186+
if isinstance(ref, ArtifactRef):
187+
return ref.artifact_id
188+
if ref.startswith("artifact://"):
189+
return ref.removeprefix("artifact://")
190+
if ref.startswith("memory://"):
191+
return ref.removeprefix("memory://")
192+
return ref
193+
194+
195+
def _to_bytes(data: bytes | str) -> bytes:
196+
if isinstance(data, bytes):
197+
return data
198+
return data.encode("utf-8")
199+
200+
201+
def _artifact_metadata(
202+
payload: bytes,
203+
metadata: Optional[dict[str, Any]],
204+
) -> dict[str, Any]:
205+
return ensure_json_serializable(
206+
{
207+
**copy.deepcopy(metadata or {}),
208+
"sha256": hashlib.sha256(payload).hexdigest(),
209+
}
210+
)

0 commit comments

Comments
 (0)