From 048091138b491543663f66da589897e5e94f161d Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 14 Nov 2025 16:13:13 -0600 Subject: [PATCH 01/12] First pass: be able to shutdown homeserver that hasn't `setup` --- synapse/api/errors.py | 6 ++++++ synapse/server.py | 29 ++++++++++++++++++++++------- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c4339ebef89..88ce52acc98 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -856,6 +856,12 @@ def to_synapse_error(self) -> SynapseError: return ProxiedRequestError(self.code, errmsg, errcode, j) +class HomeServerNotSetupException(Exception): + """ + Raised when an operation is attempted on the HomeServer before setup() has been called. + """ + + class ShadowBanError(Exception): """ Raised when a shadow-banned user attempts to perform an action. diff --git a/synapse/server.py b/synapse/server.py index de0a2b098c6..bafd059fc14 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -54,6 +54,7 @@ from synapse.api.auth.internal import InternalAuth from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.auth_blocking import AuthBlocking +from synapse.api.errors import HomeServerNotSetupException from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.app._base import unregister_sighups @@ -399,7 +400,7 @@ def run_as_background_process( """ if self._is_shutdown: raise Exception( - f"Cannot start background process. HomeServer has been shutdown {len(self._background_processes)} {len(self.get_clock()._looping_calls)} {len(self.get_clock()._call_id_to_delayed_call)}" + "Cannot start background process. HomeServer has been shutdown" ) # Ignore linter error as this is the one location this should be called. @@ -466,7 +467,11 @@ async def shutdown(self) -> None: # TODO: Cleanup replication pieces - self.get_keyring().shutdown() + try: + self.get_keyring().shutdown() + except HomeServerNotSetupException: + # If the homeserver wasn't fully setup, keyring won't exist + pass # Cleanup metrics associated with the homeserver for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): @@ -478,8 +483,12 @@ async def shutdown(self) -> None: self.config.server.server_name ) - for db in self.get_datastores().databases: - db.stop_background_updates() + try: + for db in self.get_datastores().databases: + db.stop_background_updates() + except HomeServerNotSetupException: + # If the homeserver wasn't fully setup, the datastores won't exist + pass if self.should_send_federation(): try: @@ -513,8 +522,12 @@ async def shutdown(self) -> None: pass self._background_processes.clear() - for db in self.get_datastores().databases: - db._db_pool.close() + try: + for db in self.get_datastores().databases: + db._db_pool.close() + except HomeServerNotSetupException: + # If the homeserver wasn't fully setup, the datastores won't exist + pass def register_async_shutdown_handler( self, @@ -677,7 +690,9 @@ def get_clock(self) -> Clock: def get_datastores(self) -> Databases: if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") + raise HomeServerNotSetupException( + "HomeServer.setup must be called before getting datastores" + ) return self.datastores From a0e7698535a4ba8cfab7dabc90783000ffb1d89f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 14 Nov 2025 17:08:21 -0600 Subject: [PATCH 02/12] Add test --- tests/app/test_homeserver_shutdown.py | 85 +++++++++++++++++ tests/server.py | 127 ++++++++++++++++---------- 2 files changed, 164 insertions(+), 48 deletions(-) diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py index f127e5571de..9063a63bcb8 100644 --- a/tests/app/test_homeserver_shutdown.py +++ b/tests/app/test_homeserver_shutdown.py @@ -20,6 +20,7 @@ import gc import weakref +from unittest.mock import patch from synapse.app.homeserver import SynapseHomeServer from synapse.logging.context import LoggingContext @@ -204,3 +205,87 @@ async def shutdown() -> None: # # render a useful amount of information without taking an overly long time # # to generate the result. # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + + @logcontext_clean + def test_clean_homeserver_shutdown_when_failed_to_setup(self) -> None: + """ + Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected if it + fails to be `setup`. + """ + self.reactor, self.clock = get_clock() + + # Patch `hs.setup()` to do nothing, so that the homeserver is not fully setup. + with patch.object(SynapseHomeServer, "setup", return_value=None): + # Patch out the call to `start_test_homeserver` since we want access to the + # homeserver even before the server is setup (let alone started) + with patch("tests.server.start_test_homeserver", return_value=None): + self.hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + reactor=self.reactor, + homeserver_to_use=SynapseHomeServer, + clock=self.clock, + ) + + hs_ref = weakref.ref(self.hs) + + # Run the reactor so any `callWhenRunning` functions can be cleared out. + self.reactor.run() + # This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor` + # we use in tests doesn't handle this properly (see doc comment) + cleanup_test_reactor_system_event_triggers(self.reactor) + + async def shutdown() -> None: + # Use a logcontext just to double-check that we don't mangle the logcontext + # during shutdown. + with LoggingContext(name="hs_shutdown", server_name=self.hs.hostname): + await self.hs.shutdown() + + self.get_success(shutdown()) + + # Cleanup the internal reference in our test case + del self.hs + + # Force garbage collection. + gc.collect() + + # Ensure the `HomeServer` hs been garbage collected by attempting to use the + # weakref to it. + if hs_ref() is not None: + self.fail("HomeServer reference should not be valid at this point") + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) diff --git a/tests/server.py b/tests/server.py index 30337f3e38e..2f22907ce8f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1074,10 +1074,10 @@ def setup_test_homeserver( If no datastore is supplied, one is created and given to the homeserver. Args: - cleanup_func : The function used to register a cleanup routine for - after the test. If the function returns a Deferred, the - test case will wait until the Deferred has fired before - proceeding to the next cleanup function. + cleanup_func: The function used to register a cleanup routine for + after the test. If the function returns a Deferred, the + test case will wait until the Deferred has fired before + proceeding to the next cleanup function. server_name: Homeserver name config: Homeserver config reactor: Twisted reactor @@ -1190,6 +1190,55 @@ def setup_test_homeserver( cur.close() db_conn.close() + def cleanup() -> None: + import psycopg2 + + dropped = False + + db_engine = create_engine(hs.database.config) + + # Drop the test database + db_conn = db_engine.module.connect( + dbname=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + ) + db_engine.attempt_to_set_autocommit(db_conn, True) + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), + category=UserWarning, + stacklevel=2, + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn( + "Failed to drop old DB.", + category=UserWarning, + stacklevel=2, + ) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + hs = homeserver_to_use( server_name, config=config, @@ -1224,6 +1273,32 @@ def shutdown_hs_on_cleanup() -> "Deferred[None]": with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool): hs.setup() + # Ideally, setup/start would be separated but since this is historically used + # throughout tests, we keep the existing behavior for now. We probably just need to + # rename this function. + start_test_homeserver(hs=hs, cleanup_func=cleanup_func, reactor=reactor) + + return hs + + +def start_test_homeserver( + *, + hs: HomeServer, + cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None], + reactor: ISynapseReactor | None = None, +) -> None: + """ + Start a homeserver for testing. + + Args: + hs: The homeserver to start. + cleanup_func: The function used to register a cleanup routine for + after the test. If the function returns a Deferred, the + test case will wait until the Deferred has fired before + proceeding to the next cleanup function. + reactor: Twisted reactor + """ + # Register background tasks required by this server. This must be done # somewhat manually due to the background tasks not being registered # unless handlers are instantiated. @@ -1245,53 +1320,11 @@ def shutdown_hs_on_cleanup() -> "Deferred[None]": # We need to do cleanup on PostgreSQL def cleanup() -> None: - import psycopg2 - # Close all the db pools db_pool = database_pool() if db_pool is not None: db_pool._db_pool.close() - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - dbname=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - ) - db_engine.attempt_to_set_autocommit(db_conn, True) - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), - category=UserWarning, - stacklevel=2, - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn( - "Failed to drop old DB.", - category=UserWarning, - stacklevel=2, - ) - if not LEAVE_DB: # Register the cleanup hook cleanup_func(cleanup) @@ -1330,5 +1363,3 @@ def thread_pool() -> threadpool.ThreadPool: load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) load_legacy_password_auth_providers(hs) - - return hs From df0a5d1d0174e97c819e14c8e95de5a9e5ec28de Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 14 Nov 2025 17:21:58 -0600 Subject: [PATCH 03/12] Add some better debugging info when the test fails --- tests/app/test_homeserver_shutdown.py | 216 ++++++++++++-------------- 1 file changed, 98 insertions(+), 118 deletions(-) diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py index 9063a63bcb8..0f5d1c73387 100644 --- a/tests/app/test_homeserver_shutdown.py +++ b/tests/app/test_homeserver_shutdown.py @@ -19,7 +19,9 @@ # import gc +import sys import weakref +from typing import Any from unittest.mock import patch from synapse.app.homeserver import SynapseHomeServer @@ -82,45 +84,12 @@ async def shutdown() -> None: # Ensure the `HomeServer` hs been garbage collected by attempting to use the # weakref to it. - if hs_ref() is not None: - self.fail("HomeServer reference should not be valid at this point") - - # To help debug this test when it fails, it is useful to leverage the - # `objgraph` module. - # The following code serves as an example of what I have found to be useful - # when tracking down references holding the `SynapseHomeServer` in memory: - # - # all_objects = gc.get_objects() - # for obj in all_objects: - # try: - # # These are a subset of types that are typically involved with - # # holding the `HomeServer` in memory. You may want to inspect - # # other types as well. - # if isinstance(obj, DataStore): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # db_obj = obj - # if isinstance(obj, SynapseHomeServer): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # synapse_hs = obj - # if isinstance(obj, SynapseSite): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # sysite = obj - # if isinstance(obj, DatabasePool): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # dbpool = obj - # except Exception: - # pass - # - # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) - # - # # The following values for `max_depth` and `too_many` have been found to - # # render a useful amount of information without taking an overly long time - # # to generate the result. - # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + hs_after_shutdown = hs_ref() + if hs_after_shutdown is not None: + self.fail( + "HomeServer reference should not be valid at this point " + f"{get_memory_debug_info_for_object(hs_after_shutdown)}", + ) @logcontext_clean def test_clean_homeserver_shutdown_mid_background_updates(self) -> None: @@ -166,45 +135,12 @@ async def shutdown() -> None: # Ensure the `HomeServer` hs been garbage collected by attempting to use the # weakref to it. - if hs_ref() is not None: - self.fail("HomeServer reference should not be valid at this point") - - # To help debug this test when it fails, it is useful to leverage the - # `objgraph` module. - # The following code serves as an example of what I have found to be useful - # when tracking down references holding the `SynapseHomeServer` in memory: - # - # all_objects = gc.get_objects() - # for obj in all_objects: - # try: - # # These are a subset of types that are typically involved with - # # holding the `HomeServer` in memory. You may want to inspect - # # other types as well. - # if isinstance(obj, DataStore): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # db_obj = obj - # if isinstance(obj, SynapseHomeServer): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # synapse_hs = obj - # if isinstance(obj, SynapseSite): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # sysite = obj - # if isinstance(obj, DatabasePool): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # dbpool = obj - # except Exception: - # pass - # - # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) - # - # # The following values for `max_depth` and `too_many` have been found to - # # render a useful amount of information without taking an overly long time - # # to generate the result. - # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + hs_after_shutdown = hs_ref() + if hs_after_shutdown is not None: + self.fail( + "HomeServer reference should not be valid at this point " + f"{get_memory_debug_info_for_object(hs_after_shutdown)}", + ) @logcontext_clean def test_clean_homeserver_shutdown_when_failed_to_setup(self) -> None: @@ -215,7 +151,7 @@ def test_clean_homeserver_shutdown_when_failed_to_setup(self) -> None: self.reactor, self.clock = get_clock() # Patch `hs.setup()` to do nothing, so that the homeserver is not fully setup. - with patch.object(SynapseHomeServer, "setup", return_value=None): + with patch.object(SynapseHomeServer, "setup", return_value=None) as mock_setup: # Patch out the call to `start_test_homeserver` since we want access to the # homeserver even before the server is setup (let alone started) with patch("tests.server.start_test_homeserver", return_value=None): @@ -225,6 +161,9 @@ def test_clean_homeserver_shutdown_when_failed_to_setup(self) -> None: homeserver_to_use=SynapseHomeServer, clock=self.clock, ) + # Sanity check that we patched the correct method (make sure it was the + # thing that was called) + mock_setup.assert_called_once_with() hs_ref = weakref.ref(self.hs) @@ -250,42 +189,83 @@ async def shutdown() -> None: # Ensure the `HomeServer` hs been garbage collected by attempting to use the # weakref to it. - if hs_ref() is not None: - self.fail("HomeServer reference should not be valid at this point") - - # To help debug this test when it fails, it is useful to leverage the - # `objgraph` module. - # The following code serves as an example of what I have found to be useful - # when tracking down references holding the `SynapseHomeServer` in memory: - # - # all_objects = gc.get_objects() - # for obj in all_objects: - # try: - # # These are a subset of types that are typically involved with - # # holding the `HomeServer` in memory. You may want to inspect - # # other types as well. - # if isinstance(obj, DataStore): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # db_obj = obj - # if isinstance(obj, SynapseHomeServer): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # synapse_hs = obj - # if isinstance(obj, SynapseSite): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # sysite = obj - # if isinstance(obj, DatabasePool): - # print(sys.getrefcount(obj), "refs to", obj) - # if not isinstance(obj, weakref.ProxyType): - # dbpool = obj - # except Exception: - # pass - # - # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) - # - # # The following values for `max_depth` and `too_many` have been found to - # # render a useful amount of information without taking an overly long time - # # to generate the result. - # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + hs_after_shutdown = hs_ref() + if hs_after_shutdown is not None: + self.fail( + "HomeServer reference should not be valid at this point " + f"{get_memory_debug_info_for_object(hs_after_shutdown)}", + ) + + +def get_memory_debug_info_for_object(object: Any) -> dict[str, Any]: + """ + Gathers some useful information to make it easier to figure out why the `object` is + still in memory. + + Args: + object: The object to gather debug information for. + """ + debug: dict[str, Any] = {} + if object is not None: + # The simplest tracing we can do is show the reference count for the object. + debug["reference_count"] = sys.getrefcount(object) + + # Find the list of objects that directly refer to the object. + # + # Note: The `ref_count` can be >0 but `referrers` can be empty because + # the all of the objects were frozen. Look at the + # `frozen_object_count` to detect this scenario. + referrers = gc.get_referrers(object) + debug["gc_referrer_count"] = len(referrers) + debug["gc_referrers"] = referrers + + # We don't expect to see frozen objects in normal operation of the + # `multi_synapse` shard. + # + # We can see frozen objects if you forget to `freeze=False` when + # starting the `SynapseHomeServer`. Frozen objects mean they are + # never considered for garbage collection. If the + # `SynapseHomeServer` (or anything that references the homeserver) + # is frozen, the homeserver can never be garbage collected and will + # linger in memory forever. + freeze_count = gc.get_freeze_count() + debug["gc_global_frozen_object_count"] = freeze_count + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + + return debug From d046476e97cde256a0ed11654cc72b9bc6c74e54 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 14 Nov 2025 17:26:16 -0600 Subject: [PATCH 04/12] Fix up lints --- tests/server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/server.py b/tests/server.py index 2f22907ce8f..4fb7dea5ec0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1195,8 +1195,6 @@ def cleanup() -> None: dropped = False - db_engine = create_engine(hs.database.config) - # Drop the test database db_conn = db_engine.module.connect( dbname=POSTGRES_BASE_DB, @@ -1285,7 +1283,7 @@ def start_test_homeserver( *, hs: HomeServer, cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None], - reactor: ISynapseReactor | None = None, + reactor: ISynapseReactor, ) -> None: """ Start a homeserver for testing. From d55f86d066045b1f12b5cd2788901aa777596e88 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 14 Nov 2025 17:27:52 -0600 Subject: [PATCH 05/12] Add changelog --- changelog.d/19187.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/19187.misc diff --git a/changelog.d/19187.misc b/changelog.d/19187.misc new file mode 100644 index 00000000000..d831de38c8f --- /dev/null +++ b/changelog.d/19187.misc @@ -0,0 +1 @@ +Fix `HomeServer.shutdown()` failing if the homeserver hasn't been setup yet. From ea1757efd3b92a6aac18c2a9ce675b0c81c235bb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 21 Nov 2025 18:08:48 -0600 Subject: [PATCH 06/12] Avoid partially initialized `Keyring` which can hold references to `hs` --- synapse/crypto/keyring.py | 153 ++++++++++++++++++++++++----------- synapse/server.py | 6 +- tests/crypto/test_keyring.py | 10 ++- 3 files changed, 115 insertions(+), 54 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 3abb644df5d..7d516a66543 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -21,7 +21,7 @@ import abc import logging -from typing import TYPE_CHECKING, Callable, Iterable +from typing import TYPE_CHECKING, Callable, Iterable, Sequence import attr from signedjson.key import ( @@ -150,57 +150,77 @@ class Keyring: """ def __init__( - self, hs: "HomeServer", key_fetchers: "Iterable[KeyFetcher] | None" = None + self, + hs: "HomeServer", + test_only_key_fetchers: "Sequence[KeyFetcher] | None" = None, ): - self.server_name = hs.hostname + """ + Args: + hs: The HomeServer instance + test_only_key_fetchers: Dependency injection for tests only. If provided, + these key fetchers will be used instead of the default ones. + """ - if key_fetchers is None: - # Always fetch keys from the database. - mutable_key_fetchers: list[KeyFetcher] = [StoreKeyFetcher(hs)] - # Fetch keys from configured trusted key servers, if any exist. - key_servers = hs.config.key.key_servers - if key_servers: - mutable_key_fetchers.append(PerspectivesKeyFetcher(hs)) - # Finally, fetch keys from the origin server directly. - mutable_key_fetchers.append(ServerKeyFetcher(hs)) - - self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers) - else: - self._key_fetchers = key_fetchers - - self._fetch_keys_queue: BatchingQueue[ - _FetchKeyRequest, dict[str, dict[str, FetchKeyResult]] - ] = BatchingQueue( - name="keyring_server", - hs=hs, - clock=hs.get_clock(), - # The method called to fetch each key - process_batch_callback=self._inner_fetch_key_requests, - ) + try: + self.server_name = hs.hostname + + self._key_fetchers: Sequence[KeyFetcher] = [] + if test_only_key_fetchers is None: + # Always fetch keys from the database. + self._key_fetchers.append(StoreKeyFetcher(hs)) + # Fetch keys from configured trusted key servers, if any exist. + key_servers = hs.config.key.key_servers + if key_servers: + self._key_fetchers.append(PerspectivesKeyFetcher(hs)) + # Finally, fetch keys from the origin server directly. + self._key_fetchers.append(ServerKeyFetcher(hs)) + else: + self._key_fetchers = test_only_key_fetchers + + self._fetch_keys_queue: BatchingQueue[ + _FetchKeyRequest, dict[str, dict[str, FetchKeyResult]] + ] = BatchingQueue( + name="keyring_server", + hs=hs, + clock=hs.get_clock(), + # The method called to fetch each key + process_batch_callback=self._inner_fetch_key_requests, + ) - self._is_mine_server_name = hs.is_mine_server_name + self._is_mine_server_name = hs.is_mine_server_name - # build a FetchKeyResult for each of our own keys, to shortcircuit the - # fetcher. - self._local_verify_keys: dict[str, FetchKeyResult] = {} - for key_id, key in hs.config.key.old_signing_keys.items(): - self._local_verify_keys[key_id] = FetchKeyResult( - verify_key=key, valid_until_ts=key.expired - ) + # build a FetchKeyResult for each of our own keys, to shortcircuit the + # fetcher. + self._local_verify_keys: dict[str, FetchKeyResult] = {} + for key_id, key in hs.config.key.old_signing_keys.items(): + self._local_verify_keys[key_id] = FetchKeyResult( + verify_key=key, valid_until_ts=key.expired + ) - vk = get_verify_key(hs.signing_key) - self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult( - verify_key=vk, - valid_until_ts=2**63, # fake future timestamp - ) + vk = get_verify_key(hs.signing_key) + self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult( + verify_key=vk, + valid_until_ts=2**63, # fake future timestamp + ) + except Exception: + self.shutdown() + raise def shutdown(self) -> None: """ Prepares the KeyRing for garbage collection by shutting down it's queues. + + This needs to be robust enough to be called even if `__init__` failed partway + through. """ - self._fetch_keys_queue.shutdown() - for key_fetcher in self._key_fetchers: - key_fetcher.shutdown() + _fetch_keys_queue = getattr(self, "_fetch_keys_queue", None) + if _fetch_keys_queue: + _fetch_keys_queue.shutdown() + + _key_fetchers = getattr(self, "_key_fetchers", None) + if _key_fetchers: + for key_fetcher in _key_fetchers: + key_fetcher.shutdown() async def verify_json_for_server( self, @@ -495,8 +515,13 @@ def __init__(self, hs: "HomeServer"): def shutdown(self) -> None: """ Prepares the KeyFetcher for garbage collection by shutting down it's queue. + + This needs to be robust enough to be called even if `__init__` failed partway + through. """ - self._queue.shutdown() + _queue = getattr(self, "_queue", None) + if _queue: + _queue.shutdown() async def get_keys( self, server_name: str, key_ids: list[str], minimum_valid_until_ts: int @@ -521,9 +546,24 @@ class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - self.store = hs.get_datastores().main + try: + super().__init__(hs) + + self.store = hs.get_datastores().main + + # `KeyFetcher` keeps a reference to `hs` which we need to clean up if something + # goes wrong so we can cleanly shutdown the homeserver. + # + # If something goes wrong while initializing the `KeyFetcher`, the caller won't + # be able to call shutdown on it because there won't be a reference to the + # `KeyFetcher`. + # + # An error can occur here if someone tries to create a `KeyFetcher` before the + # homeserver is fully set up (`HomeServerNotSetupException: HomeServer.setup + # must be called before getting datastores`). + except Exception: + self.shutdown() + raise async def _fetch_keys( self, keys_to_fetch: list[_FetchKeyRequest] @@ -543,9 +583,24 @@ async def _fetch_keys( class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - self.store = hs.get_datastores().main + try: + super().__init__(hs) + + self.store = hs.get_datastores().main + + # `KeyFetcher` keeps a reference to `hs` which we need to clean up if something + # goes wrong so we can cleanly shutdown the homeserver. + # + # If something goes wrong while initializing the `KeyFetcher`, the caller won't + # be able to call shutdown on it because there won't be a reference to the + # `KeyFetcher`. + # + # An error can occur here if someone tries to create a `KeyFetcher` before the + # homeserver is fully set up (`HomeServerNotSetupException: HomeServer.setup + # must be called before getting datastores`). + except Exception: + self.shutdown() + raise async def process_v2_response( self, from_server: str, response_json: JsonDict, time_added_ms: int diff --git a/synapse/server.py b/synapse/server.py index bafd059fc14..90c3ef6d067 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -467,12 +467,16 @@ async def shutdown(self) -> None: # TODO: Cleanup replication pieces + keyring: Keyring | None = None try: - self.get_keyring().shutdown() + keyring = self.get_keyring() except HomeServerNotSetupException: # If the homeserver wasn't fully setup, keyring won't exist pass + if keyring: + keyring.shutdown() + # Cleanup metrics associated with the homeserver for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): later_gauge.unregister_hooks_for_homeserver_instance_id( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d3e8da97f84..17281d6ad56 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -95,7 +95,7 @@ def check_context( def test_verify_json_objects_for_server_awaits_previous_requests(self) -> None: mock_fetcher = Mock() mock_fetcher.get_keys = Mock() - kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) + kr = keyring.Keyring(self.hs, test_only_key_fetchers=(mock_fetcher,)) # a signed object that we are going to try to validate key1 = signedjson.key.generate_signing_key("1") @@ -286,7 +286,7 @@ async def get_keys( mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring( - self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) + self.hs, test_only_key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) ) # sign the json @@ -313,7 +313,7 @@ async def get_keys( mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) - kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) + kr = keyring.Keyring(self.hs, test_only_key_fetchers=(mock_fetcher,)) json1: JsonDict = {} signedjson.sign.sign_json(json1, "server1", key1) @@ -363,7 +363,9 @@ async def get_keys2( mock_fetcher1.get_keys = Mock(side_effect=get_keys1) mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) - kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) + kr = keyring.Keyring( + self.hs, test_only_key_fetchers=(mock_fetcher1, mock_fetcher2) + ) json1: JsonDict = {} signedjson.sign.sign_json(json1, "server1", key1) From fce7adadbc7a4988475da139035dfb7e548898fc Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 21 Nov 2025 18:50:16 -0600 Subject: [PATCH 07/12] Better comment --- synapse/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/server.py b/synapse/server.py index 90c3ef6d067..88662c5b28e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -471,7 +471,9 @@ async def shutdown(self) -> None: try: keyring = self.get_keyring() except HomeServerNotSetupException: - # If the homeserver wasn't fully setup, keyring won't exist + # If the homeserver wasn't fully setup, keyring won't have existed before + # this and will fail to be initialized but it cleans itself up for any + # partial initialization problem. pass if keyring: From 235b684bf2f1ebf96bf21258f52d47d07663dc46 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 1 Dec 2025 15:54:39 -0600 Subject: [PATCH 08/12] Use `ExitStack` for nice clean-up during `__init__` See https://github.com/element-hq/synapse/pull/19187#discussion_r2571776809 --- synapse/crypto/keyring.py | 91 ++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 49 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7d516a66543..aad23ea3af8 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -21,6 +21,7 @@ import abc import logging +from contextlib import ExitStack from typing import TYPE_CHECKING, Callable, Iterable, Sequence import attr @@ -160,20 +161,27 @@ def __init__( test_only_key_fetchers: Dependency injection for tests only. If provided, these key fetchers will be used instead of the default ones. """ - - try: + with ExitStack() as exit: self.server_name = hs.hostname self._key_fetchers: Sequence[KeyFetcher] = [] if test_only_key_fetchers is None: # Always fetch keys from the database. - self._key_fetchers.append(StoreKeyFetcher(hs)) + store_key_fetcher = StoreKeyFetcher(hs) + exit.callback(store_key_fetcher.shutdown) + self._key_fetchers.append(store_key_fetcher) + # Fetch keys from configured trusted key servers, if any exist. key_servers = hs.config.key.key_servers if key_servers: - self._key_fetchers.append(PerspectivesKeyFetcher(hs)) + perspectives_key_fetcher = PerspectivesKeyFetcher(hs) + exit.callback(perspectives_key_fetcher.shutdown) + self._key_fetchers.append(perspectives_key_fetcher) + # Finally, fetch keys from the origin server directly. - self._key_fetchers.append(ServerKeyFetcher(hs)) + server_key_fetcher = ServerKeyFetcher(hs) + exit.callback(server_key_fetcher.shutdown) + self._key_fetchers.append(server_key_fetcher) else: self._key_fetchers = test_only_key_fetchers @@ -186,6 +194,7 @@ def __init__( # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, ) + exit.callback(self._fetch_keys_queue.shutdown) self._is_mine_server_name = hs.is_mine_server_name @@ -202,9 +211,10 @@ def __init__( verify_key=vk, valid_until_ts=2**63, # fake future timestamp ) - except Exception: - self.shutdown() - raise + + # We reached the end of the block whichs means everything was successful, so + # no exit handlers are needed (remove them all). + exit.pop_all() def shutdown(self) -> None: """ @@ -213,14 +223,10 @@ def shutdown(self) -> None: This needs to be robust enough to be called even if `__init__` failed partway through. """ - _fetch_keys_queue = getattr(self, "_fetch_keys_queue", None) - if _fetch_keys_queue: - _fetch_keys_queue.shutdown() + self._fetch_keys_queue.shutdown() - _key_fetchers = getattr(self, "_key_fetchers", None) - if _key_fetchers: - for key_fetcher in _key_fetchers: - key_fetcher.shutdown() + for key_fetcher in self._key_fetchers: + key_fetcher.shutdown() async def verify_json_for_server( self, @@ -515,13 +521,8 @@ def __init__(self, hs: "HomeServer"): def shutdown(self) -> None: """ Prepares the KeyFetcher for garbage collection by shutting down it's queue. - - This needs to be robust enough to be called even if `__init__` failed partway - through. """ - _queue = getattr(self, "_queue", None) - if _queue: - _queue.shutdown() + self._queue.shutdown() async def get_keys( self, server_name: str, key_ids: list[str], minimum_valid_until_ts: int @@ -546,24 +547,20 @@ class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - try: + with ExitStack() as exit: super().__init__(hs) + # `KeyFetcher` keeps a reference to `hs` which we need to clean up if + # something goes wrong so we can cleanly shutdown the homeserver. + exit.callback(super().shutdown) + # An error can be raised here if someone tried to create a `StoreKeyFetcher` + # before the homeserver is fully set up (`HomeServerNotSetupException: + # HomeServer.setup must be called before getting datastores`). self.store = hs.get_datastores().main - # `KeyFetcher` keeps a reference to `hs` which we need to clean up if something - # goes wrong so we can cleanly shutdown the homeserver. - # - # If something goes wrong while initializing the `KeyFetcher`, the caller won't - # be able to call shutdown on it because there won't be a reference to the - # `KeyFetcher`. - # - # An error can occur here if someone tries to create a `KeyFetcher` before the - # homeserver is fully set up (`HomeServerNotSetupException: HomeServer.setup - # must be called before getting datastores`). - except Exception: - self.shutdown() - raise + # We reached the end of the block whichs means everything was successful, so + # no exit handlers are needed (remove them all). + exit.pop_all() async def _fetch_keys( self, keys_to_fetch: list[_FetchKeyRequest] @@ -583,24 +580,20 @@ async def _fetch_keys( class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): - try: + with ExitStack() as exit: super().__init__(hs) + # `KeyFetcher` keeps a reference to `hs` which we need to clean up if + # something goes wrong so we can cleanly shutdown the homeserver. + exit.callback(super().shutdown) + # An error can be raised here if someone tried to create a `StoreKeyFetcher` + # before the homeserver is fully set up (`HomeServerNotSetupException: + # HomeServer.setup must be called before getting datastores`). self.store = hs.get_datastores().main - # `KeyFetcher` keeps a reference to `hs` which we need to clean up if something - # goes wrong so we can cleanly shutdown the homeserver. - # - # If something goes wrong while initializing the `KeyFetcher`, the caller won't - # be able to call shutdown on it because there won't be a reference to the - # `KeyFetcher`. - # - # An error can occur here if someone tries to create a `KeyFetcher` before the - # homeserver is fully set up (`HomeServerNotSetupException: HomeServer.setup - # must be called before getting datastores`). - except Exception: - self.shutdown() - raise + # We reached the end of the block whichs means everything was successful, so + # no exit handlers are needed (remove them all). + exit.pop_all() async def process_v2_response( self, from_server: str, response_json: JsonDict, time_added_ms: int From 25015f7ed4ec56eb521f3193bc5b3481aba82158 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 1 Dec 2025 15:57:42 -0600 Subject: [PATCH 09/12] `self._key_fetchers.clear()` See https://github.com/element-hq/synapse/pull/19187#discussion_r2571734949 --- synapse/crypto/keyring.py | 7 ++++--- tests/crypto/test_keyring.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index aad23ea3af8..4c4e7b3a0e8 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -22,7 +22,7 @@ import abc import logging from contextlib import ExitStack -from typing import TYPE_CHECKING, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Callable, Iterable import attr from signedjson.key import ( @@ -153,7 +153,7 @@ class Keyring: def __init__( self, hs: "HomeServer", - test_only_key_fetchers: "Sequence[KeyFetcher] | None" = None, + test_only_key_fetchers: "list[KeyFetcher] | None" = None, ): """ Args: @@ -164,7 +164,7 @@ def __init__( with ExitStack() as exit: self.server_name = hs.hostname - self._key_fetchers: Sequence[KeyFetcher] = [] + self._key_fetchers: list[KeyFetcher] = [] if test_only_key_fetchers is None: # Always fetch keys from the database. store_key_fetcher = StoreKeyFetcher(hs) @@ -227,6 +227,7 @@ def shutdown(self) -> None: for key_fetcher in self._key_fetchers: key_fetcher.shutdown() + self._key_fetchers.clear() async def verify_json_for_server( self, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17281d6ad56..3cc905f699a 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -95,7 +95,12 @@ def check_context( def test_verify_json_objects_for_server_awaits_previous_requests(self) -> None: mock_fetcher = Mock() mock_fetcher.get_keys = Mock() - kr = keyring.Keyring(self.hs, test_only_key_fetchers=(mock_fetcher,)) + kr = keyring.Keyring( + self.hs, + test_only_key_fetchers=[ + mock_fetcher, + ], + ) # a signed object that we are going to try to validate key1 = signedjson.key.generate_signing_key("1") @@ -286,7 +291,7 @@ async def get_keys( mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring( - self.hs, test_only_key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) + self.hs, test_only_key_fetchers=[StoreKeyFetcher(self.hs), mock_fetcher] ) # sign the json @@ -313,7 +318,7 @@ async def get_keys( mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) - kr = keyring.Keyring(self.hs, test_only_key_fetchers=(mock_fetcher,)) + kr = keyring.Keyring(self.hs, test_only_key_fetchers=[mock_fetcher]) json1: JsonDict = {} signedjson.sign.sign_json(json1, "server1", key1) @@ -364,7 +369,7 @@ async def get_keys2( mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring( - self.hs, test_only_key_fetchers=(mock_fetcher1, mock_fetcher2) + self.hs, test_only_key_fetchers=[mock_fetcher1, mock_fetcher2] ) json1: JsonDict = {} From 858c2bafe474a4d03e289e22eae8f1bb38c9a987 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 1 Dec 2025 15:58:34 -0600 Subject: [PATCH 10/12] Remove irrelevant comment --- synapse/crypto/keyring.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 4c4e7b3a0e8..d5d4942080a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -219,9 +219,6 @@ def __init__( def shutdown(self) -> None: """ Prepares the KeyRing for garbage collection by shutting down it's queues. - - This needs to be robust enough to be called even if `__init__` failed partway - through. """ self._fetch_keys_queue.shutdown() From fe3a84a73f0054f0c4770df2ddfe7e5858a751e4 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 1 Dec 2025 15:59:29 -0600 Subject: [PATCH 11/12] Fix `which` typo --- synapse/crypto/keyring.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d5d4942080a..712412a73c2 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -212,7 +212,7 @@ def __init__( valid_until_ts=2**63, # fake future timestamp ) - # We reached the end of the block whichs means everything was successful, so + # We reached the end of the block which means everything was successful, so # no exit handlers are needed (remove them all). exit.pop_all() @@ -556,7 +556,7 @@ def __init__(self, hs: "HomeServer"): # HomeServer.setup must be called before getting datastores`). self.store = hs.get_datastores().main - # We reached the end of the block whichs means everything was successful, so + # We reached the end of the block which means everything was successful, so # no exit handlers are needed (remove them all). exit.pop_all() @@ -589,7 +589,7 @@ def __init__(self, hs: "HomeServer"): # HomeServer.setup must be called before getting datastores`). self.store = hs.get_datastores().main - # We reached the end of the block whichs means everything was successful, so + # We reached the end of the block which means everything was successful, so # no exit handlers are needed (remove them all). exit.pop_all() From df97f113b2e7c4a791401a364d425039f21c38cd Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 1 Dec 2025 16:02:12 -0600 Subject: [PATCH 12/12] Explain `ExitStack` a little --- synapse/crypto/keyring.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 712412a73c2..883f682e776 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -161,6 +161,7 @@ def __init__( test_only_key_fetchers: Dependency injection for tests only. If provided, these key fetchers will be used instead of the default ones. """ + # Clean-up to avoid partial initialization leaving behind references. with ExitStack() as exit: self.server_name = hs.hostname @@ -545,6 +546,7 @@ class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): + # Clean-up to avoid partial initialization leaving behind references. with ExitStack() as exit: super().__init__(hs) # `KeyFetcher` keeps a reference to `hs` which we need to clean up if @@ -578,6 +580,7 @@ async def _fetch_keys( class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): + # Clean-up to avoid partial initialization leaving behind references. with ExitStack() as exit: super().__init__(hs) # `KeyFetcher` keeps a reference to `hs` which we need to clean up if