44# ------------------------------------
55from typing import TYPE_CHECKING
66
7+ from .base import AsyncCredentialBase
78from .._authn_client import AsyncAuthnClient
89from ..._base import ClientSecretCredentialBase , CertificateCredentialBase
910
1213 from azure .core .credentials import AccessToken
1314
1415
15- class ClientSecretCredential (ClientSecretCredentialBase ):
16+ class ClientSecretCredential (ClientSecretCredentialBase , AsyncCredentialBase ):
1617 """Authenticates as a service principal using a client ID and client secret.
1718
1819 :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
@@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs:
2829 super (ClientSecretCredential , self ).__init__ (tenant_id , client_id , client_secret , ** kwargs )
2930 self ._client = AsyncAuthnClient (tenant = tenant_id , ** kwargs )
3031
32+ async def __aenter__ (self ):
33+ await self ._client .__aenter__ ()
34+ return self
35+
36+ async def close (self ):
37+ """Close the credential's transport session."""
38+
39+ await self ._client .__aexit__ ()
40+
3141 async def get_token (self , * scopes : str , ** kwargs : "Any" ) -> "AccessToken" : # pylint:disable=unused-argument
3242 """Asynchronously request an access token for `scopes`.
3343
@@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
4454 return token # type: ignore
4555
4656
47- class CertificateCredential (CertificateCredentialBase ):
57+ class CertificateCredential (CertificateCredentialBase , AsyncCredentialBase ):
4858 """Authenticates as a service principal using a certificate.
4959
5060 :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
@@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase):
5767 defines authorities for other clouds.
5868 """
5969
70+ async def __aenter__ (self ):
71+ await self ._client .__aenter__ ()
72+ return self
73+
74+ async def close (self ):
75+ """Close the credential's transport session."""
76+
77+ await self ._client .__aexit__ ()
78+
6079 async def get_token (self , * scopes : str , ** kwargs : "Any" ) -> "AccessToken" : # pylint:disable=unused-argument
6180 """Asynchronously request an access token for `scopes`.
6281
0 commit comments