11# Copyright 2014-2016 OpenMarket Ltd
2+ # Copyright 2022 The Matrix.org Foundation C.I.C.
23#
34# Licensed under the Apache License, Version 2.0 (the "License");
45# you may not use this file except in compliance with the License.
1516from typing import (
1617 TYPE_CHECKING ,
1718 Awaitable ,
19+ Callable ,
1820 Collection ,
1921 Dict ,
2022 Iterable ,
@@ -532,6 +534,40 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
532534 new_all , new_excludes , new_wildcards , new_concrete_keys
533535 )
534536
537+ def must_await_full_state (self , is_mine_id : Callable [[str ], bool ]) -> bool :
538+ """Check if we need to wait for full state to complete to calculate this state
539+
540+ If we have a state filter which is completely satisfied even with partial
541+ state, then we don't need to await_full_state before we can return it.
542+
543+ Args:
544+ is_mine_id: a callable which confirms if a given state_key matches a mxid
545+ of a local user
546+ """
547+
548+ # XXX: can we be certain that the state at an event never changes (only gets
549+ # enlarged)?
550+
551+ # if we haven't requested membership events, then it depends on the value of
552+ # 'include_others'
553+ if EventTypes .Member not in self .types :
554+ return self .include_others
555+
556+ # if we're looking for *all* membership events, then we have to wait
557+ member_state_keys = self .types [EventTypes .Member ]
558+ if member_state_keys is None :
559+ return True
560+
561+ # otherwise, consider whose membership we are looking for. If it's entirely
562+ # local users, then we don't need to wait.
563+ for state_key in member_state_keys :
564+ if not is_mine_id (state_key ):
565+ # remote user
566+ return True
567+
568+ # local users only
569+ return False
570+
535571
536572_ALL_STATE_FILTER = StateFilter (types = frozendict (), include_others = True )
537573_ALL_NON_MEMBER_STATE_FILTER = StateFilter (
@@ -544,6 +580,7 @@ class StateGroupStorage:
544580 """High level interface to fetching state for event."""
545581
546582 def __init__ (self , hs : "HomeServer" , stores : "Databases" ):
583+ self ._is_mine_id = hs .is_mine_id
547584 self .stores = stores
548585 self ._partial_state_events_tracker = PartialStateEventsTracker (stores .main )
549586
@@ -675,7 +712,13 @@ async def get_state_for_events(
675712 RuntimeError if we don't have a state group for one or more of the events
676713 (ie they are outliers or unknown)
677714 """
678- event_to_groups = await self ._get_state_group_for_events (event_ids )
715+ await_full_state = True
716+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
717+ await_full_state = False
718+
719+ event_to_groups = await self ._get_state_group_for_events (
720+ event_ids , await_full_state = await_full_state
721+ )
679722
680723 groups = set (event_to_groups .values ())
681724 group_to_state = await self .stores .state ._get_state_for_groups (
@@ -699,7 +742,9 @@ async def get_state_for_events(
699742 return {event : event_to_state [event ] for event in event_ids }
700743
701744 async def get_state_ids_for_events (
702- self , event_ids : Collection [str ], state_filter : Optional [StateFilter ] = None
745+ self ,
746+ event_ids : Collection [str ],
747+ state_filter : Optional [StateFilter ] = None ,
703748 ) -> Dict [str , StateMap [str ]]:
704749 """
705750 Get the state dicts corresponding to a list of events, containing the event_ids
@@ -716,7 +761,13 @@ async def get_state_ids_for_events(
716761 RuntimeError if we don't have a state group for one or more of the events
717762 (ie they are outliers or unknown)
718763 """
719- event_to_groups = await self ._get_state_group_for_events (event_ids )
764+ await_full_state = True
765+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
766+ await_full_state = False
767+
768+ event_to_groups = await self ._get_state_group_for_events (
769+ event_ids , await_full_state = await_full_state
770+ )
720771
721772 groups = set (event_to_groups .values ())
722773 group_to_state = await self .stores .state ._get_state_for_groups (
@@ -802,7 +853,7 @@ async def _get_state_group_for_events(
802853 Args:
803854 event_ids: events to get state groups for
804855 await_full_state: if true, will block if we do not yet have complete
805- state at this event .
856+ state at these events .
806857 """
807858 if await_full_state :
808859 await self ._partial_state_events_tracker .await_full_state (event_ids )
0 commit comments