11import asyncio
22import functools
3+ import sys
34import typing
5+ from types import TracebackType
6+
7+ if sys .version_info < (3 , 8 ): # pragma: no cover
8+ from typing_extensions import Protocol
9+ else : # pragma: no cover
10+ from typing import Protocol
411
512
613def is_async_callable (obj : typing .Any ) -> bool :
@@ -10,3 +17,58 @@ def is_async_callable(obj: typing.Any) -> bool:
1017 return asyncio .iscoroutinefunction (obj ) or (
1118 callable (obj ) and asyncio .iscoroutinefunction (obj .__call__ )
1219 )
20+
21+
22+ T_co = typing .TypeVar ("T_co" , covariant = True )
23+
24+
25+ # TODO: once 3.8 is the minimum supported version (27 Jun 2023)
26+ # this can just become
27+ # class AwaitableOrContextManager(
28+ # typing.Awaitable[T_co],
29+ # typing.AsyncContextManager[T_co],
30+ # typing.Protocol[T_co],
31+ # ):
32+ # pass
33+ class AwaitableOrContextManager (Protocol [T_co ]):
34+ def __await__ (self ) -> typing .Generator [typing .Any , None , T_co ]:
35+ ... # pragma: no cover
36+
37+ async def __aenter__ (self ) -> T_co :
38+ ... # pragma: no cover
39+
40+ async def __aexit__ (
41+ self ,
42+ __exc_type : typing .Optional [typing .Type [BaseException ]],
43+ __exc_value : typing .Optional [BaseException ],
44+ __traceback : typing .Optional [TracebackType ],
45+ ) -> typing .Union [bool , None ]:
46+ ... # pragma: no cover
47+
48+
49+ class SupportsAsyncClose (Protocol ):
50+ async def close (self ) -> None :
51+ ... # pragma: no cover
52+
53+
54+ SupportsAsyncCloseType = typing .TypeVar (
55+ "SupportsAsyncCloseType" , bound = SupportsAsyncClose , covariant = False
56+ )
57+
58+
59+ class AwaitableOrContextManagerWrapper (typing .Generic [SupportsAsyncCloseType ]):
60+ __slots__ = ("aw" , "entered" )
61+
62+ def __init__ (self , aw : typing .Awaitable [SupportsAsyncCloseType ]) -> None :
63+ self .aw = aw
64+
65+ def __await__ (self ) -> typing .Generator [typing .Any , None , SupportsAsyncCloseType ]:
66+ return self .aw .__await__ ()
67+
68+ async def __aenter__ (self ) -> SupportsAsyncCloseType :
69+ self .entered = await self .aw
70+ return self .entered
71+
72+ async def __aexit__ (self , * args : typing .Any ) -> typing .Union [None , bool ]:
73+ await self .entered .close ()
74+ return None
0 commit comments