3333 List ,
3434 Mapping ,
3535 Optional ,
36+ Protocol ,
3637 Sequence ,
3738 Tuple ,
3839 Type ,
4344from weakref import WeakValueDictionary
4445
4546import attr
47+ from typing_extensions import Concatenate , ParamSpec
4648
4749from twisted .internet import defer
4850from twisted .python .failure import Failure
4951
5052from synapse .logging .context import make_deferred_yieldable , preserve_fn
5153from synapse .util import unwrapFirstError
5254from synapse .util .async_helpers import delay_cancellation
55+ from synapse .util .caches import CacheManager
5356from synapse .util .caches .deferred_cache import DeferredCache
5457from synapse .util .caches .lrucache import LruCache
5558
@@ -183,6 +186,7 @@ def foo(self, key, cache_context):
183186
184187 Args:
185188 orig:
189+ cache_manager: The cache manager to handle metrics
186190 max_entries:
187191 num_args: number of positional arguments (excluding ``self`` and
188192 ``cache_context``) to use as cache keys. Defaults to all named
@@ -196,11 +200,14 @@ def foo(self, key, cache_context):
196200 prune_unread_entries: If True, cache entries that haven't been read recently
197201 will be evicted from the cache in the background. Set to False to opt-out
198202 of this behaviour.
203+ name: Will default to the `__name__` of the `orig` function.
199204 """
200205
201206 def __init__ (
202207 self ,
203208 orig : Callable [..., Any ],
209+ * ,
210+ cache_manager : CacheManager ,
204211 max_entries : int = 1000 ,
205212 num_args : Optional [int ] = None ,
206213 uncached_args : Optional [Collection [str ]] = None ,
@@ -217,6 +224,7 @@ def __init__(
217224 cache_context = cache_context ,
218225 name = name ,
219226 )
227+ self .cache_manager = cache_manager
220228
221229 if tree and self .num_args < 2 :
222230 raise RuntimeError (
@@ -233,6 +241,7 @@ def __get__(
233241 ) -> Callable [..., "defer.Deferred[Any]" ]:
234242 cache : DeferredCache [CacheKey , Any ] = DeferredCache (
235243 name = self .name ,
244+ cache_manager = self .cache_manager ,
236245 max_entries = self .max_entries ,
237246 tree = self .tree ,
238247 iterable = self .iterable ,
@@ -487,10 +496,12 @@ class _CachedFunctionDescriptor:
487496 iterable : bool
488497 prune_unread_entries : bool
489498 name : Optional [str ]
499+ cache_manager : CacheManager
490500
491501 def __call__ (self , orig : F ) -> CachedFunction [F ]:
492502 d = DeferredCacheDescriptor (
493503 orig ,
504+ cache_manager = self .cache_manager ,
494505 max_entries = self .max_entries ,
495506 num_args = self .num_args ,
496507 uncached_args = self .uncached_args ,
@@ -503,6 +514,15 @@ def __call__(self, orig: F) -> CachedFunction[F]:
503514 return cast (CachedFunction [F ], d )
504515
505516
517+ P = ParamSpec ("P" )
518+ R = TypeVar ("R" )
519+
520+
521+ class HasCacheManager (Protocol ):
522+ # Used to handle registering the caches
523+ cache_manager : CacheManager
524+
525+
506526def cached (
507527 * ,
508528 max_entries : int = 1000 ,
@@ -513,17 +533,55 @@ def cached(
513533 iterable : bool = False ,
514534 prune_unread_entries : bool = True ,
515535 name : Optional [str ] = None ,
516- ) -> _CachedFunctionDescriptor :
517- return _CachedFunctionDescriptor (
518- max_entries = max_entries ,
519- num_args = num_args ,
520- uncached_args = uncached_args ,
521- tree = tree ,
522- cache_context = cache_context ,
523- iterable = iterable ,
524- prune_unread_entries = prune_unread_entries ,
525- name = name ,
526- )
536+ ) -> Callable [[Callable [P , Awaitable [R ]]], Callable [P , Awaitable [R ]]]:
537+ """Decorate an async method with a `Measure` context manager.
538+
539+ The Measure is created using `self.cache_manager`; it should only be used to decorate
540+ methods in classes defining an instance-level `clock` attribute.
541+
542+ Usage:
543+
544+ @measure_func()
545+ async def foo(...):
546+ ...
547+
548+ Which is analogous to:
549+
550+ async def foo(...):
551+ with Measure(...):
552+ ...
553+
554+ """
555+
556+ def wrapper (
557+ func : Callable [Concatenate [HasCacheManager , P ], Awaitable [R ]],
558+ ) -> Callable [P , Awaitable [R ]]:
559+ # block_name = func.__name__ if name is None else name
560+
561+ @functools .wraps (func )
562+ async def cached_func (
563+ self : HasCacheManager , * args : P .args , ** kwargs : P .kwargs
564+ ) -> R :
565+ return _CachedFunctionDescriptor (
566+ max_entries = max_entries ,
567+ num_args = num_args ,
568+ uncached_args = uncached_args ,
569+ tree = tree ,
570+ cache_context = cache_context ,
571+ iterable = iterable ,
572+ prune_unread_entries = prune_unread_entries ,
573+ name = name ,
574+ # Grab this attribute from the instance
575+ cache_manager = self .cache_manager ,
576+ )
577+
578+ # There are some shenanigans here, because we're decorating a method but
579+ # explicitly making use of the `self` parameter. The key thing here is that the
580+ # return type within the return type for `measure_func` itself describes how the
581+ # decorated function will be called.
582+ return cached_func # type: ignore[return-value]
583+
584+ return wrapper # type: ignore[return-value]
527585
528586
529587@attr .s (auto_attribs = True , slots = True , frozen = True )
0 commit comments