Skip to content

Commit df8fe41

Browse files
fix: critical session cache, async tools, and pairing security gaps
- UnifiedSessionStore: invalidate in-memory cache when on-disk mtime changes to prevent cross-process message loss between TUI and interactive CLI - _arun_framework: use ToolResolver.resolve(..., instantiate=True) to match sync path and avoid async YAML tool initialisation failures - Pairing: reject approval callbacks from non-owner users; share gateway PairingStore with channel bots and reload store before is_paired checks Adds regression tests for each fix. Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent a9f4bd5 commit df8fe41

8 files changed

Lines changed: 167 additions & 16 deletions

File tree

src/praisonai/praisonai/agents_generator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -698,18 +698,14 @@ async def _arun_framework(self, config):
698698
if isinstance(t, str) and t.strip():
699699
needed_tools.add(t.strip())
700700

701-
# Resolve only the tools actually referenced in YAML
701+
# Resolve only the tools actually referenced in YAML using ToolResolver with instantiation
702702
for tool_name in needed_tools:
703703
try:
704-
resolved_tool = self.tool_resolver.resolve(tool_name)
705-
if resolved_tool is None:
706-
self.logger.warning(f"Tool '{tool_name}' not found")
707-
continue
708-
tools_dict[tool_name] = (
709-
resolved_tool() if inspect.isclass(resolved_tool) else resolved_tool
710-
)
704+
resolved_tool = self.tool_resolver.resolve(tool_name, instantiate=True)
705+
if resolved_tool is not None:
706+
tools_dict[tool_name] = resolved_tool
711707
except Exception as e:
712-
self.logger.warning(f"Failed to initialize tool '{tool_name}': {e}")
708+
self.logger.warning(f"Failed to resolve or instantiate tool '{tool_name}': {e}")
713709
continue
714710

715711
except Exception as e:

src/praisonai/praisonai/bots/_pairing_ui.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,20 @@ async def handle_approval_callback(
200200
success=False,
201201
message="Invalid or tampered callback data"
202202
)
203+
204+
# Only the configured bot owner may approve or deny pairing requests
205+
config = getattr(bot_adapter, "config", None)
206+
expected_owner = getattr(config, "owner_user_id", None) if config else None
207+
if expected_owner and str(expected_owner) != str(owner_user_id):
208+
logger.warning(
209+
"Rejected pairing callback from non-owner user %s (expected %s)",
210+
owner_user_id,
211+
expected_owner,
212+
)
213+
return PairingApprovalResult(
214+
success=False,
215+
message="Only the bot owner can approve pairing requests",
216+
)
203217

204218
action = parsed["action"]
205219
channel = parsed["channel"]

src/praisonai/praisonai/cli/session/unified.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, session_dir: Optional[Path] = None):
131131
self.session_dir = Path(session_dir) if session_dir else DEFAULT_SESSION_DIR
132132
self.session_dir.mkdir(parents=True, exist_ok=True)
133133
self._cache: Dict[str, UnifiedSession] = {}
134+
self._cache_mtimes: Dict[str, float] = {}
134135
self._last_session_id: Optional[str] = None
135136

136137
def _get_session_path(self, session_id: str) -> Path:
@@ -140,6 +141,20 @@ def _get_session_path(self, session_id: str) -> Path:
140141
def _get_last_session_path(self) -> Path:
141142
"""Get the path to the last session marker file."""
142143
return self.session_dir / ".last_session"
144+
145+
def _is_cache_valid(self, session_id: str) -> bool:
146+
"""Return True if the in-memory cache matches the on-disk file."""
147+
if session_id not in self._cache:
148+
return False
149+
path = self._get_session_path(session_id)
150+
if not path.exists():
151+
return False
152+
try:
153+
current_mtime = path.stat().st_mtime_ns
154+
cached_mtime = self._cache_mtimes.get(session_id, 0)
155+
return current_mtime <= cached_mtime
156+
except OSError:
157+
return False
143158

144159
def save(self, session: UnifiedSession) -> None:
145160
"""
@@ -200,8 +215,12 @@ def save(self, session: UnifiedSession) -> None:
200215
json_data = json.dumps(session.to_dict(), indent=2).encode('utf-8')
201216
f.write(json_data)
202217

203-
# Update cache
218+
# Update cache and track file mtime for cross-process invalidation
204219
self._cache[session.session_id] = session
220+
try:
221+
self._cache_mtimes[session.session_id] = path.stat().st_mtime_ns
222+
except OSError:
223+
pass
205224

206225
# Update last session marker
207226
self._update_last_session(session.session_id)
@@ -221,8 +240,8 @@ def load(self, session_id: str) -> Optional[UnifiedSession]:
221240
Returns:
222241
Session if found, None otherwise
223242
"""
224-
# Check cache first
225-
if session_id in self._cache:
243+
# Return cached session only when the on-disk file has not changed
244+
if self._is_cache_valid(session_id):
226245
return self._cache[session_id]
227246

228247
path = self._get_session_path(session_id)
@@ -258,6 +277,10 @@ def load(self, session_id: str) -> Optional[UnifiedSession]:
258277

259278
session = UnifiedSession.from_dict(data)
260279
self._cache[session_id] = session
280+
try:
281+
self._cache_mtimes[session_id] = path.stat().st_mtime_ns
282+
except OSError:
283+
pass
261284
logger.debug(f"Loaded session: {session_id}")
262285
return session
263286
except Exception as e:
@@ -299,6 +322,7 @@ def delete(self, session_id: str) -> bool:
299322
if path.exists():
300323
path.unlink()
301324
self._cache.pop(session_id, None)
325+
self._cache_mtimes.pop(session_id, None)
302326
logger.debug(f"Deleted session: {session_id}")
303327
return True
304328
return False

src/praisonai/praisonai/gateway/pairing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,10 @@ def __init__(
146146
self._pending: Dict[str, dict] = {}
147147
# (channel_id, channel_type) -> PairedChannel
148148
self._paired: Dict[tuple, PairedChannel] = {}
149+
self._loaded_mtime: float = 0.0
149150

150151
self._load()
152+
self._loaded_mtime = self._get_store_mtime()
151153

152154
# ── Code lifecycle ────────────────────────────────────────────────
153155

@@ -230,6 +232,7 @@ def verify_and_pair(
230232
def is_paired(self, channel_id: str, channel_type: str) -> bool:
231233
"""Check if a channel is authorised."""
232234
with self._lock:
235+
self._reload_if_stale()
233236
return (channel_id, channel_type) in self._paired
234237

235238
def list_paired(self) -> List[PairedChannel]:
@@ -331,6 +334,7 @@ def _save(self) -> None:
331334
with os.fdopen(fd, "w") as fh:
332335
json.dump(data, fh, indent=2)
333336
os.replace(tmp_path, self._path) # atomic on POSIX
337+
self._loaded_mtime = self._get_store_mtime()
334338
except Exception:
335339
# Clean up temp file on failure
336340
try:
@@ -341,6 +345,23 @@ def _save(self) -> None:
341345
except OSError as exc:
342346
logger.warning("Failed to save pairing store: %s", exc)
343347

348+
def _get_store_mtime(self) -> float:
349+
"""Return pairing store file mtime, or 0 if unavailable."""
350+
try:
351+
return os.path.getmtime(self._path) if os.path.exists(self._path) else 0.0
352+
except OSError:
353+
return 0.0
354+
355+
def _reload_if_stale(self) -> None:
356+
"""Reload from disk when another process has updated the store."""
357+
current_mtime = self._get_store_mtime()
358+
if current_mtime <= self._loaded_mtime:
359+
return
360+
self._paired.clear()
361+
self._pending.clear()
362+
self._load()
363+
self._loaded_mtime = current_mtime
364+
344365
def _load(self) -> None:
345366
"""Load paired channels from disk."""
346367
if not os.path.exists(self._path):

src/praisonai/praisonai/gateway/server.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,13 @@ async def start_channels(self, channels_cfg: Dict[str, Dict[str, Any]]) -> None:
16851685
self._channel_tasks.append(task)
16861686
logger.info(f"Started {len(self._channel_bots)} channel bot(s)")
16871687

1688+
def _wire_gateway_pairing_store(self, bot: Any) -> None:
1689+
"""Share the gateway pairing store with a channel bot."""
1690+
from praisonai.bots._pairing_ui import PairingCallbackHandler
1691+
1692+
bot._pairing_store = self.pairing_store
1693+
bot._pairing_callback_handler = PairingCallbackHandler(self.pairing_store)
1694+
16881695
def _create_bot(
16891696
self,
16901697
channel_type: str,
@@ -1714,14 +1721,20 @@ def _create_bot(
17141721

17151722
if channel_type == "telegram":
17161723
from praisonai.bots import TelegramBot
1717-
return TelegramBot(token=token, agent=agent, config=config)
1724+
bot = TelegramBot(token=token, agent=agent, config=config)
1725+
self._wire_gateway_pairing_store(bot)
1726+
return bot
17181727
elif channel_type == "discord":
17191728
from praisonai.bots import DiscordBot
1720-
return DiscordBot(token=token, agent=agent, config=config)
1729+
bot = DiscordBot(token=token, agent=agent, config=config)
1730+
self._wire_gateway_pairing_store(bot)
1731+
return bot
17211732
elif channel_type == "slack":
17221733
from praisonai.bots import SlackBot
17231734
app_token = ch_cfg.get("app_token", os.environ.get("SLACK_APP_TOKEN", ""))
1724-
return SlackBot(token=token, agent=agent, config=config, app_token=app_token)
1735+
bot = SlackBot(token=token, agent=agent, config=config, app_token=app_token)
1736+
self._wire_gateway_pairing_store(bot)
1737+
return bot
17251738
elif channel_type == "whatsapp":
17261739
from praisonai.bots import WhatsAppBot
17271740
wa_mode = ch_cfg.get("mode", "cloud").lower().strip()

src/praisonai/tests/integration/bots/test_pairing_owner_dm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,30 @@ async def test_owner_approval_allows_future_messages(self):
155155
# Should not send another approval DM
156156
assert len(self.adapter.approval_dms) == 1 # Still only the original one
157157

158+
async def test_non_owner_cannot_approve_pairing(self):
159+
"""Only the configured owner may approve pairing callbacks."""
160+
code = self.pairing_store.generate_code(channel_type="telegram")
161+
keyboard = PairingUIBuilder.create_telegram_keyboard(
162+
user_name="Alice",
163+
code=code,
164+
channel="telegram",
165+
user_id="new-user",
166+
)
167+
callback_data = keyboard["inline_keyboard"][0][0]["callback_data"]
168+
169+
self.adapter.config = self.config
170+
171+
callback_handler = PairingCallbackHandler(self.pairing_store)
172+
result = await callback_handler.handle_approval_callback(
173+
callback_data=callback_data,
174+
owner_user_id="attacker-999",
175+
bot_adapter=self.adapter,
176+
)
177+
178+
assert result.success is False
179+
assert "owner" in result.message.lower()
180+
assert not self.pairing_store.is_paired("new-user", "telegram")
181+
158182
async def test_no_owner_id_falls_back_to_cli(self):
159183
"""Test fallback to CLI instructions when owner_user_id is not configured."""
160184
# Configure bot without owner ID

src/praisonai/tests/unit/cli/test_unified_session.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,32 @@ def test_load_nonexistent(self, temp_session_dir):
263263

264264
assert session is None
265265

266+
def test_load_invalidates_stale_cache_after_external_write(self, temp_session_dir):
267+
"""Cross-process writes must not be overwritten by stale in-memory cache."""
268+
store_a = UnifiedSessionStore(session_dir=temp_session_dir)
269+
store_b = UnifiedSessionStore(session_dir=temp_session_dir)
270+
271+
session = UnifiedSession(session_id="shared-session")
272+
session.add_user_message("first message")
273+
store_a.save(session)
274+
275+
# Process A keeps a warm cache
276+
cached = store_a.load("shared-session")
277+
assert cached is not None
278+
assert len(cached.messages) == 1
279+
280+
# Process B appends a message and saves
281+
updated = store_b.load("shared-session")
282+
assert updated is not None
283+
updated.add_user_message("second message")
284+
store_b.save(updated)
285+
286+
# Process A must see B's write instead of returning stale cache
287+
reloaded = store_a.load("shared-session")
288+
assert reloaded is not None
289+
assert len(reloaded.messages) == 2
290+
assert reloaded.messages[1]["content"] == "second message"
291+
266292

267293
class TestGlobalSessionStore:
268294
"""Tests for global session store."""

tests/test_wrapper_layer_regression.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
import pytest
11-
from unittest.mock import MagicMock, patch
11+
from unittest.mock import AsyncMock, MagicMock, patch
1212

1313

1414
class TestInteractiveRuntimeLifecycle:
@@ -110,6 +110,39 @@ def test_bots_cli_uses_tool_resolver(self):
110110
MockResolver.assert_called_once()
111111
mock_resolver.resolve.assert_called_once_with('test_tool', instantiate=True)
112112

113+
def test_arun_framework_uses_instantiate_true(self):
114+
"""Async YAML tool resolution must mirror sync instantiate=True behaviour."""
115+
import asyncio
116+
from praisonai.praisonai.agents_generator import AgentsGenerator
117+
118+
generator = AgentsGenerator(agent_file="agents.yaml")
119+
config = {
120+
"roles": {
121+
"researcher": {
122+
"role": "Researcher",
123+
"goal": "Research",
124+
"backstory": "You research",
125+
"tools": ["test_tool"],
126+
}
127+
}
128+
}
129+
130+
adapter = MagicMock()
131+
adapter.arun = AsyncMock(return_value="ok")
132+
generator.framework_adapter = adapter
133+
generator.framework = "praisonaiagents"
134+
135+
with patch.object(generator, "tool_resolver") as mock_resolver, \
136+
patch("praisonai.praisonai.agents_generator.is_available", return_value=True), \
137+
patch("praisonai.praisonai.framework_adapters.validators.assert_framework_available"), \
138+
patch.object(generator, "_validate_cli_backend_compatibility"):
139+
mock_resolver.resolve.return_value = MagicMock()
140+
mock_resolver.get_local_tool_classes.return_value = {}
141+
142+
asyncio.run(generator._arun_framework(config))
143+
144+
mock_resolver.resolve.assert_called_once_with("test_tool", instantiate=True)
145+
113146
def test_job_workflow_uses_tool_resolver(self):
114147
"""Test that job_workflow uses ToolResolver for tool resolution"""
115148
from praisonai.praisonai.cli.features.job_workflow import JobWorkflowExecutor

0 commit comments

Comments
 (0)