Skip to content

Commit 97b602e

Browse files
lmazuelxiangyan99
andauthored
Add bearer token provider (#32655)
* Add bearer token provider * Only creates the policy once * Bump azure-core for typing * black * Revert "black" This reverts commit 6454f84. * black * Feedback --------- Co-authored-by: xiangyan99 <xiangsjtu@gmail.com>
1 parent 5b48838 commit 97b602e

7 files changed

Lines changed: 142 additions & 1 deletion

File tree

sdk/identity/azure-identity/azure/identity/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
WorkloadIdentityCredential,
2929
)
3030
from ._persistent_cache import TokenCachePersistenceOptions
31+
from ._bearer_token_provider import get_bearer_token_provider
3132

3233

3334
__all__ = [
@@ -55,6 +56,7 @@
5556
"UsernamePasswordCredential",
5657
"VisualStudioCodeCredential",
5758
"WorkloadIdentityCredential",
59+
"get_bearer_token_provider",
5860
]
5961

6062
from ._version import VERSION
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import Callable
6+
7+
from azure.core.credentials import TokenCredential
8+
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
9+
from azure.core.pipeline import PipelineRequest, PipelineContext
10+
from azure.core.rest import HttpRequest
11+
12+
13+
def _make_request() -> PipelineRequest[HttpRequest]:
14+
return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None))
15+
16+
17+
def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Callable[[], str]:
18+
"""Returns a callable that provides a bearer token.
19+
20+
It can be used for instance to write code like:
21+
22+
.. code-block:: python
23+
24+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
25+
26+
credential = DefaultAzureCredential()
27+
bearer_token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
28+
29+
# Usage
30+
request.headers["Authorization"] = "Bearer " + bearer_token_provider()
31+
32+
:param credential: The credential used to authenticate the request.
33+
:type credential: ~azure.core.credentials.TokenCredential
34+
:param str scopes: The scopes required for the bearer token.
35+
:rtype: callable
36+
:return: A callable that returns a bearer token.
37+
"""
38+
39+
policy = BearerTokenCredentialPolicy(credential, *scopes)
40+
41+
def wrapper() -> str:
42+
request = _make_request()
43+
policy.on_request(request)
44+
return request.http_request.headers["Authorization"][len("Bearer ") :]
45+
46+
return wrapper

sdk/identity/azure-identity/azure/identity/aio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ClientAssertionCredential,
2222
WorkloadIdentityCredential,
2323
)
24+
from ._bearer_token_provider import get_bearer_token_provider
2425

2526

2627
__all__ = [
@@ -39,4 +40,5 @@
3940
"VisualStudioCodeCredential",
4041
"ClientAssertionCredential",
4142
"WorkloadIdentityCredential",
43+
"get_bearer_token_provider",
4244
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import Callable, Coroutine, Any
6+
7+
from azure.core.credentials_async import AsyncTokenCredential
8+
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
9+
from azure.core.pipeline import PipelineRequest, PipelineContext
10+
from azure.core.rest import HttpRequest
11+
12+
13+
def _make_request() -> PipelineRequest[HttpRequest]:
14+
return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None))
15+
16+
17+
def get_bearer_token_provider(credential: AsyncTokenCredential, *scopes: str) -> Callable[[], Coroutine[Any, Any, str]]:
18+
"""Returns a callable that provides a bearer token.
19+
20+
It can be used for instance to write code like:
21+
22+
.. code-block:: python
23+
24+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
25+
26+
credential = DefaultAzureCredential()
27+
bearer_token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
28+
29+
30+
# Usage
31+
request.headers["Authorization"] = "Bearer " + await bearer_token_provider()
32+
33+
:param credential: The credential used to authenticate the request.
34+
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
35+
:param str scopes: The scopes required for the bearer token.
36+
:rtype: coroutine
37+
:return: A coroutine that returns a bearer token.
38+
"""
39+
40+
policy = AsyncBearerTokenCredentialPolicy(credential, *scopes)
41+
42+
async def wrapper() -> str:
43+
request = _make_request()
44+
await policy.on_request(request)
45+
return request.http_request.headers["Authorization"][len("Bearer ") :]
46+
47+
return wrapper

sdk/identity/azure-identity/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"Programming Language :: Python :: 3.9",
4848
"Programming Language :: Python :: 3.10",
4949
"Programming Language :: Python :: 3.11",
50+
"Programming Language :: Python :: 3.12",
5051
"License :: OSI Approved :: MIT License",
5152
],
5253
zip_safe=False,
@@ -59,7 +60,7 @@
5960
),
6061
python_requires=">=3.7",
6162
install_requires=[
62-
"azure-core<2.0.0,>=1.11.0",
63+
"azure-core<2.0.0,>=1.23.0",
6364
"cryptography>=2.5",
6465
"msal<2.0.0,>=1.24.0",
6566
"msal-extensions<2.0.0,>=0.3.0",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
6+
from azure.core.credentials import AccessToken
7+
from azure.identity import get_bearer_token_provider
8+
9+
10+
class MockCredential:
11+
def get_token(self, *scopes, **kwargs):
12+
assert len(scopes) == 1
13+
assert scopes[0] == "scope"
14+
return AccessToken("mock_token", 42)
15+
16+
17+
def test_get_bearer_token_provider():
18+
19+
func = get_bearer_token_provider(MockCredential(), "scope")
20+
assert func() == "mock_token"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
6+
from azure.core.credentials import AccessToken
7+
from azure.identity.aio import get_bearer_token_provider
8+
9+
import pytest
10+
11+
12+
class MockCredential:
13+
async def get_token(self, *scopes, **kwargs):
14+
assert len(scopes) == 1
15+
assert scopes[0] == "scope"
16+
return AccessToken("mock_token", 42)
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_get_bearer_token_provider():
21+
22+
func = get_bearer_token_provider(MockCredential(), "scope")
23+
assert await func() == "mock_token"

0 commit comments

Comments
 (0)