Skip to content
This repository was archived by the owner on May 24, 2022. It is now read-only.

Commit d7d0b1f

Browse files
Add dependable caching
1 parent 1e9e07f commit d7d0b1f

6 files changed

Lines changed: 190 additions & 24 deletions

File tree

aiodine/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from .dependencies import call_resolved, depends
1+
from .dependencies import CACHE, call_resolved, depends
22

33
__version__ = "1.2.8"
44

5-
__all__ = ["__version__", "depends", "call_resolved"]
5+
cached = CACHE.cached
6+
7+
__all__ = ["__version__", "depends", "CACHE", "cached", "call_resolved"]

aiodine/dependencies.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,14 @@
88
asyncnullcontext,
99
is_async_context_manager,
1010
)
11+
from .models import Dependable, DependableFunc, DependablesCache, T
1112

12-
T = typing.TypeVar("T")
13-
DependableFunc = typing.Union[
14-
typing.Callable[..., typing.Awaitable[T]],
15-
typing.Callable[..., typing.AsyncContextManager[T]],
16-
]
1713

14+
def depends(func: DependableFunc[T], *, cached: bool = None) -> T:
15+
return typing.cast(T, Dependable(func, cached=cached))
1816

19-
def depends(func: DependableFunc[T]) -> T:
20-
return typing.cast(T, Dependable(func))
2117

22-
23-
class Dependable(typing.Generic[T]):
24-
def __init__(self, func: DependableFunc[T]):
25-
self.func = func
26-
27-
def __repr__(self) -> str:
28-
return f"{self.__class__.__name__}(func={self.func!r})"
18+
CACHE = DependablesCache()
2919

3020

3121
async def call_resolved(
@@ -46,13 +36,18 @@ async def call_resolved(
4636
bound = signature.bind_partial(*args, **kwargs)
4737
bound.apply_defaults()
4838

49-
for name, value in bound.arguments.items():
50-
if isinstance(value, Dependable):
51-
bound.arguments[name] = await call_resolved(
52-
value.func, __exit_stack__=exit_stack
53-
)
39+
for name, val in bound.arguments.items():
40+
if isinstance(val, Dependable):
41+
dependable = val
5442

55-
raw = func(*bound.args, **bound.kwargs)
43+
try:
44+
value = CACHE[dependable]
45+
except KeyError:
46+
value = await call_resolved(dependable.func, __exit_stack__=exit_stack)
47+
if CACHE.should_cache(dependable):
48+
CACHE[dependable] = value
49+
50+
bound.arguments[name] = value
5651

5752
ctx = (
5853
exit_stack
@@ -61,6 +56,8 @@ async def call_resolved(
6156
)
6257

6358
async with ctx:
59+
raw = func(*bound.args, **bound.kwargs)
60+
6461
if isinstance(raw, types.CoroutineType):
6562
return await raw
6663

aiodine/models.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import typing
2+
3+
T = typing.TypeVar("T")
4+
DependableFunc = typing.Union[
5+
typing.Callable[..., typing.Awaitable[T]],
6+
typing.Callable[..., typing.AsyncContextManager[T]],
7+
]
8+
F = typing.TypeVar("F", bound=DependableFunc)
9+
10+
11+
class Dependable(typing.Generic[T]):
12+
__slots__ = ("func", "cached")
13+
14+
def __init__(self, func: DependableFunc[T], cached: bool = None) -> None:
15+
self.func = func
16+
self.cached = cached
17+
18+
def __eq__(self, other: typing.Any) -> bool:
19+
if not isinstance(other, Dependable):
20+
return False
21+
return self.func == other.func and self.cached == other.cached
22+
23+
def __hash__(self) -> int:
24+
return hash((self.func, self.cached))
25+
26+
def __repr__(self) -> str:
27+
attrs = [f"func={self.func!r}"]
28+
if self.cached:
29+
attrs.append("cached")
30+
return f"{self.__class__.__name__}({', '.join(attrs)})"
31+
32+
33+
class DependablesCache(typing.Mapping[Dependable, typing.Any]):
34+
def __init__(self) -> None:
35+
self._last_id = 0
36+
self._cached_dependables: typing.Dict[Dependable, typing.Any] = {}
37+
self._cachable_funcs: typing.Set[DependableFunc] = set()
38+
39+
def __getitem__(self, dep: Dependable[T]) -> T:
40+
return self._cached_dependables[dep]
41+
42+
def __setitem__(self, dep: Dependable[T], value: T) -> None:
43+
self._cached_dependables[dep] = value
44+
45+
def __len__(self) -> int:
46+
return len(self._cached_dependables)
47+
48+
def __iter__(self) -> typing.Iterator[Dependable]:
49+
return iter(self._cached_dependables)
50+
51+
def should_cache(self, dep: Dependable[T]) -> bool:
52+
if dep.cached:
53+
return dep not in self
54+
if dep.cached is None:
55+
# 'cached=...' was not passed to 'depends()'.
56+
# => Should cache if the dependable function was decorated with '@cached'
57+
return dep.func in self._cachable_funcs
58+
return False
59+
60+
def cached(self, func: F) -> F:
61+
self._cachable_funcs.add(func)
62+
return func
63+
64+
def clear(self) -> None:
65+
self._cached_dependables = {}
66+
self._cachable_funcs = set()
67+
68+
69+
cache = DependablesCache()

tests/models/test_dependable.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ def __repr__(self) -> str:
2525
(CowSay(), "Dependable(func=Cow says moo!)"),
2626
],
2727
)
28-
def test_dependable_repr(func: typing.Callable, output: str) -> None:
29-
dependable = aiodine.depends(func)
28+
@pytest.mark.parametrize("cached", (None, False, True))
29+
def test_dependable_repr(func: typing.Callable, cached: bool, output: str) -> None:
30+
kwargs = {"cached": cached} if cached is not None else {}
31+
dependable = aiodine.depends(func, **kwargs)
32+
if cached:
33+
output = f"{output.rstrip(')')}, cached)"
3034
assert repr(dependable) == output
35+
36+
37+
def test_dependable_equal() -> None:
38+
dependable = aiodine.depends(cowsay)
39+
assert dependable == dependable
40+
assert dependable == aiodine.depends(cowsay)
41+
assert dependable != "not a Dependable"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
import aiodine
4+
from aiodine.models import DependablesCache
5+
6+
7+
@pytest.mark.anyio
8+
async def test_dependables_cache() -> None:
9+
cache = DependablesCache()
10+
11+
@cache.cached
12+
async def moo() -> str:
13+
return "moo"
14+
15+
moo_dependable = aiodine.dependencies.Dependable(moo)
16+
17+
assert len(cache) == 0
18+
assert moo_dependable not in cache
19+
with pytest.raises(KeyError):
20+
cache[moo_dependable]
21+
22+
cache[moo_dependable] = await moo()
23+
24+
assert len(cache) == 1
25+
assert moo_dependable in cache
26+
assert moo not in cache
27+
assert list(cache) == [moo_dependable]
28+
assert cache[moo_dependable] == "moo"
29+
30+
cache.clear()
31+
assert len(cache) == 0
32+
assert moo_dependable not in cache

tests/test_caching.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import typing
2+
3+
import anyio
4+
import pytest
5+
6+
import aiodine
7+
from aiodine.compat import asynccontextmanager
8+
9+
10+
@pytest.mark.anyio
11+
@pytest.mark.parametrize("cached", (False, True))
12+
@pytest.mark.parametrize("forced_cached", (None, False, True))
13+
@pytest.mark.parametrize("use_context_manager", (False, True))
14+
async def test_cache_function(
15+
cached: bool, forced_cached: typing.Optional[bool], use_context_manager: bool
16+
) -> None:
17+
aiodine.CACHE.clear()
18+
count = 0
19+
20+
moo: aiodine.models.DependableFunc[str]
21+
22+
if use_context_manager:
23+
24+
@asynccontextmanager
25+
async def moo() -> typing.AsyncIterator[str]:
26+
nonlocal count
27+
count += 1
28+
yield "moo"
29+
30+
else:
31+
32+
async def moo() -> str:
33+
nonlocal count
34+
count += 1
35+
return "moo"
36+
37+
if cached:
38+
moo = aiodine.cached(moo)
39+
40+
kwargs = {"cached": forced_cached} if forced_cached is not None else {}
41+
42+
async def cowsay(moo: str = aiodine.depends(moo, **kwargs)) -> str:
43+
return f"Cow says {moo}"
44+
45+
assert await aiodine.call_resolved(cowsay) == "Cow says moo"
46+
assert count == 1
47+
48+
assert await aiodine.call_resolved(cowsay) == "Cow says moo"
49+
50+
if (cached or forced_cached) and forced_cached is not False:
51+
assert count == 1
52+
assert list(aiodine.CACHE) == [aiodine.depends(moo, **kwargs)]
53+
else:
54+
assert count == 2
55+
assert list(aiodine.CACHE) == []

0 commit comments

Comments
 (0)