|
10 | 10 |
|
11 | 11 | T = TypeVar("T") |
12 | 12 |
|
13 | | - |
14 | | -class ToolUseLimitExceeded(Exception): |
15 | | - """Raised when tool use limit is exceeded.""" |
16 | | - |
17 | | - pass |
18 | | - |
19 | | - |
20 | 13 | _progress_message: ContextVar[Any] = ContextVar("progress_message", default=None) |
21 | 14 | _tool_usage_counts: ContextVar[dict[str, int] | None] = ContextVar( |
22 | 15 | "tool_usage_counts", default=None |
@@ -95,30 +88,29 @@ async def wrapper(*args, **kwargs) -> T: |
95 | 88 | result = await result |
96 | 89 | return result |
97 | 90 |
|
| 91 | + # Get tool name for tracking |
| 92 | + tool_name = kwargs.get("name", "Unknown Tool") |
| 93 | + if not tool_name or tool_name == "Unknown Tool": |
| 94 | + if len(args) > 1: |
| 95 | + tool_name = args[1] |
| 96 | + |
| 97 | + # Always track and enforce tool usage limits |
| 98 | + counts = _tool_usage_counts.get() |
| 99 | + if counts is None: |
| 100 | + counts = defaultdict(int) |
| 101 | + _tool_usage_counts.set(counts) |
| 102 | + counts[tool_name] += 1 |
| 103 | + |
| 104 | + # Check if we've exceeded the limit |
| 105 | + total_calls = sum(counts.values()) |
| 106 | + if total_calls > max_tool_calls: |
| 107 | + # Return a message that tells the agent to stop using tools and continue |
| 108 | + # This will be returned as the tool result, which the agent will see |
| 109 | + return "Tool use limit reached. Please continue with the information you've gathered so far to answer the user's question." |
| 110 | + |
98 | 111 | _current_tool_token = None |
99 | 112 | if _progress := _progress_message.get(): |
100 | | - # The tool name is either in kwargs['name'] or args[1] |
101 | | - tool_name = kwargs.get("name", "Unknown Tool") |
102 | | - if not tool_name or tool_name == "Unknown Tool": |
103 | | - if len(args) > 1: |
104 | | - tool_name = args[1] |
105 | | - |
106 | | - # Update tool usage counts |
107 | | - counts = _tool_usage_counts.get() |
108 | | - if counts is None: |
109 | | - counts = defaultdict(int) |
110 | | - _tool_usage_counts.set(counts) |
111 | | - counts[tool_name] += 1 |
112 | | - |
113 | | - # Check if we've exceeded the limit |
114 | | - total_calls = sum(counts.values()) |
115 | | - if total_calls > max_tool_calls: |
116 | | - # Raise an exception to preserve type safety |
117 | | - raise ToolUseLimitExceeded( |
118 | | - "I've reached my tool use limit for this response. Please ask a follow-up question if you need more information." |
119 | | - ) |
120 | | - |
121 | | - # Set current tool |
| 113 | + # Set current tool for progress tracking |
122 | 114 | _current_tool_token = _current_tool.set(tool_name) |
123 | 115 |
|
124 | 116 | try: |
|
0 commit comments