11# Copyright 2014-2016 OpenMarket Ltd
2+ # Copyright 2022 The Matrix.org Foundation C.I.C.
23#
34# Licensed under the Apache License, Version 2.0 (the "License");
45# you may not use this file except in compliance with the License.
1516from typing import (
1617 TYPE_CHECKING ,
1718 Awaitable ,
19+ Callable ,
1820 Collection ,
1921 Dict ,
2022 Iterable ,
@@ -530,6 +532,40 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
530532 new_all , new_excludes , new_wildcards , new_concrete_keys
531533 )
532534
535+ def must_await_full_state (self , is_mine_id : Callable [[str ], bool ]) -> bool :
536+ """Check if we need to wait for full state to complete to calculate this state
537+
538+ If we have a state filter which is completely satisfied even with partial
539+ state, then we don't need to await_full_state before we can return it.
540+
541+ Args:
542+ is_mine_id: a callable which confirms if a given state_key matches a mxid
543+ of a local user
544+ """
545+
546+ # XXX: can we be certain that the state at an event never changes (only gets
547+ # enlarged)?
548+
549+ # if we haven't requested membership events, then it depends on the value of
550+ # 'include_others'
551+ if EventTypes .Member not in self .types :
552+ return self .include_others
553+
554+ # if we're looking for *all* membership events, then we have to wait
555+ member_state_keys = self .types [EventTypes .Member ]
556+ if member_state_keys is None :
557+ return True
558+
559+ # otherwise, consider whose membership we are looking for. If it's entirely
560+ # local users, then we don't need to wait.
561+ for state_key in member_state_keys :
562+ if not is_mine_id (state_key ):
563+ # remote user
564+ return True
565+
566+ # local users only
567+ return False
568+
533569
534570_ALL_STATE_FILTER = StateFilter (types = frozendict (), include_others = True )
535571_ALL_NON_MEMBER_STATE_FILTER = StateFilter (
@@ -542,6 +578,7 @@ class StateGroupStorage:
542578 """High level interface to fetching state for event."""
543579
544580 def __init__ (self , hs : "HomeServer" , stores : "Databases" ):
581+ self ._is_mine_id = hs .is_mine_id
545582 self .stores = stores
546583 self ._partial_state_events_tracker = PartialStateEventsTracker (stores .main )
547584
@@ -673,7 +710,13 @@ async def get_state_for_events(
673710 RuntimeError if we don't have a state group for one or more of the events
674711 (ie they are outliers or unknown)
675712 """
676- event_to_groups = await self ._get_state_group_for_events (event_ids )
713+ await_full_state = True
714+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
715+ await_full_state = False
716+
717+ event_to_groups = await self ._get_state_group_for_events (
718+ event_ids , await_full_state = await_full_state
719+ )
677720
678721 groups = set (event_to_groups .values ())
679722 group_to_state = await self .stores .state ._get_state_for_groups (
@@ -697,7 +740,9 @@ async def get_state_for_events(
697740 return {event : event_to_state [event ] for event in event_ids }
698741
699742 async def get_state_ids_for_events (
700- self , event_ids : Collection [str ], state_filter : Optional [StateFilter ] = None
743+ self ,
744+ event_ids : Collection [str ],
745+ state_filter : Optional [StateFilter ] = None ,
701746 ) -> Dict [str , StateMap [str ]]:
702747 """
703748 Get the state dicts corresponding to a list of events, containing the event_ids
@@ -714,7 +759,13 @@ async def get_state_ids_for_events(
714759 RuntimeError if we don't have a state group for one or more of the events
715760 (ie they are outliers or unknown)
716761 """
717- event_to_groups = await self ._get_state_group_for_events (event_ids )
762+ await_full_state = True
763+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
764+ await_full_state = False
765+
766+ event_to_groups = await self ._get_state_group_for_events (
767+ event_ids , await_full_state = await_full_state
768+ )
718769
719770 groups = set (event_to_groups .values ())
720771 group_to_state = await self .stores .state ._get_state_for_groups (
@@ -800,7 +851,7 @@ async def _get_state_group_for_events(
800851 Args:
801852 event_ids: events to get state groups for
802853 await_full_state: if true, will block if we do not yet have complete
803- state at this event .
854+ state at these events .
804855 """
805856 if await_full_state :
806857 await self ._partial_state_events_tracker .await_full_state (event_ids )
0 commit comments