1414
1515import logging
1616import threading
17+ import weakref
1718from enum import Enum , auto
1819from typing import (
1920 TYPE_CHECKING ,
2324 Dict ,
2425 Iterable ,
2526 List ,
27+ MutableMapping ,
2628 Optional ,
2729 Set ,
2830 Tuple ,
@@ -248,6 +250,12 @@ def __init__(
248250 str , ObservableDeferred [Dict [str , EventCacheEntry ]]
249251 ] = {}
250252
253+ # We keep track of the events we have currently loaded in memory so that
254+ # we can reuse them even if they've been evicted from the cache. We only
255+ # track events that don't need redacting in here (as then we don't need
256+ # to track redaction status).
257+ self ._event_ref : MutableMapping [str , EventBase ] = weakref .WeakValueDictionary ()
258+
251259 self ._event_fetch_lock = threading .Condition ()
252260 self ._event_fetch_list : List [
253261 Tuple [Iterable [str ], "defer.Deferred[Dict[str, _EventRow]]" ]
@@ -723,6 +731,8 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
723731
724732 def _invalidate_get_event_cache (self , event_id : str ) -> None :
725733 self ._get_event_cache .invalidate ((event_id ,))
734+ self ._event_ref .pop (event_id , None )
735+ self ._current_event_fetches .pop (event_id , None )
726736
727737 def _get_events_from_cache (
728738 self , events : Iterable [str ], update_metrics : bool = True
@@ -738,13 +748,30 @@ def _get_events_from_cache(
738748 event_map = {}
739749
740750 for event_id in events :
751+ # First check if it's in the event cache
741752 ret = self ._get_event_cache .get (
742753 (event_id ,), None , update_metrics = update_metrics
743754 )
744- if not ret :
755+ if ret :
756+ event_map [event_id ] = ret
745757 continue
746758
747- event_map [event_id ] = ret
759+ # Otherwise check if we still have the event in memory.
760+ event = self ._event_ref .get (event_id )
761+ if event :
762+ # Reconstruct an event cache entry
763+
764+ cache_entry = EventCacheEntry (
765+ event = event ,
766+ # We don't cache weakrefs to redacted events, so we know
767+ # this is None.
768+ redacted_event = None ,
769+ )
770+ event_map [event_id ] = cache_entry
771+
772+ # We add the entry back into the cache as we want to keep
773+ # recently queried events in the cache.
774+ self ._get_event_cache .set ((event_id ,), cache_entry )
748775
749776 return event_map
750777
@@ -1124,6 +1151,10 @@ async def _get_events_from_db(
11241151 self ._get_event_cache .set ((event_id ,), cache_entry )
11251152 result_map [event_id ] = cache_entry
11261153
1154+ if not redacted_event :
1155+ # We only cache references to unredacted events.
1156+ self ._event_ref [event_id ] = original_ev
1157+
11271158 return result_map
11281159
11291160 async def _enqueue_events (self , events : Collection [str ]) -> Dict [str , _EventRow ]:
0 commit comments