1717from frozendict import frozendict
1818from typing_extensions import Literal
1919
20- from twisted .internet .defer import Deferred
21-
2220from synapse .appservice import ApplicationService
2321from synapse .events import EventBase
24- from synapse .logging .context import make_deferred_yieldable , run_in_background
2522from synapse .types import JsonDict , StateMap
2623
2724if TYPE_CHECKING :
@@ -61,6 +58,9 @@ class EventContext:
6158 If ``state_group`` is None (ie, the event is an outlier),
6259 ``state_group_before_event`` will always also be ``None``.
6360
61+ state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
62+ then this is the delta of the state between the two groups.
63+
6464 prev_group: If it is known, ``state_group``'s prev_group. Note that this being
6565 None does not necessarily mean that ``state_group`` does not have
6666 a prev_group!
@@ -79,73 +79,47 @@ class EventContext:
7979 app_service: If this event is being sent by a (local) application service, that
8080 app service.
8181
82- _current_state_ids: The room state map, including this event - ie, the state
83- in ``state_group``.
84-
85- (type, state_key) -> event_id
86-
87- For an outlier, this is {}
88-
89- Note that this is a private attribute: it should be accessed via
90- ``get_current_state_ids``. _AsyncEventContext impl calculates this
91- on-demand: it will be None until that happens.
92-
93- _prev_state_ids: The room state map, excluding this event - ie, the state
94- in ``state_group_before_event``. For a non-state
95- event, this will be the same as _current_state_events.
96-
97- Note that it is a completely different thing to prev_group!
98-
99- (type, state_key) -> event_id
100-
101- For an outlier, this is {}
102-
103- As with _current_state_ids, this is a private attribute. It should be
104- accessed via get_prev_state_ids.
105-
10682 partial_state: if True, we may be storing this event with a temporary,
10783 incomplete state.
10884 """
10985
86+ _storage : "Storage"
11087 rejected : Union [Literal [False ], str ] = False
11188 _state_group : Optional [int ] = None
11289 state_group_before_event : Optional [int ] = None
90+ _state_delta_due_to_event : Optional [StateMap [str ]] = None
11391 prev_group : Optional [int ] = None
11492 delta_ids : Optional [StateMap [str ]] = None
11593 app_service : Optional [ApplicationService ] = None
11694
117- _current_state_ids : Optional [StateMap [str ]] = None
118- _prev_state_ids : Optional [StateMap [str ]] = None
119-
12095 partial_state : bool = False
12196
12297 @staticmethod
12398 def with_state (
99+ storage : "Storage" ,
124100 state_group : Optional [int ],
125101 state_group_before_event : Optional [int ],
126- current_state_ids : Optional [StateMap [str ]],
127- prev_state_ids : Optional [StateMap [str ]],
102+ state_delta_due_to_event : Optional [StateMap [str ]],
128103 partial_state : bool ,
129104 prev_group : Optional [int ] = None ,
130105 delta_ids : Optional [StateMap [str ]] = None ,
131106 ) -> "EventContext" :
132107 return EventContext (
133- current_state_ids = current_state_ids ,
134- prev_state_ids = prev_state_ids ,
108+ storage = storage ,
135109 state_group = state_group ,
136110 state_group_before_event = state_group_before_event ,
111+ state_delta_due_to_event = state_delta_due_to_event ,
137112 prev_group = prev_group ,
138113 delta_ids = delta_ids ,
139114 partial_state = partial_state ,
140115 )
141116
142117 @staticmethod
143- def for_outlier () -> "EventContext" :
118+ def for_outlier (
119+ storage : "Storage" ,
120+ ) -> "EventContext" :
144121 """Return an EventContext instance suitable for persisting an outlier event"""
145- return EventContext (
146- current_state_ids = {},
147- prev_state_ids = {},
148- )
122+ return EventContext (storage = storage )
149123
150124 async def serialize (self , event : EventBase , store : "DataStore" ) -> JsonDict :
151125 """Converts self to a type that can be serialized as JSON, and then
@@ -158,24 +132,14 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
158132 The serialized event.
159133 """
160134
161- # We don't serialize the full state dicts, instead they get pulled out
162- # of the DB on the other side. However, the other side can't figure out
163- # the prev_state_ids, so if we're a state event we include the event
164- # id that we replaced in the state.
165- if event .is_state ():
166- prev_state_ids = await self .get_prev_state_ids ()
167- prev_state_id = prev_state_ids .get ((event .type , event .state_key ))
168- else :
169- prev_state_id = None
170-
171135 return {
172- "prev_state_id" : prev_state_id ,
173- "event_type" : event .type ,
174- "event_state_key" : event .get_state_key (),
175136 "state_group" : self ._state_group ,
176137 "state_group_before_event" : self .state_group_before_event ,
177138 "rejected" : self .rejected ,
178139 "prev_group" : self .prev_group ,
140+ "state_delta_due_to_event" : _encode_state_dict (
141+ self ._state_delta_due_to_event
142+ ),
179143 "delta_ids" : _encode_state_dict (self .delta_ids ),
180144 "app_service_id" : self .app_service .id if self .app_service else None ,
181145 "partial_state" : self .partial_state ,
@@ -193,16 +157,16 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
193157 Returns:
194158 The event context.
195159 """
196- context = _AsyncEventContextImpl (
160+ context = EventContext (
197161 # We use the state_group and prev_state_id stuff to pull the
198162 # current_state_ids out of the DB and construct prev_state_ids.
199163 storage = storage ,
200- prev_state_id = input ["prev_state_id" ],
201- event_type = input ["event_type" ],
202- event_state_key = input ["event_state_key" ],
203164 state_group = input ["state_group" ],
204165 state_group_before_event = input ["state_group_before_event" ],
205166 prev_group = input ["prev_group" ],
167+ state_delta_due_to_event = _decode_state_dict (
168+ input ["state_delta_due_to_event" ]
169+ ),
206170 delta_ids = _decode_state_dict (input ["delta_ids" ]),
207171 rejected = input ["rejected" ],
208172 partial_state = input .get ("partial_state" , False ),
@@ -250,8 +214,15 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
250214 if self .rejected :
251215 raise RuntimeError ("Attempt to access state_ids of rejected event" )
252216
253- await self ._ensure_fetched ()
254- return self ._current_state_ids
217+ assert self ._state_delta_due_to_event is not None
218+
219+ prev_state_ids = await self .get_prev_state_ids ()
220+
221+ if self ._state_delta_due_to_event :
222+ prev_state_ids = dict (prev_state_ids )
223+ prev_state_ids .update (self ._state_delta_due_to_event )
224+
225+ return prev_state_ids
255226
256227 async def get_prev_state_ids (self ) -> StateMap [str ]:
257228 """
@@ -266,94 +237,10 @@ async def get_prev_state_ids(self) -> StateMap[str]:
266237 Maps a (type, state_key) to the event ID of the state event matching
267238 this tuple.
268239 """
269- await self ._ensure_fetched ()
270- # There *should* be previous state IDs now.
271- assert self ._prev_state_ids is not None
272- return self ._prev_state_ids
273-
274- def get_cached_current_state_ids (self ) -> Optional [StateMap [str ]]:
275- """Gets the current state IDs if we have them already cached.
276-
277- It is an error to access this for a rejected event, since rejected state should
278- not make it into the room state. This method will raise an exception if
279- ``rejected`` is set.
280-
281- Returns:
282- Returns None if we haven't cached the state or if state_group is None
283- (which happens when the associated event is an outlier).
284-
285- Otherwise, returns the the current state IDs.
286- """
287- if self .rejected :
288- raise RuntimeError ("Attempt to access state_ids of rejected event" )
289-
290- return self ._current_state_ids
291-
292- async def _ensure_fetched (self ) -> None :
293- return None
294-
295-
296- @attr .s (slots = True )
297- class _AsyncEventContextImpl (EventContext ):
298- """
299- An implementation of EventContext which fetches _current_state_ids and
300- _prev_state_ids from the database on demand.
301-
302- Attributes:
303-
304- _storage
305-
306- _fetching_state_deferred: Resolves when *_state_ids have been calculated.
307- None if we haven't started calculating yet
308-
309- _event_type: The type of the event the context is associated with.
310-
311- _event_state_key: The state_key of the event the context is associated with.
312-
313- _prev_state_id: If the event associated with the context is a state event,
314- then `_prev_state_id` is the event_id of the state that was replaced.
315- """
316-
317- # This needs to have a default as we're inheriting
318- _storage : "Storage" = attr .ib (default = None )
319- _prev_state_id : Optional [str ] = attr .ib (default = None )
320- _event_type : str = attr .ib (default = None )
321- _event_state_key : Optional [str ] = attr .ib (default = None )
322- _fetching_state_deferred : Optional ["Deferred[None]" ] = attr .ib (default = None )
323-
324- async def _ensure_fetched (self ) -> None :
325- if not self ._fetching_state_deferred :
326- self ._fetching_state_deferred = run_in_background (self ._fill_out_state )
327-
328- await make_deferred_yieldable (self ._fetching_state_deferred )
329-
330- async def _fill_out_state (self ) -> None :
331- """Called to populate the _current_state_ids and _prev_state_ids
332- attributes by loading from the database.
333- """
334- if self .state_group is None :
335- # No state group means the event is an outlier. Usually the state_ids dicts are also
336- # pre-set to empty dicts, but they get reset when the context is serialized, so set
337- # them to empty dicts again here.
338- self ._current_state_ids = {}
339- self ._prev_state_ids = {}
340- return
341-
342- current_state_ids = await self ._storage .state .get_state_ids_for_group (
343- self .state_group
240+ assert self .state_group_before_event is not None
241+ return await self ._storage .state .get_state_ids_for_group (
242+ self .state_group_before_event
344243 )
345- # Set this separately so mypy knows current_state_ids is not None.
346- self ._current_state_ids = current_state_ids
347- if self ._event_state_key is not None :
348- self ._prev_state_ids = dict (current_state_ids )
349-
350- key = (self ._event_type , self ._event_state_key )
351- if self ._prev_state_id :
352- self ._prev_state_ids [key ] = self ._prev_state_id
353- else :
354- self ._prev_state_ids .pop (key , None )
355- else :
356- self ._prev_state_ids = current_state_ids
357244
358245
359246def _encode_state_dict (
0 commit comments