diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index d07b0cdff082..04b43a2b0618 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -18,10 +18,10 @@ NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy, - DistributedTracingPolicy + DistributedTracingPolicy, ) from azure.core.pipeline.transport import RequestsTransport, HttpRequest -from azure.identity._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities +from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities try: ABC = abc.ABC @@ -98,6 +98,7 @@ def request_token(self, scopes, method, headers, form_data, params, **kwargs): @abc.abstractmethod def obtain_token_by_refresh_token(self, scopes, username): + # type: (Iterable[str], Optional[str]) -> AccessToken pass def _deserialize_and_cache_token(self, response, scopes, request_time): @@ -214,22 +215,45 @@ def request_token( token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token - def obtain_token_by_refresh_token(self, scopes, username): - # type: (Iterable[str], str) -> Optional[AccessToken] - """Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no - refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + def obtain_token_by_refresh_token(self, scopes, username=None): + # type: (Iterable[str], Optional[str]) -> AccessToken + """Acquire an access token using a cached refresh token. Raises ClientAuthenticationError if that fails. + This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + + # if an username is provided, restrict our search to accounts that have that username + query = {"username": username} if username else {} + accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query=query) + + # if more than one account was returned, ensure that that they all have the same home_account_id. If so, + # we'll treat them as equal, otherwise we can't know which one to pick, so we'll raise an error. + if len(accounts) > 1 and len({account.get("home_account_id") for account in accounts}) != 1: + if username: + message = ( + "Multiple entries found for user '{}' were found in the shared token cache. " + "This is not currently supported by SharedTokenCacheCredential." + ).format(username) + else: + # TODO: we could identify usernames associated with exactly one home account id + message = ( + "Multiple users were discovered in the shared token cache. If using DefaultAzureCredential, set " + "the AZURE_USERNAME environment variable to the preferred username. Otherwise, specify it when " + "constructing SharedTokenCacheCredential." + "\nDiscovered accounts: {}" + ).format(", ".join({account.get("username") for account in accounts})) + raise ClientAuthenticationError(message=message) - # find account matching username - accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username}) for account in accounts: - # try each refresh token that might work, return the first access token acquired - for token in self.get_refresh_tokens(scopes, account): - # currently we only support login.microsoftonline.com, which has an alias login.windows.net - # TODO: this must change to support sovereign clouds - environment = account.get("environment") - if not environment or (environment not in self._auth_url and environment != "login.windows.net"): + # ensure the account is associated with the token authority we expect to use + # ('environment' is an authority e.g. 'login.microsoftonline.com') + environment = account.get("environment") + if not environment or environment not in self._auth_url: + # doubtful this account can get the access token we want but public cloud's a special case + # because its authority has an alias: for our purposes login.windows.net = login.microsoftonline.com + if not (environment == "login.windows.net" and KnownAuthorities.AZURE_PUBLIC_CLOUD in self._auth_url): continue + # try each refresh token, returning the first access token acquired + for token in self.get_refresh_tokens(scopes, account): request = self.get_refresh_token_grant_request(token, scopes) request_time = int(time.time()) response = self._pipeline.run(request, stream=False) @@ -240,7 +264,11 @@ def obtain_token_by_refresh_token(self, scopes, username): except ClientAuthenticationError: continue - return None + message = "No cached token found" + if username: + message += " for '{}'".format(username) + + raise ClientAuthenticationError(message=message) @staticmethod def _create_config(**kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 784a08c66f51..5f0dfefce730 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -20,9 +20,9 @@ class DefaultAzureCredential(ChainedTokenCredential): 1. A service principal configured by environment variables. See :class:`~azure.identity.EnvironmentCredential` for more details. 2. An Azure managed identity. See :class:`~azure.identity.ManagedIdentityCredential` for more details. - 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a - value for the environment variable ``AZURE_USERNAME``. See :class:`~azure.identity.SharedTokenCacheCredential` - for more details. + 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. If multiple + identities are in the cache, then the value of the environment variable ``AZURE_USERNAME`` is used to select + which identity to use. See :class:`~azure.identity.SharedTokenCacheCredential` for more details. Keyword arguments - **authority** (str): Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', @@ -34,15 +34,8 @@ def __init__(self, **kwargs): authority = kwargs.pop("authority", None) credentials = [EnvironmentCredential(authority=authority, **kwargs), ManagedIdentityCredential(**kwargs)] - # SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a - # value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily - # without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME - # and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead) - if ( - SharedTokenCacheCredential.supported() - and EnvironmentVariables.AZURE_USERNAME in os.environ - and EnvironmentVariables.AZURE_PASSWORD not in os.environ - ): + # SharedTokenCacheCredential is part of the default only on supported platforms. + if SharedTokenCacheCredential.supported(): credentials.append( SharedTokenCacheCredential( username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), authority=authority, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/_credentials/user.py index ac987fc6b242..a83c32dc889c 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user.py @@ -123,8 +123,8 @@ class SharedTokenCacheCredential(object): defines authorities for other clouds. """ - def __init__(self, username, **kwargs): # pylint:disable=unused-argument - # type: (str, **Any) -> None + def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument + # type: (Optional[str], **Any) -> None self._username = username @@ -161,11 +161,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if not self._client: raise ClientAuthenticationError(message="Shared token cache unavailable") - token = self._client.obtain_token_by_refresh_token(scopes, self._username) - if not token: - raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username)) - - return token + return self._client.obtain_token_by_refresh_token(scopes, self._username) @staticmethod def supported(): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index 0670dae23296..10ee8ee94026 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -21,6 +21,7 @@ from azure.core.pipeline.transport import AioHttpTransport from .._authn_client import AuthnClientBase +from .._constants import KnownAuthorities if TYPE_CHECKING: from typing import Any, Dict, Iterable, Mapping, Optional @@ -67,21 +68,46 @@ async def request_token( token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token - async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username: str) -> "Optional[AccessToken]": - """Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no - refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + async def obtain_token_by_refresh_token( + self, scopes: "Iterable[str]", username: "Optional[str]" = None + ) -> "AccessToken": + """Acquire an access token using a cached refresh token. Raises ClientAuthenticationError if that fails. + This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + + # if an username is provided, restrict our search to accounts that have that username + query = {"username": username} if username else {} + accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query=query) + + # if more than one account was returned, ensure that that they all have the same home_account_id. If so, + # we'll treat them as equal, otherwise we can't know which one to pick, so we'll raise an error. + if len(accounts) > 1 and len({account.get("home_account_id") for account in accounts}) != 1: + if username: + message = ( + "Multiple entries found for user '{}' were found in the shared token cache. " + "This is not currently supported by SharedTokenCacheCredential" + ).format(username) + else: + # TODO: we could identify usernames associated with exactly one home account id + message = ( + "Multiple users were discovered in the shared token cache. If using DefaultAzureCredential, set " + "the AZURE_USERNAME environment variable to the preferred username. Otherwise, specify it when " + "constructing SharedTokenCacheCredential." + "\nDiscovered accounts: {}" + ).format(", ".join({account.get("username") for account in accounts})) + raise ClientAuthenticationError(message=message) - # find account matching username - accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username}) for account in accounts: - # try each refresh token that might work, return the first access token acquired - for token in self.get_refresh_tokens(scopes, account): - # currently we only support login.microsoftonline.com, which has an alias login.windows.net - # TODO: this must change to support sovereign clouds - environment = account.get("environment") - if not environment or (environment not in self._auth_url and environment != "login.windows.net"): + # ensure the account is associated with the token authority we expect to use + # ('environment' is an authority e.g. 'login.microsoftonline.com') + environment = account.get("environment") + if not environment or environment not in self._auth_url: + # doubtful this account can get the access token we want but public cloud's a special case + # because its authority has an alias: for our purposes login.windows.net = login.microsoftonline.com + if not (environment == "login.windows.net" and KnownAuthorities.AZURE_PUBLIC_CLOUD in self._auth_url): continue + # try each refresh token, returning the first access token acquired + for token in self.get_refresh_tokens(scopes, account): request = self.get_refresh_token_grant_request(token, scopes) request_time = int(time.time()) response = await self._pipeline.run(request, stream=False) @@ -92,7 +118,11 @@ async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username: except ClientAuthenticationError: continue - return None + message = "No cached token found" + if username: + message += " for '{}'".format(username) + + raise ClientAuthenticationError(message=message) @staticmethod def _create_config(**kwargs: "Any") -> Configuration: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 8e64efe6e7b9..f3ea1b039f51 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -20,9 +20,9 @@ class DefaultAzureCredential(ChainedTokenCredential): 1. A service principal configured by environment variables. See :class:`~azure.identity.aio.EnvironmentCredential` for more details. 2. An Azure managed identity. See :class:`~azure.identity.aio.ManagedIdentityCredential` for more details. - 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a - value for the environment variable ``AZURE_USERNAME``. See - :class:`~azure.identity.aio.SharedTokenCacheCredential` for more details. + 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. If multiple + identities are in the cache, then the value of the environment variable ``AZURE_USERNAME`` is used to select + which identity to use. See :class:`~azure.identity.aio.SharedTokenCacheCredential` for more details. Keyword arguments - **authority** (str): Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', @@ -34,15 +34,8 @@ def __init__(self, **kwargs): authority = kwargs.pop("authority", None) credentials = [EnvironmentCredential(authority=authority, **kwargs), ManagedIdentityCredential(**kwargs)] - # SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a - # value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily - # without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME - # and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead) - if ( - SharedTokenCacheCredential.supported() - and EnvironmentVariables.AZURE_USERNAME in os.environ - and EnvironmentVariables.AZURE_PASSWORD not in os.environ - ): + # SharedTokenCacheCredential is part of the default only on supported platforms. + if SharedTokenCacheCredential.supported(): credentials.append( SharedTokenCacheCredential( username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), authority=authority, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py index 7296b1ef424c..c28d4b501864 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py @@ -45,11 +45,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py if not self._client: raise ClientAuthenticationError(message="Shared token cache unavailable") - token = await self._client.obtain_token_by_refresh_token(scopes, self._username) - if not token: - raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username)) - - return token + return await self._client.obtain_token_by_refresh_token(scopes, self._username) @staticmethod def _get_auth_client(cache: "msal_extensions.FileTokenCache") -> "AuthnClientBase": diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 1f2057d1c21b..c1907d704967 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -247,37 +247,14 @@ def test_default_credential_shared_cache_use(mock_credential): assert mock_credential.supported.call_count == 1 mock_credential.supported.reset_mock() - # unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache - credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 - assert mock_credential.supported.call_count == 1 - mock_credential.supported = Mock(return_value=True) - # supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache + # supported platform -> default credential should use shared cache credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 + assert mock_credential.call_count == 1 assert mock_credential.supported.call_count == 1 mock_credential.supported.reset_mock() - # supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache - # (EnvironmentCredential should be used when both variables are set) - with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}): - credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 - - # supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache - with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}): - expected_token = AccessToken("***", 42) - mock_credential.return_value = Mock(get_token=lambda *_: expected_token) - - credential = DefaultAzureCredential() - assert mock_credential.call_count == 1 - - token = credential.get_token("scope") - assert token == expected_token - - def test_device_code_credential(): expected_token = "access-token" user_code = "user-code" diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index 7223dcf0470f..4bf12c2d6d00 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -296,32 +296,10 @@ async def test_default_credential_shared_cache_use(): assert mock_credential.supported.call_count == 1 mock_credential.supported.reset_mock() - # unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache - credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 - assert mock_credential.supported.call_count == 1 - mock_credential.supported = Mock(return_value=True) - # supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache + # supported platform -> default credential should use shared cache credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 + assert mock_credential.call_count == 1 assert mock_credential.supported.call_count == 1 mock_credential.supported.reset_mock() - - # supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache - # (EnvironmentCredential should be used when both variables are set) - with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}): - credential = DefaultAzureCredential() - assert mock_credential.call_count == 0 - - # supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache - with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}): - expected_token = AccessToken("***", 42) - mock_credential.return_value = Mock(get_token=asyncio.coroutine(lambda *_: expected_token)) - - credential = DefaultAzureCredential() - assert mock_credential.call_count == 1 - - token = await credential.get_token("scope") - assert token == expected_token