|
20 | 20 |
|
21 | 21 | from .exception_wrapper import wrap_exceptions |
22 | 22 | from .msal_transport_adapter import MsalTransportAdapter |
23 | | -from .._exceptions import AuthenticationRequiredError |
| 23 | +from .._constants import KnownAuthorities |
| 24 | +from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError |
24 | 25 | from .._internal import get_default_authority, normalize_authority |
25 | 26 | from .._auth_record import AuthenticationRecord |
26 | 27 |
|
|
41 | 42 |
|
42 | 43 | _LOGGER = logging.getLogger(__name__) |
43 | 44 |
|
| 45 | +_DEFAULT_AUTHENTICATE_SCOPES = { |
| 46 | + "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), |
| 47 | + "https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",), |
| 48 | + "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), |
| 49 | + "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), |
| 50 | +} |
| 51 | + |
44 | 52 |
|
45 | 53 | def _decode_client_info(raw): |
46 | 54 | """Taken from msal.oauth2cli.oidc""" |
@@ -91,11 +99,9 @@ class MsalCredential(ABC): |
91 | 99 |
|
92 | 100 | def __init__(self, client_id, client_credential=None, **kwargs): |
93 | 101 | # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None |
94 | | - tenant_id = kwargs.pop("tenant_id", None) or "organizations" |
95 | 102 | authority = kwargs.pop("authority", None) |
96 | | - authority = normalize_authority(authority) if authority else get_default_authority() |
97 | | - |
98 | | - self._base_url = "/".join((authority, tenant_id.strip("/"))) |
| 103 | + self._authority = normalize_authority(authority) if authority else get_default_authority() |
| 104 | + self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" |
99 | 105 |
|
100 | 106 | self._client_credential = client_credential |
101 | 107 | self._client_id = client_id |
@@ -130,7 +136,7 @@ def _create_app(self, cls): |
130 | 136 | app = cls( |
131 | 137 | client_id=self._client_id, |
132 | 138 | client_credential=self._client_credential, |
133 | | - authority=self._base_url, |
| 139 | + authority="{}/{}".format(self._authority, self._tenant_id), |
134 | 140 | token_cache=self._cache, |
135 | 141 | ) |
136 | 142 |
|
@@ -242,15 +248,24 @@ def authenticate(self, **kwargs): |
242 | 248 | # type: (**Any) -> AuthenticationRecord |
243 | 249 | """Interactively authenticate a user. |
244 | 250 |
|
245 | | - :keyword Sequence[str] scopes: optional scopes to request during authentication, such as those provided by |
| 251 | + :keyword Sequence[str] scopes: scopes to request during authentication, such as those provided by |
246 | 252 | :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token |
247 | 253 | for these scopes. |
248 | 254 | :rtype: ~azure.identity.AuthenticationRecord |
249 | 255 | :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` |
250 | 256 | attribute gives a reason. |
251 | 257 | """ |
252 | 258 |
|
253 | | - scopes = kwargs.pop("scopes", None) or ("https://management.azure.com/.default",) |
| 259 | + scopes = kwargs.pop("scopes", None) |
| 260 | + if not scopes: |
| 261 | + if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: |
| 262 | + # the credential is configured to use a cloud whose ARM scope we can't determine |
| 263 | + raise CredentialUnavailableError( |
| 264 | + message="Authenticating in this environment requires a value for the 'scopes' keyword argument." |
| 265 | + ) |
| 266 | + |
| 267 | + scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] |
| 268 | + |
254 | 269 | _ = self.get_token(*scopes, _allow_prompt=True, **kwargs) |
255 | 270 | return self.authentication_record # type: ignore |
256 | 271 |
|
|
0 commit comments