|
18 | 18 | import asyncio |
19 | 19 | import logging |
20 | 20 | import time |
| 21 | +import weakref |
21 | 22 | from typing import Any, Dict, List, Optional, TYPE_CHECKING |
| 23 | +from .._lockmap import LockMap |
22 | 24 |
|
23 | 25 | if TYPE_CHECKING: |
24 | 26 | from praisonaiagents import Agent |
@@ -74,8 +76,8 @@ def __init__( |
74 | 76 | run_timeout: float = 300.0, # 5 minutes default timeout |
75 | 77 | ) -> None: |
76 | 78 | self._histories: Dict[str, List[Dict[str, Any]]] = {} |
77 | | - self._locks: Dict[str, asyncio.Lock] = {} |
78 | | - self._agent_locks: Dict[int, asyncio.Lock] = {} |
| 79 | + self._locks = LockMap() |
| 80 | + self._agent_locks: "weakref.WeakKeyDictionary[Any, asyncio.Lock]" = weakref.WeakKeyDictionary() |
79 | 81 | self._max_history = max_history |
80 | 82 | self._store = store |
81 | 83 | self._platform = platform |
@@ -131,16 +133,15 @@ def _session_key(self, user_id: str) -> str: |
131 | 133 | def _get_lock(self, user_id: str) -> asyncio.Lock: |
132 | 134 | """Get or create an asyncio.Lock for *user_id* (storage-keyed).""" |
133 | 135 | key = self._storage_key(user_id) |
134 | | - if key not in self._locks: |
135 | | - self._locks[key] = asyncio.Lock() |
136 | | - return self._locks[key] |
| 136 | + return self._locks.get(key) |
137 | 137 |
|
138 | 138 | def _get_agent_lock(self, agent: "Agent") -> asyncio.Lock: |
139 | | - """Get or create a lock for the *agent* instance (by id).""" |
140 | | - agent_id = id(agent) |
141 | | - if agent_id not in self._agent_locks: |
142 | | - self._agent_locks[agent_id] = asyncio.Lock() |
143 | | - return self._agent_locks[agent_id] |
| 139 | + """Get or create a lock for the *agent* instance (using WeakKeyDictionary).""" |
| 140 | + lock = self._agent_locks.get(agent) |
| 141 | + if lock is None: |
| 142 | + lock = asyncio.Lock() |
| 143 | + self._agent_locks[agent] = lock |
| 144 | + return lock |
144 | 145 |
|
145 | 146 | def _load_history(self, user_id: str) -> List[Dict[str, Any]]: |
146 | 147 | """Load user history from store (if available) or in-memory cache.""" |
@@ -537,7 +538,7 @@ def reap_stale(self, max_age_seconds: int) -> int: |
537 | 538 | for storage_key in stale: |
538 | 539 | self._histories.pop(storage_key, None) |
539 | 540 | self._last_active.pop(storage_key, None) |
540 | | - self._locks.pop(storage_key, None) |
| 541 | + self._locks.drop(storage_key) |
541 | 542 | if self._store is not None: |
542 | 543 | key = self._persist_key(storage_key) |
543 | 544 | try: |
@@ -571,7 +572,7 @@ def reset(self, user_id: str) -> bool: |
571 | 572 | existed = storage_key in self._histories |
572 | 573 | self._histories.pop(storage_key, None) |
573 | 574 | self._last_active.pop(storage_key, None) |
574 | | - self._locks.pop(storage_key, None) |
| 575 | + self._locks.drop(storage_key) |
575 | 576 |
|
576 | 577 | if self._store is not None: |
577 | 578 | persist_key = self._session_key(user_id) |
|
0 commit comments