Skip to content

Commit 2aea2be

Browse files
authored
test tool limiting (#1196)
* test tool limiting * address review comments
1 parent c52ff20 commit 2aea2be

3 files changed

Lines changed: 52 additions & 8 deletions

File tree

examples/slackbot/src/slackbot/api.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
post_slack_message,
3333
)
3434
from slackbot.strings import count_tokens, slice_tokens
35-
from slackbot.wrap import WatchToolCalls, _progress_message, _tool_usage_counts
35+
from slackbot.wrap import (
36+
ToolUseLimitExceeded,
37+
WatchToolCalls,
38+
_progress_message,
39+
_tool_usage_counts,
40+
)
3641

3742
BOT_MENTION = r"<@(\w+)>"
3843

@@ -69,7 +74,10 @@ async def run_agent(
6974
counts_token = _tool_usage_counts.set(defaultdict(int))
7075

7176
try:
72-
with WatchToolCalls(settings=decorator_settings):
77+
with WatchToolCalls(
78+
settings=decorator_settings,
79+
max_tool_calls=settings.max_tool_calls_per_turn,
80+
):
7381
result = await create_agent(model=settings.model_name).run(
7482
user_prompt=cleaned_message,
7583
message_history=conversation,
@@ -154,6 +162,19 @@ async def handle_message(payload: SlackPayload, db: Database):
154162
channel_id=event.channel,
155163
thread_ts=thread_ts,
156164
)
165+
except ToolUseLimitExceeded as e:
166+
logger.warning(f"Tool use limit exceeded: {e}")
167+
assert event.channel is not None, "No channel found"
168+
await task(post_slack_message)(
169+
message=str(e),
170+
channel_id=event.channel,
171+
thread_ts=thread_ts,
172+
)
173+
return Completed(
174+
message="Tool use limit exceeded",
175+
name="LIMIT_EXCEEDED",
176+
data=dict(user_context=user_context),
177+
)
157178
except Exception as e:
158179
logger.error(f"Error running agent: {e}")
159180
assert event.channel is not None, "No channel found"

examples/slackbot/src/slackbot/settings.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,24 @@ def validate_log_level(cls, v: str) -> str:
7878
description="Slack user ID to notify when discussions are created (e.g., U1234567890)",
7979
)
8080

81+
# Tool use limits
82+
max_tool_calls_per_turn: int = Field(
83+
default=50,
84+
description="Maximum number of tool calls allowed per agent turn to prevent runaway tool use",
85+
)
86+
8187
@model_validator(mode="after")
82-
def validate_temperature(self) -> "SlackbotSettings":
88+
def _apply_post_validation_defaults(self) -> "SlackbotSettings":
8389
if "gpt-5" in self.model_name:
8490
self.temperature = 1.0
85-
return self
86-
87-
@model_validator(mode="after")
88-
def set_turbopuffer_api_key(self) -> "SlackbotSettings":
8991
if not os.getenv("TURBOPUFFER_API_KEY"):
9092
try:
9193
api_key = Secret.load("tpuf-api-key", _sync=True).get() # type: ignore
9294
os.environ["TURBOPUFFER_API_KEY"] = api_key
9395
except Exception:
9496
pass # If secret doesn't exist, turbopuffer will handle the error
97+
if not self.admin_slack_user_id:
98+
self.admin_slack_user_id = Variable.get("admin-slack-id", _sync=True)
9599
return self
96100

97101
@property

examples/slackbot/src/slackbot/wrap.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111
T = TypeVar("T")
1212

13+
14+
class ToolUseLimitExceeded(Exception):
15+
"""Raised when tool use limit is exceeded."""
16+
17+
pass
18+
19+
1320
_progress_message: ContextVar[Any] = ContextVar("progress_message", default=None)
1421
_tool_usage_counts: ContextVar[dict[str, int] | None] = ContextVar(
1522
"tool_usage_counts", default=None
@@ -64,6 +71,7 @@ def prefect_wrapped_function(
6471
decorator: Callable[..., Callable[..., T]] = task,
6572
tags: set[str] | None = None,
6673
settings: dict[str, Any] | None = None,
74+
max_tool_calls: int = 10, # Default limit per agent run (matches settings)
6775
) -> Callable[..., Callable[..., T]]:
6876
"""Decorator for wrapping a function with a prefect decorator."""
6977
tags = tags or set[str]()
@@ -102,6 +110,14 @@ async def wrapper(*args, **kwargs) -> T:
102110
_tool_usage_counts.set(counts)
103111
counts[tool_name] += 1
104112

113+
# Check if we've exceeded the limit
114+
total_calls = sum(counts.values())
115+
if total_calls > max_tool_calls:
116+
# Raise an exception to preserve type safety
117+
raise ToolUseLimitExceeded(
118+
"I've reached my tool use limit for this response. Please ask a follow-up question if you need more information."
119+
)
120+
105121
# Set current tool
106122
_current_tool_token = _current_tool.set(tool_name)
107123

@@ -161,11 +177,13 @@ def __init__(
161177
patch_method_name: str = "call_tool",
162178
tags: set[str] | None = None,
163179
settings: dict[str, Any] | None = None,
180+
max_tool_calls: int = 10,
164181
):
165182
"""Initialize the context manager.
166183
Args:
167184
tags: Prefect tags to apply to the flow.
168-
flow_kwargs: Keyword arguments to pass to the flow.
185+
settings: Settings to pass to the decorator.
186+
max_tool_calls: Maximum number of tool calls allowed per turn.
169187
"""
170188
# Import here to avoid circular imports
171189
from pydantic_ai.toolsets.abstract import AbstractToolset
@@ -176,4 +194,5 @@ def __init__(
176194
decorator=prefect_wrapped_function,
177195
tags=tags,
178196
settings=settings,
197+
max_tool_calls=max_tool_calls,
179198
)

0 commit comments

Comments
 (0)