Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 19d79b6

Browse files
authored
Refactor resolve_state_groups_for_events to not pull out full state when no state resolution happens. (#12775)
1 parent 3d8839c commit 19d79b6

5 files changed

Lines changed: 40 additions & 23 deletions

File tree

changelog.d/12775.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.

synapse/state/__init__.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ async def compute_event_context(
288288
#
289289
# first of all, figure out the state before the event
290290
#
291-
292291
if old_state:
293292
# if we're given the state before the event, then we use that
294293
state_ids_before_event: StateMap[str] = {
@@ -419,33 +418,37 @@ async def resolve_state_groups_for_events(
419418
"""
420419
logger.debug("resolve_state_groups event_ids %s", event_ids)
421420

422-
# map from state group id to the state in that state group (where
423-
# 'state' is a map from state key to event id)
424-
# dict[int, dict[(str, str), str]]
425-
state_groups_ids = await self.state_store.get_state_groups_ids(
426-
room_id, event_ids
427-
)
428-
429-
if len(state_groups_ids) == 0:
430-
return _StateCacheEntry(state={}, state_group=None)
431-
elif len(state_groups_ids) == 1:
432-
name, state_list = list(state_groups_ids.items()).pop()
421+
state_groups = await self.state_store.get_state_group_for_events(event_ids)
433422

434-
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
423+
state_group_ids = state_groups.values()
435424

425+
# check if each event has same state group id, if so there's no state to resolve
426+
state_group_ids_set = set(state_group_ids)
427+
if len(state_group_ids_set) == 1:
428+
(state_group_id,) = state_group_ids_set
429+
state = await self.state_store.get_state_for_groups(state_group_ids_set)
430+
prev_group, delta_ids = await self.state_store.get_state_group_delta(
431+
state_group_id
432+
)
436433
return _StateCacheEntry(
437-
state=state_list,
438-
state_group=name,
434+
state=state[state_group_id],
435+
state_group=state_group_id,
439436
prev_group=prev_group,
440437
delta_ids=delta_ids,
441438
)
439+
elif len(state_group_ids_set) == 0:
440+
return _StateCacheEntry(state={}, state_group=None)
442441

443442
room_version = await self.store.get_room_version_id(room_id)
444443

444+
state_to_resolve = await self.state_store.get_state_for_groups(
445+
state_group_ids_set
446+
)
447+
445448
result = await self._state_resolution_handler.resolve_state_groups(
446449
room_id,
447450
room_version,
448-
state_groups_ids,
451+
state_to_resolve,
449452
None,
450453
state_res_store=StateResolutionStore(self.store),
451454
)

synapse/storage/databases/state/store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _get_state_for_group_using_cache(
189189
group: int,
190190
state_filter: StateFilter,
191191
) -> Tuple[MutableStateMap[str], bool]:
192-
"""Checks if group is in cache. See `_get_state_for_groups`
192+
"""Checks if group is in cache. See `get_state_for_groups`
193193
194194
Args:
195195
cache: the state group cache to use

synapse/storage/state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ async def get_state_groups_ids(
586586
if not event_ids:
587587
return {}
588588

589-
event_to_groups = await self._get_state_group_for_events(event_ids)
589+
event_to_groups = await self.get_state_group_for_events(event_ids)
590590

591591
groups = set(event_to_groups.values())
592592
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -602,7 +602,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
602602
Returns:
603603
Resolves to a map of (type, state_key) -> event_id
604604
"""
605-
group_to_state = await self._get_state_for_groups((state_group,))
605+
group_to_state = await self.get_state_for_groups((state_group,))
606606

607607
return group_to_state[state_group]
608608

@@ -675,7 +675,7 @@ async def get_state_for_events(
675675
RuntimeError if we don't have a state group for one or more of the events
676676
(ie they are outliers or unknown)
677677
"""
678-
event_to_groups = await self._get_state_group_for_events(event_ids)
678+
event_to_groups = await self.get_state_group_for_events(event_ids)
679679

680680
groups = set(event_to_groups.values())
681681
group_to_state = await self.stores.state._get_state_for_groups(
@@ -716,7 +716,7 @@ async def get_state_ids_for_events(
716716
RuntimeError if we don't have a state group for one or more of the events
717717
(ie they are outliers or unknown)
718718
"""
719-
event_to_groups = await self._get_state_group_for_events(event_ids)
719+
event_to_groups = await self.get_state_group_for_events(event_ids)
720720

721721
groups = set(event_to_groups.values())
722722
group_to_state = await self.stores.state._get_state_for_groups(
@@ -774,7 +774,7 @@ async def get_state_ids_for_event(
774774
)
775775
return state_map[event_id]
776776

777-
def _get_state_for_groups(
777+
def get_state_for_groups(
778778
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
779779
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
780780
"""Gets the state at each of a list of state groups, optionally
@@ -792,7 +792,7 @@ def _get_state_for_groups(
792792
groups, state_filter or StateFilter.all()
793793
)
794794

795-
async def _get_state_group_for_events(
795+
async def get_state_group_for_events(
796796
self,
797797
event_ids: Collection[str],
798798
await_full_state: bool = True,

tests/test_state.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@ def register_event_id_state_group(self, event_id, state_group):
129129
async def get_room_version_id(self, room_id):
130130
return RoomVersions.V1.identifier
131131

132+
async def get_state_group_for_events(self, event_ids):
133+
res = {}
134+
for event in event_ids:
135+
res[event] = self._event_to_state_group[event]
136+
return res
137+
138+
async def get_state_for_groups(self, groups):
139+
res = {}
140+
for group in groups:
141+
state = self._group_to_state[group]
142+
res[group] = state
143+
return res
144+
132145

133146
class DictObj(dict):
134147
def __init__(self, **kwargs):

0 commit comments

Comments
 (0)