Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit ee61e78

Browse files
committed
Always make the EventContext hit the DB
1 parent cca3711 commit ee61e78

2 files changed

Lines changed: 11 additions & 93 deletions

File tree

synapse/events/snapshot.py

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616
import attr
1717
from frozendict import frozendict
1818

19-
from twisted.internet.defer import Deferred
20-
2119
from synapse.appservice import ApplicationService
2220
from synapse.events import EventBase
23-
from synapse.logging.context import make_deferred_yieldable, run_in_background
2421
from synapse.types import JsonDict, StateMap
2522

2623
if TYPE_CHECKING:
@@ -132,8 +129,6 @@ def with_state(
132129
) -> "EventContext":
133130
return EventContext(
134131
storage=storage,
135-
current_state_ids=current_state_ids,
136-
prev_state_ids=prev_state_ids,
137132
state_group=state_group,
138133
state_group_before_event=state_group_before_event,
139134
prev_group=prev_group,
@@ -163,20 +158,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
163158
The serialized event.
164159
"""
165160

166-
# We don't serialize the full state dicts, instead they get pulled out
167-
# of the DB on the other side. However, the other side can't figure out
168-
# the prev_state_ids, so if we're a state event we include the event
169-
# id that we replaced in the state.
170-
if event.is_state():
171-
prev_state_ids = await self.get_prev_state_ids()
172-
prev_state_id = prev_state_ids.get((event.type, event.state_key))
173-
else:
174-
prev_state_id = None
175-
176161
return {
177-
"prev_state_id": prev_state_id,
178-
"event_type": event.type,
179-
"event_state_key": event.get_state_key(),
180162
"state_group": self._state_group,
181163
"state_group_before_event": self.state_group_before_event,
182164
"rejected": self.rejected,
@@ -198,13 +180,10 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
198180
Returns:
199181
The event context.
200182
"""
201-
context = _AsyncEventContextImpl(
183+
context = EventContext(
202184
# We use the state_group and prev_state_id stuff to pull the
203185
# current_state_ids out of the DB and construct prev_state_ids.
204186
storage=storage,
205-
prev_state_id=input["prev_state_id"],
206-
event_type=input["event_type"],
207-
event_state_key=input["event_state_key"],
208187
state_group=input["state_group"],
209188
state_group_before_event=input["state_group_before_event"],
210189
prev_group=input["prev_group"],
@@ -255,8 +234,10 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
255234
if self.rejected:
256235
raise RuntimeError("Attempt to access state_ids of rejected event")
257236

258-
await self._ensure_fetched()
259-
return self._current_state_ids
237+
if self._state_group is None:
238+
return None
239+
240+
return await self._storage.state.get_state_ids_for_group(self._state_group)
260241

261242
async def get_prev_state_ids(self) -> StateMap[str]:
262243
"""
@@ -271,76 +252,10 @@ async def get_prev_state_ids(self) -> StateMap[str]:
271252
Maps a (type, state_key) to the event ID of the state event matching
272253
this tuple.
273254
"""
274-
await self._ensure_fetched()
275-
# There *should* be previous state IDs now.
276-
assert self._prev_state_ids is not None
277-
return self._prev_state_ids
278-
279-
async def _ensure_fetched(self) -> None:
280-
return None
281-
282-
283-
@attr.s(slots=True)
284-
class _AsyncEventContextImpl(EventContext):
285-
"""
286-
An implementation of EventContext which fetches _current_state_ids and
287-
_prev_state_ids from the database on demand.
288-
289-
Attributes:
290-
291-
_storage
292-
293-
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
294-
None if we haven't started calculating yet
295-
296-
_event_type: The type of the event the context is associated with.
297-
298-
_event_state_key: The state_key of the event the context is associated with.
299-
300-
_prev_state_id: If the event associated with the context is a state event,
301-
then `_prev_state_id` is the event_id of the state that was replaced.
302-
"""
303-
304-
# This needs to have a default as we're inheriting
305-
_storage: "Storage" = attr.ib(default=None)
306-
_prev_state_id: Optional[str] = attr.ib(default=None)
307-
_event_type: str = attr.ib(default=None)
308-
_event_state_key: Optional[str] = attr.ib(default=None)
309-
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
310-
311-
async def _ensure_fetched(self) -> None:
312-
if not self._fetching_state_deferred:
313-
self._fetching_state_deferred = run_in_background(self._fill_out_state)
314-
315-
await make_deferred_yieldable(self._fetching_state_deferred)
316-
317-
async def _fill_out_state(self) -> None:
318-
"""Called to populate the _current_state_ids and _prev_state_ids
319-
attributes by loading from the database.
320-
"""
321-
if self.state_group is None:
322-
# No state group means the event is an outlier. Usually the state_ids dicts are also
323-
# pre-set to empty dicts, but they get reset when the context is serialized, so set
324-
# them to empty dicts again here.
325-
self._current_state_ids = {}
326-
self._prev_state_ids = {}
327-
return
328-
329-
current_state_ids = await self._storage.state.get_state_ids_for_group(
330-
self.state_group
255+
assert self.state_group_before_event
256+
return await self._storage.state.get_state_ids_for_group(
257+
self.state_group_before_event
331258
)
332-
# Set this separately so mypy knows current_state_ids is not None.
333-
self._current_state_ids = current_state_ids
334-
if self._event_state_key is not None:
335-
self._prev_state_ids = dict(current_state_ids)
336-
337-
key = (self._event_type, self._event_state_key)
338-
if self._prev_state_id:
339-
self._prev_state_ids[key] = self._prev_state_id
340-
else:
341-
self._prev_state_ids.pop(key, None)
342-
else:
343-
self._prev_state_ids = current_state_ids
344259

345260

346261
def _encode_state_dict(

tests/test_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ async def get_state_groups_ids(self, room_id, event_ids):
8888

8989
return groups
9090

91+
async def get_state_ids_for_group(self, state_group):
92+
return self._group_to_state[state_group]
93+
9194
async def store_state_group(
9295
self, event_id, room_id, prev_group, delta_ids, current_state_ids
9396
):

0 commit comments

Comments
 (0)