Skip to content

Commit 5732445

Browse files
authored
fix tool limiting (#1198)
1 parent 4ae65fd commit 5732445

3 files changed

Lines changed: 23 additions & 49 deletions

File tree

examples/slackbot/src/slackbot/api.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,7 @@
3232
post_slack_message,
3333
)
3434
from slackbot.strings import count_tokens, slice_tokens
35-
from slackbot.wrap import (
36-
ToolUseLimitExceeded,
37-
WatchToolCalls,
38-
_progress_message,
39-
_tool_usage_counts,
40-
)
35+
from slackbot.wrap import WatchToolCalls, _progress_message, _tool_usage_counts
4136

4237
BOT_MENTION = r"<@(\w+)>"
4338

@@ -162,19 +157,6 @@ async def handle_message(payload: SlackPayload, db: Database):
162157
channel_id=event.channel,
163158
thread_ts=thread_ts,
164159
)
165-
except ToolUseLimitExceeded as e:
166-
logger.warning(f"Tool use limit exceeded: {e}")
167-
assert event.channel is not None, "No channel found"
168-
await task(post_slack_message)(
169-
message=str(e),
170-
channel_id=event.channel,
171-
thread_ts=thread_ts,
172-
)
173-
return Completed(
174-
message="Tool use limit exceeded",
175-
name="LIMIT_EXCEEDED",
176-
data=dict(user_context=user_context),
177-
)
178160
except Exception as e:
179161
logger.error(f"Error running agent: {e}")
180162
assert event.channel is not None, "No channel found"

examples/slackbot/src/slackbot/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def validate_log_level(cls, v: str) -> str:
8080

8181
# Tool use limits
8282
max_tool_calls_per_turn: int = Field(
83-
default=50,
83+
default=5,
8484
description="Maximum number of tool calls allowed per agent turn to prevent runaway tool use",
8585
)
8686

examples/slackbot/src/slackbot/wrap.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010

1111
T = TypeVar("T")
1212

13-
14-
class ToolUseLimitExceeded(Exception):
15-
"""Raised when tool use limit is exceeded."""
16-
17-
pass
18-
19-
2013
_progress_message: ContextVar[Any] = ContextVar("progress_message", default=None)
2114
_tool_usage_counts: ContextVar[dict[str, int] | None] = ContextVar(
2215
"tool_usage_counts", default=None
@@ -95,30 +88,29 @@ async def wrapper(*args, **kwargs) -> T:
9588
result = await result
9689
return result
9790

91+
# Get tool name for tracking
92+
tool_name = kwargs.get("name", "Unknown Tool")
93+
if not tool_name or tool_name == "Unknown Tool":
94+
if len(args) > 1:
95+
tool_name = args[1]
96+
97+
# Always track and enforce tool usage limits
98+
counts = _tool_usage_counts.get()
99+
if counts is None:
100+
counts = defaultdict(int)
101+
_tool_usage_counts.set(counts)
102+
counts[tool_name] += 1
103+
104+
# Check if we've exceeded the limit
105+
total_calls = sum(counts.values())
106+
if total_calls > max_tool_calls:
107+
# Return a message that tells the agent to stop using tools and continue
108+
# This will be returned as the tool result, which the agent will see
109+
return "Tool use limit reached. Please continue with the information you've gathered so far to answer the user's question."
110+
98111
_current_tool_token = None
99112
if _progress := _progress_message.get():
100-
# The tool name is either in kwargs['name'] or args[1]
101-
tool_name = kwargs.get("name", "Unknown Tool")
102-
if not tool_name or tool_name == "Unknown Tool":
103-
if len(args) > 1:
104-
tool_name = args[1]
105-
106-
# Update tool usage counts
107-
counts = _tool_usage_counts.get()
108-
if counts is None:
109-
counts = defaultdict(int)
110-
_tool_usage_counts.set(counts)
111-
counts[tool_name] += 1
112-
113-
# Check if we've exceeded the limit
114-
total_calls = sum(counts.values())
115-
if total_calls > max_tool_calls:
116-
# Raise an exception to preserve type safety
117-
raise ToolUseLimitExceeded(
118-
"I've reached my tool use limit for this response. Please ask a follow-up question if you need more information."
119-
)
120-
121-
# Set current tool
113+
# Set current tool for progress tracking
122114
_current_tool_token = _current_tool.set(tool_name)
123115

124116
try:

0 commit comments

Comments
 (0)