Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 603b781

Browse files
committed
Skip waiting for full state if a StateFilter does not require it
If `StateFilter` specifies a state set which we will have regardless of state-syncing, then we may as well return it immediately.
1 parent f5668f0 commit 603b781

2 files changed

Lines changed: 56 additions & 4 deletions

File tree

changelog.d/12498.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.

synapse/storage/state.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2014-2016 OpenMarket Ltd
2+
# Copyright 2022 The Matrix.org Foundation C.I.C.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
1516
from typing import (
1617
TYPE_CHECKING,
1718
Awaitable,
19+
Callable,
1820
Collection,
1921
Dict,
2022
Iterable,
@@ -530,6 +532,40 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
530532
new_all, new_excludes, new_wildcards, new_concrete_keys
531533
)
532534

535+
def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
536+
"""Check if we need to wait for full state to complete to calculate this state
537+
538+
If we have a state filter which is completely satisfied even with partial
539+
state, then we don't need to await_full_state before we can return it.
540+
541+
Args:
542+
is_mine_id: a callable which confirms if a given state_key matches a mxid
543+
of a local user
544+
"""
545+
546+
# XXX: can we be certain that the state at an event never changes (only gets
547+
# enlarged)?
548+
549+
# if we haven't requested membership events, then it depends on the value of
550+
# 'include_others'
551+
if EventTypes.Member not in self.types:
552+
return self.include_others
553+
554+
# if we're looking for *all* membership events, then we have to wait
555+
member_state_keys = self.types[EventTypes.Member]
556+
if member_state_keys is None:
557+
return True
558+
559+
# otherwise, consider whose membership we are looking for. If it's entirely
560+
# local users, then we don't need to wait.
561+
for state_key in member_state_keys:
562+
if not is_mine_id(state_key):
563+
# remote user
564+
return True
565+
566+
# local users only
567+
return False
568+
533569

534570
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
535571
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
@@ -542,6 +578,7 @@ class StateGroupStorage:
542578
"""High level interface to fetching state for event."""
543579

544580
def __init__(self, hs: "HomeServer", stores: "Databases"):
581+
self._is_mine_id = hs.is_mine_id
545582
self.stores = stores
546583
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
547584

@@ -673,7 +710,13 @@ async def get_state_for_events(
673710
RuntimeError if we don't have a state group for one or more of the events
674711
(ie they are outliers or unknown)
675712
"""
676-
event_to_groups = await self._get_state_group_for_events(event_ids)
713+
await_full_state = True
714+
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
715+
await_full_state = False
716+
717+
event_to_groups = await self._get_state_group_for_events(
718+
event_ids, await_full_state=await_full_state
719+
)
677720

678721
groups = set(event_to_groups.values())
679722
group_to_state = await self.stores.state._get_state_for_groups(
@@ -697,7 +740,9 @@ async def get_state_for_events(
697740
return {event: event_to_state[event] for event in event_ids}
698741

699742
async def get_state_ids_for_events(
700-
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
743+
self,
744+
event_ids: Collection[str],
745+
state_filter: Optional[StateFilter] = None,
701746
) -> Dict[str, StateMap[str]]:
702747
"""
703748
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -714,7 +759,13 @@ async def get_state_ids_for_events(
714759
RuntimeError if we don't have a state group for one or more of the events
715760
(ie they are outliers or unknown)
716761
"""
717-
event_to_groups = await self._get_state_group_for_events(event_ids)
762+
await_full_state = True
763+
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
764+
await_full_state = False
765+
766+
event_to_groups = await self._get_state_group_for_events(
767+
event_ids, await_full_state=await_full_state
768+
)
718769

719770
groups = set(event_to_groups.values())
720771
group_to_state = await self.stores.state._get_state_for_groups(
@@ -800,7 +851,7 @@ async def _get_state_group_for_events(
800851
Args:
801852
event_ids: events to get state groups for
802853
await_full_state: if true, will block if we do not yet have complete
803-
state at this event.
854+
state at these events.
804855
"""
805856
if await_full_state:
806857
await self._partial_state_events_tracker.await_full_state(event_ids)

0 commit comments

Comments
 (0)