diff --git a/src/praisonai/praisonai/agents_generator.py b/src/praisonai/praisonai/agents_generator.py index aaf361582..e05378f31 100644 --- a/src/praisonai/praisonai/agents_generator.py +++ b/src/praisonai/praisonai/agents_generator.py @@ -698,18 +698,14 @@ async def _arun_framework(self, config): if isinstance(t, str) and t.strip(): needed_tools.add(t.strip()) - # Resolve only the tools actually referenced in YAML + # Resolve only the tools actually referenced in YAML using ToolResolver with instantiation for tool_name in needed_tools: try: - resolved_tool = self.tool_resolver.resolve(tool_name) - if resolved_tool is None: - self.logger.warning(f"Tool '{tool_name}' not found") - continue - tools_dict[tool_name] = ( - resolved_tool() if inspect.isclass(resolved_tool) else resolved_tool - ) + resolved_tool = self.tool_resolver.resolve(tool_name, instantiate=True) + if resolved_tool is not None: + tools_dict[tool_name] = resolved_tool except Exception as e: - self.logger.warning(f"Failed to initialize tool '{tool_name}': {e}") + self.logger.warning(f"Failed to resolve or instantiate tool '{tool_name}': {e}") continue except Exception as e: diff --git a/src/praisonai/praisonai/bots/_pairing_ui.py b/src/praisonai/praisonai/bots/_pairing_ui.py index dcc7d7bcd..d99c9e96d 100644 --- a/src/praisonai/praisonai/bots/_pairing_ui.py +++ b/src/praisonai/praisonai/bots/_pairing_ui.py @@ -200,6 +200,20 @@ async def handle_approval_callback( success=False, message="Invalid or tampered callback data" ) + + # Only the configured bot owner may approve or deny pairing requests + config = getattr(bot_adapter, "config", None) + expected_owner = getattr(config, "owner_user_id", None) if config else None + if expected_owner and str(expected_owner) != str(owner_user_id): + logger.warning( + "Rejected pairing callback from non-owner user %s (expected %s)", + owner_user_id, + expected_owner, + ) + return PairingApprovalResult( + success=False, + message="Only the bot owner can approve pairing requests", + ) action = parsed["action"] channel = parsed["channel"] diff --git a/src/praisonai/praisonai/cli/session/unified.py b/src/praisonai/praisonai/cli/session/unified.py index fe613c61e..c98a0c10c 100644 --- a/src/praisonai/praisonai/cli/session/unified.py +++ b/src/praisonai/praisonai/cli/session/unified.py @@ -131,6 +131,7 @@ def __init__(self, session_dir: Optional[Path] = None): self.session_dir = Path(session_dir) if session_dir else DEFAULT_SESSION_DIR self.session_dir.mkdir(parents=True, exist_ok=True) self._cache: Dict[str, UnifiedSession] = {} + self._cache_mtimes: Dict[str, float] = {} self._last_session_id: Optional[str] = None def _get_session_path(self, session_id: str) -> Path: @@ -140,6 +141,20 @@ def _get_session_path(self, session_id: str) -> Path: def _get_last_session_path(self) -> Path: """Get the path to the last session marker file.""" return self.session_dir / ".last_session" + + def _is_cache_valid(self, session_id: str) -> bool: + """Return True if the in-memory cache matches the on-disk file.""" + if session_id not in self._cache: + return False + path = self._get_session_path(session_id) + if not path.exists(): + return False + try: + current_mtime = path.stat().st_mtime_ns + cached_mtime = self._cache_mtimes.get(session_id, 0) + return current_mtime <= cached_mtime + except OSError: + return False def save(self, session: UnifiedSession) -> None: """ @@ -200,8 +215,12 @@ def save(self, session: UnifiedSession) -> None: json_data = json.dumps(session.to_dict(), indent=2).encode('utf-8') f.write(json_data) - # Update cache + # Update cache and track file mtime for cross-process invalidation self._cache[session.session_id] = session + try: + self._cache_mtimes[session.session_id] = path.stat().st_mtime_ns + except OSError: + pass # Update last session marker self._update_last_session(session.session_id) @@ -221,8 +240,8 @@ def load(self, session_id: str) -> Optional[UnifiedSession]: Returns: Session if found, None otherwise """ - # Check cache first - if session_id in self._cache: + # Return cached session only when the on-disk file has not changed + if self._is_cache_valid(session_id): return self._cache[session_id] path = self._get_session_path(session_id) @@ -258,6 +277,10 @@ def load(self, session_id: str) -> Optional[UnifiedSession]: session = UnifiedSession.from_dict(data) self._cache[session_id] = session + try: + self._cache_mtimes[session_id] = path.stat().st_mtime_ns + except OSError: + pass logger.debug(f"Loaded session: {session_id}") return session except Exception as e: @@ -299,6 +322,7 @@ def delete(self, session_id: str) -> bool: if path.exists(): path.unlink() self._cache.pop(session_id, None) + self._cache_mtimes.pop(session_id, None) logger.debug(f"Deleted session: {session_id}") return True return False diff --git a/src/praisonai/praisonai/gateway/pairing.py b/src/praisonai/praisonai/gateway/pairing.py index 053288922..18a8758ac 100644 --- a/src/praisonai/praisonai/gateway/pairing.py +++ b/src/praisonai/praisonai/gateway/pairing.py @@ -146,8 +146,10 @@ def __init__( self._pending: Dict[str, dict] = {} # (channel_id, channel_type) -> PairedChannel self._paired: Dict[tuple, PairedChannel] = {} + self._loaded_mtime: float = 0.0 self._load() + self._loaded_mtime = self._get_store_mtime() # ── Code lifecycle ──────────────────────────────────────────────── @@ -230,6 +232,7 @@ def verify_and_pair( def is_paired(self, channel_id: str, channel_type: str) -> bool: """Check if a channel is authorised.""" with self._lock: + self._reload_if_stale() return (channel_id, channel_type) in self._paired def list_paired(self) -> List[PairedChannel]: @@ -331,6 +334,7 @@ def _save(self) -> None: with os.fdopen(fd, "w") as fh: json.dump(data, fh, indent=2) os.replace(tmp_path, self._path) # atomic on POSIX + self._loaded_mtime = self._get_store_mtime() except Exception: # Clean up temp file on failure try: @@ -341,6 +345,23 @@ def _save(self) -> None: except OSError as exc: logger.warning("Failed to save pairing store: %s", exc) + def _get_store_mtime(self) -> float: + """Return pairing store file mtime, or 0 if unavailable.""" + try: + return os.path.getmtime(self._path) if os.path.exists(self._path) else 0.0 + except OSError: + return 0.0 + + def _reload_if_stale(self) -> None: + """Reload from disk when another process has updated the store.""" + current_mtime = self._get_store_mtime() + if current_mtime <= self._loaded_mtime: + return + self._paired.clear() + self._pending.clear() + self._load() + self._loaded_mtime = current_mtime + def _load(self) -> None: """Load paired channels from disk.""" if not os.path.exists(self._path): diff --git a/src/praisonai/praisonai/gateway/server.py b/src/praisonai/praisonai/gateway/server.py index eaa735c94..739a21d11 100644 --- a/src/praisonai/praisonai/gateway/server.py +++ b/src/praisonai/praisonai/gateway/server.py @@ -1685,6 +1685,13 @@ async def start_channels(self, channels_cfg: Dict[str, Dict[str, Any]]) -> None: self._channel_tasks.append(task) logger.info(f"Started {len(self._channel_bots)} channel bot(s)") + def _wire_gateway_pairing_store(self, bot: Any) -> None: + """Share the gateway pairing store with a channel bot.""" + from praisonai.bots._pairing_ui import PairingCallbackHandler + + bot._pairing_store = self.pairing_store + bot._pairing_callback_handler = PairingCallbackHandler(self.pairing_store) + def _create_bot( self, channel_type: str, @@ -1714,14 +1721,20 @@ def _create_bot( if channel_type == "telegram": from praisonai.bots import TelegramBot - return TelegramBot(token=token, agent=agent, config=config) + bot = TelegramBot(token=token, agent=agent, config=config) + self._wire_gateway_pairing_store(bot) + return bot elif channel_type == "discord": from praisonai.bots import DiscordBot - return DiscordBot(token=token, agent=agent, config=config) + bot = DiscordBot(token=token, agent=agent, config=config) + self._wire_gateway_pairing_store(bot) + return bot elif channel_type == "slack": from praisonai.bots import SlackBot app_token = ch_cfg.get("app_token", os.environ.get("SLACK_APP_TOKEN", "")) - return SlackBot(token=token, agent=agent, config=config, app_token=app_token) + bot = SlackBot(token=token, agent=agent, config=config, app_token=app_token) + self._wire_gateway_pairing_store(bot) + return bot elif channel_type == "whatsapp": from praisonai.bots import WhatsAppBot wa_mode = ch_cfg.get("mode", "cloud").lower().strip() diff --git a/src/praisonai/tests/integration/bots/test_pairing_owner_dm.py b/src/praisonai/tests/integration/bots/test_pairing_owner_dm.py index 99caccc3c..c227b9aca 100644 --- a/src/praisonai/tests/integration/bots/test_pairing_owner_dm.py +++ b/src/praisonai/tests/integration/bots/test_pairing_owner_dm.py @@ -155,6 +155,30 @@ async def test_owner_approval_allows_future_messages(self): # Should not send another approval DM assert len(self.adapter.approval_dms) == 1 # Still only the original one + async def test_non_owner_cannot_approve_pairing(self): + """Only the configured owner may approve pairing callbacks.""" + code = self.pairing_store.generate_code(channel_type="telegram") + keyboard = PairingUIBuilder.create_telegram_keyboard( + user_name="Alice", + code=code, + channel="telegram", + user_id="new-user", + ) + callback_data = keyboard["inline_keyboard"][0][0]["callback_data"] + + self.adapter.config = self.config + + callback_handler = PairingCallbackHandler(self.pairing_store) + result = await callback_handler.handle_approval_callback( + callback_data=callback_data, + owner_user_id="attacker-999", + bot_adapter=self.adapter, + ) + + assert result.success is False + assert "owner" in result.message.lower() + assert not self.pairing_store.is_paired("new-user", "telegram") + async def test_no_owner_id_falls_back_to_cli(self): """Test fallback to CLI instructions when owner_user_id is not configured.""" # Configure bot without owner ID diff --git a/src/praisonai/tests/unit/cli/test_unified_session.py b/src/praisonai/tests/unit/cli/test_unified_session.py index 7bbf98b7a..ed92f0e8e 100644 --- a/src/praisonai/tests/unit/cli/test_unified_session.py +++ b/src/praisonai/tests/unit/cli/test_unified_session.py @@ -263,6 +263,32 @@ def test_load_nonexistent(self, temp_session_dir): assert session is None + def test_load_invalidates_stale_cache_after_external_write(self, temp_session_dir): + """Cross-process writes must not be overwritten by stale in-memory cache.""" + store_a = UnifiedSessionStore(session_dir=temp_session_dir) + store_b = UnifiedSessionStore(session_dir=temp_session_dir) + + session = UnifiedSession(session_id="shared-session") + session.add_user_message("first message") + store_a.save(session) + + # Process A keeps a warm cache + cached = store_a.load("shared-session") + assert cached is not None + assert len(cached.messages) == 1 + + # Process B appends a message and saves + updated = store_b.load("shared-session") + assert updated is not None + updated.add_user_message("second message") + store_b.save(updated) + + # Process A must see B's write instead of returning stale cache + reloaded = store_a.load("shared-session") + assert reloaded is not None + assert len(reloaded.messages) == 2 + assert reloaded.messages[1]["content"] == "second message" + class TestGlobalSessionStore: """Tests for global session store.""" diff --git a/tests/test_wrapper_layer_regression.py b/tests/test_wrapper_layer_regression.py index 933d8f605..b0291f54d 100644 --- a/tests/test_wrapper_layer_regression.py +++ b/tests/test_wrapper_layer_regression.py @@ -8,7 +8,7 @@ """ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch class TestInteractiveRuntimeLifecycle: @@ -110,6 +110,39 @@ def test_bots_cli_uses_tool_resolver(self): MockResolver.assert_called_once() mock_resolver.resolve.assert_called_once_with('test_tool', instantiate=True) + def test_arun_framework_uses_instantiate_true(self): + """Async YAML tool resolution must mirror sync instantiate=True behaviour.""" + import asyncio + from praisonai.praisonai.agents_generator import AgentsGenerator + + generator = AgentsGenerator(agent_file="agents.yaml") + config = { + "roles": { + "researcher": { + "role": "Researcher", + "goal": "Research", + "backstory": "You research", + "tools": ["test_tool"], + } + } + } + + adapter = MagicMock() + adapter.arun = AsyncMock(return_value="ok") + generator.framework_adapter = adapter + generator.framework = "praisonaiagents" + + with patch.object(generator, "tool_resolver") as mock_resolver, \ + patch("praisonai.praisonai.agents_generator.is_available", return_value=True), \ + patch("praisonai.praisonai.framework_adapters.validators.assert_framework_available"), \ + patch.object(generator, "_validate_cli_backend_compatibility"): + mock_resolver.resolve.return_value = MagicMock() + mock_resolver.get_local_tool_classes.return_value = {} + + asyncio.run(generator._arun_framework(config)) + + mock_resolver.resolve.assert_called_once_with("test_tool", instantiate=True) + def test_job_workflow_uses_tool_resolver(self): """Test that job_workflow uses ToolResolver for tool resolution""" from praisonai.praisonai.cli.features.job_workflow import JobWorkflowExecutor