Skip to content

Commit 8b40aae

Browse files
authored
Expose methods for closing async credential transport sessions (#9090)
1 parent d8a9ffd commit 8b40aae

34 files changed

Lines changed: 544 additions & 105 deletions
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from typing import Any
9+
from typing_extensions import Protocol
10+
from .credentials import AccessToken
11+
12+
class AsyncTokenCredential(Protocol):
13+
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
14+
pass
15+
16+
async def close(self) -> None:
17+
pass
18+
19+
async def __aenter__(self):
20+
pass
21+
22+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
23+
pass

sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ class _BearerTokenCredentialPolicyBase(object):
2929
:param str scopes: Lets you specify the type of access needed.
3030
"""
3131

32-
def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
33-
# type: (TokenCredential, *str, Mapping[str, Any]) -> None
32+
def __init__(self, *scopes, **kwargs): # pylint:disable=unused-argument
33+
# type: (*str, **Any) -> None
3434
super(_BearerTokenCredentialPolicyBase, self).__init__()
3535
self._scopes = scopes
36-
self._credential = credential
3736
self._token = None # type: Optional[AccessToken]
3837

3938
@staticmethod
@@ -69,6 +68,11 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo
6968
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
7069
"""
7170

71+
def __init__(self, credential, *scopes, **kwargs):
72+
# type: (TokenCredential, *str, **Any) -> None
73+
self._credential = credential
74+
super(BearerTokenCredentialPolicy, self).__init__(*scopes, **kwargs)
75+
7276
def on_request(self, request):
7377
# type: (PipelineRequest) -> None
7478
"""Adds a bearer token Authorization header to request and sends request to next policy.

sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,33 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import threading
7+
from typing import TYPE_CHECKING
78

8-
from azure.core.pipeline import PipelineRequest
99
from azure.core.pipeline.policies import SansIOHTTPPolicy
1010
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase
1111

12+
if TYPE_CHECKING:
13+
# pylint:disable=unused-import
14+
from typing import Any
15+
from azure.core.credentials_async import AsyncTokenCredential
16+
from azure.core.pipeline import PipelineRequest
17+
1218

1319
class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
1420
# pylint:disable=too-few-public-methods
1521
"""Adds a bearer token Authorization header to requests.
1622
1723
:param credential: The credential.
18-
:type credential: ~azure.core.credentials.TokenCredential
24+
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
1925
:param str scopes: Lets you specify the type of access needed.
2026
"""
2127

22-
def __init__(self, credential, *scopes, **kwargs):
23-
super().__init__(credential, *scopes, **kwargs)
28+
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
29+
self._credential = credential
2430
self._lock = threading.Lock()
31+
super().__init__(*scopes, **kwargs)
2532

26-
async def on_request(self, request: PipelineRequest):
33+
async def on_request(self, request: "PipelineRequest"):
2734
"""Adds a bearer token Authorization header to request and sends request to next policy.
2835
2936
:param request: The pipeline request object to be modified.

sdk/identity/azure-identity/HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
- All credential pipelines include `ProxyPolicy`
66
([#8945](https://github.com/Azure/azure-sdk-for-python/pull/8945))
7+
- Async credentials are async context managers and have an async `close` method
8+
([#9090](https://github.com/Azure/azure-sdk-for-python/pull/9090))
79

810

911
## 1.1.0 (2019-11-27)

sdk/identity/azure-identity/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ async transport, such as [aiohttp](https://pypi.org/project/aiohttp/). See
213213
[azure-core documentation](../../core/azure-core/README.md#transport)
214214
for more information.
215215

216+
Async credentials should be closed when they're no longer needed. Each async
217+
credential is an async context manager and defines an async `close` method. For
218+
example:
219+
220+
```py
221+
from azure.identity.aio import DefaultAzureCredential
222+
223+
# call close when the credential is no longer needed
224+
credential = DefaultAzureCredential()
225+
...
226+
await credential.close()
227+
228+
# alternatively, use the credential as an async context manager
229+
credential = DefaultAzureCredential()
230+
async with credential:
231+
...
232+
```
233+
216234
This example demonstrates authenticating the asynchronous `SecretClient` from
217235
[azure-keyvault-secrets][azure_keyvault_secrets] with an asynchronous
218236
credential.

sdk/identity/azure-identity/azure/identity/_credentials/chained.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
from azure.core.credentials import AccessToken, TokenCredential
1616

1717

18+
def _get_error_message(history):
19+
attempts = []
20+
for credential, error in history:
21+
if error:
22+
attempts.append("{}: {}".format(credential.__class__.__name__, error))
23+
else:
24+
attempts.append(credential.__class__.__name__)
25+
return """No credential in this chain provided a token.
26+
Attempted credentials:\n\t{}""".format(
27+
"\n\t".join(attempts)
28+
)
29+
30+
1831
class ChainedTokenCredential(object):
1932
"""A sequence of credentials that is itself a credential.
2033
@@ -48,16 +61,5 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
4861
history.append((credential, ex.message))
4962
except Exception as ex: # pylint: disable=broad-except
5063
history.append((credential, str(ex)))
51-
error_message = self._get_error_message(history)
64+
error_message = _get_error_message(history)
5265
raise ClientAuthenticationError(message=error_message)
53-
54-
@staticmethod
55-
def _get_error_message(history):
56-
attempts = []
57-
for credential, error in history:
58-
if error:
59-
attempts.append("{}: {}".format(credential.__class__.__name__, error))
60-
else:
61-
attempts.append(credential.__class__.__name__)
62-
return """No credential in this chain provided a token.
63-
Attempted credentials:\n\t{}""".format("\n\t".join(attempts))

sdk/identity/azure-identity/azure/identity/aio/_authn_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def __init__(
5454
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
5555
super().__init__(**kwargs)
5656

57+
async def __aenter__(self):
58+
await self._pipeline.__aenter__()
59+
return self
60+
61+
async def __aexit__(self, *args):
62+
await self.close()
63+
64+
async def close(self) -> None:
65+
await self._pipeline.__aexit__()
66+
5767
async def request_token(
5868
self,
5969
scopes: "Iterable[str]",
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import abc
6+
7+
8+
class AsyncCredentialBase(abc.ABC):
9+
@abc.abstractmethod
10+
async def close(self):
11+
pass
12+
13+
async def __aenter__(self):
14+
return self
15+
16+
async def __aexit__(self, *args):
17+
await self.close()
18+
19+
@abc.abstractmethod
20+
async def get_token(self, *scopes, **kwargs):
21+
pass

sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,40 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import asyncio
56
from typing import TYPE_CHECKING
67

78
from azure.core.exceptions import ClientAuthenticationError
8-
from ... import ChainedTokenCredential as SyncChainedTokenCredential
9+
from .base import AsyncCredentialBase
10+
from ..._credentials.chained import _get_error_message
911

1012
if TYPE_CHECKING:
1113
from typing import Any
1214
from azure.core.credentials import AccessToken
15+
from azure.core.credentials_async import AsyncTokenCredential
1316

1417

15-
class ChainedTokenCredential(SyncChainedTokenCredential):
18+
class ChainedTokenCredential(AsyncCredentialBase):
1619
"""A sequence of credentials that is itself a credential.
1720
1821
Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first
1922
valid token received.
2023
2124
:param credentials: credential instances to form the chain
22-
:type credentials: :class:`azure.core.credentials.TokenCredential`
25+
:type credentials: :class:`azure.core.credentials.AsyncTokenCredential`
2326
"""
2427

25-
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
28+
def __init__(self, *credentials: "AsyncTokenCredential") -> None:
29+
if not credentials:
30+
raise ValueError("at least one credential is required")
31+
self.credentials = credentials
32+
33+
async def close(self):
34+
"""Close the transport sessions of all credentials in the chain."""
35+
36+
await asyncio.gather(*(credential.close() for credential in self.credentials))
37+
38+
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
2639
"""Asynchronously request a token from each credential, in order, returning the first token received.
2740
2841
If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError`
@@ -41,5 +54,5 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
4154
history.append((credential, ex.message))
4255
except Exception as ex: # pylint: disable=broad-except
4356
history.append((credential, str(ex)))
44-
error_message = self._get_error_message(history)
57+
error_message = _get_error_message(history)
4558
raise ClientAuthenticationError(message=error_message)

sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# ------------------------------------
55
from typing import TYPE_CHECKING
66

7+
from .base import AsyncCredentialBase
78
from .._authn_client import AsyncAuthnClient
89
from ..._base import ClientSecretCredentialBase, CertificateCredentialBase
910

@@ -12,7 +13,7 @@
1213
from azure.core.credentials import AccessToken
1314

1415

15-
class ClientSecretCredential(ClientSecretCredentialBase):
16+
class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
1617
"""Authenticates as a service principal using a client ID and client secret.
1718
1819
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
@@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs:
2829
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
2930
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)
3031

32+
async def __aenter__(self):
33+
await self._client.__aenter__()
34+
return self
35+
36+
async def close(self):
37+
"""Close the credential's transport session."""
38+
39+
await self._client.__aexit__()
40+
3141
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
3242
"""Asynchronously request an access token for `scopes`.
3343
@@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
4454
return token # type: ignore
4555

4656

47-
class CertificateCredential(CertificateCredentialBase):
57+
class CertificateCredential(CertificateCredentialBase, AsyncCredentialBase):
4858
"""Authenticates as a service principal using a certificate.
4959
5060
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
@@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase):
5767
defines authorities for other clouds.
5868
"""
5969

70+
async def __aenter__(self):
71+
await self._client.__aenter__()
72+
return self
73+
74+
async def close(self):
75+
"""Close the credential's transport session."""
76+
77+
await self._client.__aexit__()
78+
6079
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
6180
"""Asynchronously request an access token for `scopes`.
6281

0 commit comments

Comments
 (0)