1616import attr
1717from frozendict import frozendict
1818
19- from twisted .internet .defer import Deferred
20-
2119from synapse .appservice import ApplicationService
2220from synapse .events import EventBase
23- from synapse .logging .context import make_deferred_yieldable , run_in_background
2421from synapse .types import JsonDict , StateMap
2522
2623if TYPE_CHECKING :
@@ -132,8 +129,6 @@ def with_state(
132129 ) -> "EventContext" :
133130 return EventContext (
134131 storage = storage ,
135- current_state_ids = current_state_ids ,
136- prev_state_ids = prev_state_ids ,
137132 state_group = state_group ,
138133 state_group_before_event = state_group_before_event ,
139134 prev_group = prev_group ,
@@ -163,20 +158,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
163158 The serialized event.
164159 """
165160
166- # We don't serialize the full state dicts, instead they get pulled out
167- # of the DB on the other side. However, the other side can't figure out
168- # the prev_state_ids, so if we're a state event we include the event
169- # id that we replaced in the state.
170- if event .is_state ():
171- prev_state_ids = await self .get_prev_state_ids ()
172- prev_state_id = prev_state_ids .get ((event .type , event .state_key ))
173- else :
174- prev_state_id = None
175-
176161 return {
177- "prev_state_id" : prev_state_id ,
178- "event_type" : event .type ,
179- "event_state_key" : event .get_state_key (),
180162 "state_group" : self ._state_group ,
181163 "state_group_before_event" : self .state_group_before_event ,
182164 "rejected" : self .rejected ,
@@ -198,13 +180,10 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
198180 Returns:
199181 The event context.
200182 """
201- context = _AsyncEventContextImpl (
183+ context = EventContext (
202184 # We use the state_group and prev_state_id stuff to pull the
203185 # current_state_ids out of the DB and construct prev_state_ids.
204186 storage = storage ,
205- prev_state_id = input ["prev_state_id" ],
206- event_type = input ["event_type" ],
207- event_state_key = input ["event_state_key" ],
208187 state_group = input ["state_group" ],
209188 state_group_before_event = input ["state_group_before_event" ],
210189 prev_group = input ["prev_group" ],
@@ -255,8 +234,10 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
255234 if self .rejected :
256235 raise RuntimeError ("Attempt to access state_ids of rejected event" )
257236
258- await self ._ensure_fetched ()
259- return self ._current_state_ids
237+ if self ._state_group is None :
238+ return None
239+
240+ return await self ._storage .state .get_state_ids_for_group (self ._state_group )
260241
261242 async def get_prev_state_ids (self ) -> StateMap [str ]:
262243 """
@@ -271,76 +252,10 @@ async def get_prev_state_ids(self) -> StateMap[str]:
271252 Maps a (type, state_key) to the event ID of the state event matching
272253 this tuple.
273254 """
274- await self ._ensure_fetched ()
275- # There *should* be previous state IDs now.
276- assert self ._prev_state_ids is not None
277- return self ._prev_state_ids
278-
279- async def _ensure_fetched (self ) -> None :
280- return None
281-
282-
283- @attr .s (slots = True )
284- class _AsyncEventContextImpl (EventContext ):
285- """
286- An implementation of EventContext which fetches _current_state_ids and
287- _prev_state_ids from the database on demand.
288-
289- Attributes:
290-
291- _storage
292-
293- _fetching_state_deferred: Resolves when *_state_ids have been calculated.
294- None if we haven't started calculating yet
295-
296- _event_type: The type of the event the context is associated with.
297-
298- _event_state_key: The state_key of the event the context is associated with.
299-
300- _prev_state_id: If the event associated with the context is a state event,
301- then `_prev_state_id` is the event_id of the state that was replaced.
302- """
303-
304- # This needs to have a default as we're inheriting
305- _storage : "Storage" = attr .ib (default = None )
306- _prev_state_id : Optional [str ] = attr .ib (default = None )
307- _event_type : str = attr .ib (default = None )
308- _event_state_key : Optional [str ] = attr .ib (default = None )
309- _fetching_state_deferred : Optional ["Deferred[None]" ] = attr .ib (default = None )
310-
311- async def _ensure_fetched (self ) -> None :
312- if not self ._fetching_state_deferred :
313- self ._fetching_state_deferred = run_in_background (self ._fill_out_state )
314-
315- await make_deferred_yieldable (self ._fetching_state_deferred )
316-
317- async def _fill_out_state (self ) -> None :
318- """Called to populate the _current_state_ids and _prev_state_ids
319- attributes by loading from the database.
320- """
321- if self .state_group is None :
322- # No state group means the event is an outlier. Usually the state_ids dicts are also
323- # pre-set to empty dicts, but they get reset when the context is serialized, so set
324- # them to empty dicts again here.
325- self ._current_state_ids = {}
326- self ._prev_state_ids = {}
327- return
328-
329- current_state_ids = await self ._storage .state .get_state_ids_for_group (
330- self .state_group
255+ assert self .state_group_before_event
256+ return await self ._storage .state .get_state_ids_for_group (
257+ self .state_group_before_event
331258 )
332- # Set this separately so mypy knows current_state_ids is not None.
333- self ._current_state_ids = current_state_ids
334- if self ._event_state_key is not None :
335- self ._prev_state_ids = dict (current_state_ids )
336-
337- key = (self ._event_type , self ._event_state_key )
338- if self ._prev_state_id :
339- self ._prev_state_ids [key ] = self ._prev_state_id
340- else :
341- self ._prev_state_ids .pop (key , None )
342- else :
343- self ._prev_state_ids = current_state_ids
344259
345260
346261def _encode_state_dict (
0 commit comments