Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from azure.core.credentials import AccessToken
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpRequest
from ..._internal import AadClientCertificate

Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]
Expand Down Expand Up @@ -44,41 +45,31 @@ async def obtain_token_by_authorization_code(
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_client_certificate(
self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_client_secret(
self, scopes: "Iterable[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_jwt_assertion(
self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_jwt_assertion_request(scopes, assertion)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_refresh_token(
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_on_behalf_of(
self,
Expand All @@ -90,10 +81,16 @@ async def obtain_token_on_behalf_of(
request = self._get_on_behalf_of_request(
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return await self._run_pipeline(request, **kwargs)

# pylint:disable=no-self-use
def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline":
return build_async_pipeline(**kwargs)

async def _run_pipeline(self, request: "HttpRequest", **kwargs: "Any") -> "AccessToken":
# remove tenant_id kwarg that could have been passed from credential's get_token method
# tenant_id is already part of `request` at this point
kwargs.pop("tenant_id", None)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import ClientSecretCredential, TokenCachePersistenceOptions
from azure.identity._enums import RegionalAuthority
Expand Down Expand Up @@ -207,7 +206,10 @@ def test_multitenant_authentication():
second_tenant = "second-tenant"
second_token = first_token * 2

def send(request, **_):
def send(request, **kwargs):
with pytest.raises(KeyError):
kwargs["tenant_id"]
Comment thread
mccoyp marked this conversation as resolved.
Outdated

parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant)
Expand All @@ -233,6 +235,18 @@ def send(request, **_):
token = credential.get_token("scope")
assert token.token == first_token


def test_live_multitenant_authentication(live_service_principal):
# first create a credential with a non-existent tenant
credential = ClientSecretCredential(
"...", live_service_principal["client_id"], live_service_principal["client_secret"]
)
# then get a valid token for an actual tenant
token = credential.get_token("https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"])
assert token.token
assert token.expires_on


def test_multitenant_authentication_not_allowed():
expected_tenant = "expected-tenant"
expected_token = "***"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from urllib.parse import urlparse

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import TokenCachePersistenceOptions
from azure.identity._constants import EnvironmentVariables
Expand Down Expand Up @@ -257,10 +256,14 @@ async def test_multitenant_authentication():
second_tenant = "second-tenant"
second_token = first_token * 2

async def send(request, **_):
async def send(request, **kwargs):
with pytest.raises(KeyError):
kwargs["tenant_id"]

parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant)

token = first_token if tenant == first_tenant else second_token
return mock_response(json_payload=build_aad_response(access_token=token))

Expand All @@ -280,6 +283,21 @@ async def send(request, **_):
token = await credential.get_token("scope")
assert token.token == first_token


@pytest.mark.asyncio
async def test_live_multitenant_authentication(live_service_principal):
Comment thread
mccoyp marked this conversation as resolved.
# first create a credential with a non-existent tenant
credential = ClientSecretCredential(
"...", live_service_principal["client_id"], live_service_principal["client_secret"]
)
# then get a valid token for an actual tenant
token = await credential.get_token(
"https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"]
)
assert token.token
assert token.expires_on


@pytest.mark.asyncio
async def test_multitenant_authentication_not_allowed():
expected_tenant = "expected-tenant"
Expand Down