From 13c5a3fae198c07055a15636dafe77f1b702e73f Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Mon, 23 Oct 2023 20:15:12 +0000 Subject: [PATCH 1/7] Add bearer token provider --- .../azure-identity/azure/identity/__init__.py | 2 + .../azure/identity/_bearer_token_provider.py | 45 ++++++++++++++++++ .../azure/identity/aio/__init__.py | 2 + .../identity/aio/_bearer_token_provider.py | 46 +++++++++++++++++++ .../tests/test_bearer_token_provider.py | 20 ++++++++ .../tests/test_bearer_token_provider_async.py | 23 ++++++++++ 6 files changed, 138 insertions(+) create mode 100644 sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py create mode 100644 sdk/identity/azure-identity/tests/test_bearer_token_provider.py create mode 100644 sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 8030b55ee033..53f9e45c89ba 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -28,6 +28,7 @@ WorkloadIdentityCredential, ) from ._persistent_cache import TokenCachePersistenceOptions +from ._bearer_token_provider import get_bearer_token_provider __all__ = [ @@ -55,6 +56,7 @@ "UsernamePasswordCredential", "VisualStudioCodeCredential", "WorkloadIdentityCredential", + "get_bearer_token_provider", ] from ._version import VERSION diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py new file mode 100644 index 000000000000..cd2c86e15061 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -0,0 +1,45 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import Callable + +from azure.core.credentials import TokenCredential +from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.pipeline import PipelineRequest, PipelineContext +from azure.core.rest import HttpRequest + + +def _make_request() -> PipelineRequest[HttpRequest]: + return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None)) + + +def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Callable[[], str]: + """Returns a callable that provides a bearer token. + + It can be used for instance to write code like: + + .. code-block:: python + + from azure.identity import DefaultAzureCredential, get_bearer_token_provider + + credential = DefaultAzureCredential() + bearer_token_provider = get_bearer_token_provider(credential, "https://storage.azure.com/.default") + + # Usage + request.headers["Authorization"] = "Bearer " + bearer_token_provider() + + :param credential: The credential used to authenticate the request. + :type credential: ~azure.core.credentials.TokenCredential + :param str scopes: The scopes required for the bearer token. + :rtype: callable + :return: A callable that returns a bearer token. + """ + + def wrapper() -> str: + policy = BearerTokenCredentialPolicy(credential, *scopes) + request = _make_request() + policy.on_request(request) + return request.http_request.headers["Authorization"][len("Bearer ") :] + + return wrapper diff --git a/sdk/identity/azure-identity/azure/identity/aio/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/__init__.py index 3c891665e83e..c6d9763b0263 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/__init__.py @@ -21,6 +21,7 @@ ClientAssertionCredential, WorkloadIdentityCredential, ) +from ._bearer_token_provider import get_bearer_token_provider __all__ = [ @@ -39,4 +40,5 @@ "VisualStudioCodeCredential", "ClientAssertionCredential", "WorkloadIdentityCredential", + "get_bearer_token_provider", ] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py new file mode 100644 index 000000000000..9817ff56467c --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py @@ -0,0 +1,46 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import Callable, Coroutine, Any + +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy +from azure.core.pipeline import PipelineRequest, PipelineContext +from azure.core.rest import HttpRequest + + +def _make_request() -> PipelineRequest[HttpRequest]: + return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None)) + + +def get_bearer_token_provider(credential: AsyncTokenCredential, *scopes: str) -> Callable[[], Coroutine[Any, Any, str]]: + """Returns a callable that provides a bearer token. + + It can be used for instance to write code like: + + .. code-block:: python + + from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider + + credential = DefaultAzureCredential() + bearer_token_provider = get_bearer_token_provider(credential, "https://storage.azure.com/.default") + + + # Usage + request.headers["Authorization"] = "Bearer " + await bearer_token_provider() + + :param credential: The credential used to authenticate the request. + :type credential: ~azure.core.credentials.TokenCredential + :param str scopes: The scopes required for the bearer token. + :rtype: coroutine + :return: A coroutine that returns a bearer token. + """ + + async def wrapper() -> str: + policy = AsyncBearerTokenCredentialPolicy(credential, *scopes) + request = _make_request() + await policy.on_request(request) + return request.http_request.headers["Authorization"][len("Bearer ") :] + + return wrapper diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py new file mode 100644 index 000000000000..f20ce7ac1d88 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py @@ -0,0 +1,20 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +from azure.core.credentials import AccessToken +from azure.identity import get_bearer_token_provider + + +class MockCredential: + def get_token(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessToken("mock_token", 42) + + +def test_get_bearer_token_provider(): + + func = get_bearer_token_provider(MockCredential(), "scope") + assert func() == "mock_token" diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py new file mode 100644 index 000000000000..35a8db46457e --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py @@ -0,0 +1,23 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +from azure.core.credentials import AccessToken +from azure.identity.aio import get_bearer_token_provider + +import pytest + + +class MockCredential: + async def get_token(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessToken("mock_token", 42) + + +@pytest.mark.asyncio +async def test_get_bearer_token_provider(): + + func = get_bearer_token_provider(MockCredential(), "scope") + assert await func() == "mock_token" From 9cf84305aa7a2ded01dae86ee68856547200104e Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Tue, 24 Oct 2023 16:38:47 +0000 Subject: [PATCH 2/7] Only creates the policy once --- .../azure-identity/azure/identity/_bearer_token_provider.py | 2 +- .../azure-identity/azure/identity/aio/_bearer_token_provider.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py index cd2c86e15061..728880f6fc77 100644 --- a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -36,8 +36,8 @@ def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Call :return: A callable that returns a bearer token. """ + policy = BearerTokenCredentialPolicy(credential, *scopes) def wrapper() -> str: - policy = BearerTokenCredentialPolicy(credential, *scopes) request = _make_request() policy.on_request(request) return request.http_request.headers["Authorization"][len("Bearer ") :] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py index 9817ff56467c..a7ae5c97c153 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py @@ -37,8 +37,8 @@ def get_bearer_token_provider(credential: AsyncTokenCredential, *scopes: str) -> :return: A coroutine that returns a bearer token. """ + policy = AsyncBearerTokenCredentialPolicy(credential, *scopes) async def wrapper() -> str: - policy = AsyncBearerTokenCredentialPolicy(credential, *scopes) request = _make_request() await policy.on_request(request) return request.http_request.headers["Authorization"][len("Bearer ") :] From a092c0560bf2bf10163d7861505e28b87cc20068 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Tue, 24 Oct 2023 16:45:43 +0000 Subject: [PATCH 3/7] Bump azure-core for typing --- sdk/identity/azure-identity/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index defcbe08fbf9..e9f5eb9d3266 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -47,6 +47,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", ], zip_safe=False, @@ -59,7 +60,7 @@ ), python_requires=">=3.7", install_requires=[ - "azure-core<2.0.0,>=1.11.0", + "azure-core<2.0.0,>=1.23.0", "cryptography>=2.5", "msal<2.0.0,>=1.24.0", "msal-extensions<2.0.0,>=0.3.0", From 6454f841e7bb742f42478b72c31e875e5e770a4c Mon Sep 17 00:00:00 2001 From: xiangyan99 Date: Tue, 24 Oct 2023 14:38:00 -0700 Subject: [PATCH 4/7] black --- .../documents/aio/_search_client_async.py | 74 +++++-------------- 1 file changed, 17 insertions(+), 57 deletions(-) diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py index e513d2cdd3d4..b3f4d487af30 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py @@ -66,11 +66,7 @@ class SearchClient(HeadersMixin): _client: SearchIndexClient def __init__( - self, - endpoint: str, - index_name: str, - credential: Union[AzureKeyCredential, AsyncTokenCredential], - **kwargs: Any + self, endpoint: str, index_name: str, credential: Union[AzureKeyCredential, AsyncTokenCredential], **kwargs: Any ) -> None: self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._index_documents_batch = IndexDocumentsBatch() @@ -89,9 +85,7 @@ def __init__( ) else: self._aad = True - authentication_policy = get_authentication_policy( - credential, audience=audience, is_async=True - ) + authentication_policy = get_authentication_policy(credential, audience=audience, is_async=True) self._client = SearchIndexClient( endpoint=endpoint, index_name=index_name, @@ -102,9 +96,7 @@ def __init__( ) def __repr__(self) -> str: - return "".format( - repr(self._endpoint), repr(self._index_name) - )[:1024] + return "".format(repr(self._endpoint), repr(self._index_name))[:1024] async def close(self) -> None: """Close the :class:`~azure.search.documents.aio.SearchClient` session. @@ -125,9 +117,7 @@ async def get_document_count(self, **kwargs: Any) -> int: return int(await self._client.documents.count(**kwargs)) @distributed_trace_async - async def get_document( - self, key: str, selected_fields: Optional[List[str]] = None, **kwargs: Any - ) -> Dict: + async def get_document(self, key: str, selected_fields: Optional[List[str]] = None, **kwargs: Any) -> Dict: """Retrieve a document from the Azure search index by its key. :param key: The primary key value for the document to retrieve @@ -147,9 +137,7 @@ async def get_document( :caption: Get a specific document from the search index. """ kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) - result = await self._client.documents.get( - key=key, selected_fields=selected_fields, **kwargs - ) + result = await self._client.documents.get(key=key, selected_fields=selected_fields, **kwargs) return cast(dict, result) @distributed_trace_async @@ -340,16 +328,8 @@ async def search( include_total_result_count = include_total_count filter_arg = filter search_fields_str = ",".join(search_fields) if search_fields else None - answers = ( - query_answer - if not query_answer_count - else "{}|count-{}".format(query_answer, query_answer_count) - ) - answers = ( - answers - if not query_answer_threshold - else "{}|threshold-{}".format(answers, query_answer_threshold) - ) + answers = query_answer if not query_answer_count else "{}|count-{}".format(query_answer, query_answer_count) + answers = answers if not query_answer_threshold else "{}|threshold-{}".format(answers, query_answer_threshold) captions = ( query_caption if not query_caption_highlight @@ -396,9 +376,7 @@ async def search( query.order_by(order_by) kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) kwargs["api_version"] = self._api_version - return AsyncSearchItemPaged( - self._client, query, kwargs, page_iterator_class=AsyncSearchPageIterator - ) + return AsyncSearchItemPaged(self._client, query, kwargs, page_iterator_class=AsyncSearchPageIterator) @distributed_trace_async async def suggest( @@ -481,9 +459,7 @@ async def suggest( query.order_by(order_by) kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) request = cast(SuggestRequest, query.request) - response = await self._client.documents.suggest_post( - suggest_request=request, **kwargs - ) + response = await self._client.documents.suggest_post(suggest_request=request, **kwargs) assert response.results is not None # Hint for mypy results = [r.as_dict() for r in response.results] return results @@ -560,17 +536,13 @@ async def autocomplete( kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) request = cast(AutocompleteRequest, query.request) - response = await self._client.documents.autocomplete_post( - autocomplete_request=request, **kwargs - ) + response = await self._client.documents.autocomplete_post(autocomplete_request=request, **kwargs) assert response.results is not None # Hint for mypy results = [r.as_dict() for r in response.results] return results # pylint:disable=client-method-missing-tracing-decorator-async - async def upload_documents( - self, documents: List[Dict], **kwargs: Any - ) -> List[IndexingResult]: + async def upload_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: """Upload documents to the Azure search index. An upload action is similar to an "upsert" where the document will be @@ -599,9 +571,7 @@ async def upload_documents( return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async, delete-operation-wrong-return-type - async def delete_documents( - self, documents: List[Dict], **kwargs: Any - ) -> List[IndexingResult]: + async def delete_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: """Delete documents from the Azure search index Delete removes the specified document from the index. Any field you @@ -635,9 +605,7 @@ async def delete_documents( return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async - async def merge_documents( - self, documents: List[Dict], **kwargs: Any - ) -> List[IndexingResult]: + async def merge_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: """Merge documents in to existing documents in the Azure search index. Merge updates an existing document with the specified fields. If the @@ -667,9 +635,7 @@ async def merge_documents( return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async - async def merge_or_upload_documents( - self, documents: List[Dict], **kwargs: Any - ) -> List[IndexingResult]: + async def merge_or_upload_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: """Merge documents in to existing documents in the Azure search index, or upload them if they do not yet exist. @@ -690,9 +656,7 @@ async def merge_or_upload_documents( return cast(List[IndexingResult], results) @distributed_trace_async - async def index_documents( - self, batch: IndexDocumentsBatch, **kwargs: Any - ) -> List[IndexingResult]: + async def index_documents(self, batch: IndexDocumentsBatch, **kwargs: Any) -> List[IndexingResult]: """Specify a document operations to perform as a batch. :param batch: A batch of document operations to perform. @@ -703,17 +667,13 @@ async def index_documents( """ return await self._index_documents_actions(actions=batch.actions, **kwargs) - async def _index_documents_actions( - self, actions: List[IndexAction], **kwargs: Any - ) -> List[IndexingResult]: + async def _index_documents_actions(self, actions: List[IndexAction], **kwargs: Any) -> List[IndexingResult]: error_map = {413: RequestEntityTooLargeError} kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) batch = IndexBatch(actions=actions) try: - batch_response = await self._client.documents.index( - batch=batch, error_map=error_map, **kwargs - ) + batch_response = await self._client.documents.index(batch=batch, error_map=error_map, **kwargs) return cast(List[IndexingResult], batch_response.results) except RequestEntityTooLargeError: if len(actions) == 1: From 68092342273505eb8ade2c2933b08d2d2695d577 Mon Sep 17 00:00:00 2001 From: xiangyan99 Date: Tue, 24 Oct 2023 14:42:17 -0700 Subject: [PATCH 5/7] Revert "black" This reverts commit 6454f841e7bb742f42478b72c31e875e5e770a4c. --- .../documents/aio/_search_client_async.py | 74 ++++++++++++++----- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py index b3f4d487af30..e513d2cdd3d4 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py @@ -66,7 +66,11 @@ class SearchClient(HeadersMixin): _client: SearchIndexClient def __init__( - self, endpoint: str, index_name: str, credential: Union[AzureKeyCredential, AsyncTokenCredential], **kwargs: Any + self, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, AsyncTokenCredential], + **kwargs: Any ) -> None: self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._index_documents_batch = IndexDocumentsBatch() @@ -85,7 +89,9 @@ def __init__( ) else: self._aad = True - authentication_policy = get_authentication_policy(credential, audience=audience, is_async=True) + authentication_policy = get_authentication_policy( + credential, audience=audience, is_async=True + ) self._client = SearchIndexClient( endpoint=endpoint, index_name=index_name, @@ -96,7 +102,9 @@ def __init__( ) def __repr__(self) -> str: - return "".format(repr(self._endpoint), repr(self._index_name))[:1024] + return "".format( + repr(self._endpoint), repr(self._index_name) + )[:1024] async def close(self) -> None: """Close the :class:`~azure.search.documents.aio.SearchClient` session. @@ -117,7 +125,9 @@ async def get_document_count(self, **kwargs: Any) -> int: return int(await self._client.documents.count(**kwargs)) @distributed_trace_async - async def get_document(self, key: str, selected_fields: Optional[List[str]] = None, **kwargs: Any) -> Dict: + async def get_document( + self, key: str, selected_fields: Optional[List[str]] = None, **kwargs: Any + ) -> Dict: """Retrieve a document from the Azure search index by its key. :param key: The primary key value for the document to retrieve @@ -137,7 +147,9 @@ async def get_document(self, key: str, selected_fields: Optional[List[str]] = No :caption: Get a specific document from the search index. """ kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) - result = await self._client.documents.get(key=key, selected_fields=selected_fields, **kwargs) + result = await self._client.documents.get( + key=key, selected_fields=selected_fields, **kwargs + ) return cast(dict, result) @distributed_trace_async @@ -328,8 +340,16 @@ async def search( include_total_result_count = include_total_count filter_arg = filter search_fields_str = ",".join(search_fields) if search_fields else None - answers = query_answer if not query_answer_count else "{}|count-{}".format(query_answer, query_answer_count) - answers = answers if not query_answer_threshold else "{}|threshold-{}".format(answers, query_answer_threshold) + answers = ( + query_answer + if not query_answer_count + else "{}|count-{}".format(query_answer, query_answer_count) + ) + answers = ( + answers + if not query_answer_threshold + else "{}|threshold-{}".format(answers, query_answer_threshold) + ) captions = ( query_caption if not query_caption_highlight @@ -376,7 +396,9 @@ async def search( query.order_by(order_by) kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) kwargs["api_version"] = self._api_version - return AsyncSearchItemPaged(self._client, query, kwargs, page_iterator_class=AsyncSearchPageIterator) + return AsyncSearchItemPaged( + self._client, query, kwargs, page_iterator_class=AsyncSearchPageIterator + ) @distributed_trace_async async def suggest( @@ -459,7 +481,9 @@ async def suggest( query.order_by(order_by) kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) request = cast(SuggestRequest, query.request) - response = await self._client.documents.suggest_post(suggest_request=request, **kwargs) + response = await self._client.documents.suggest_post( + suggest_request=request, **kwargs + ) assert response.results is not None # Hint for mypy results = [r.as_dict() for r in response.results] return results @@ -536,13 +560,17 @@ async def autocomplete( kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) request = cast(AutocompleteRequest, query.request) - response = await self._client.documents.autocomplete_post(autocomplete_request=request, **kwargs) + response = await self._client.documents.autocomplete_post( + autocomplete_request=request, **kwargs + ) assert response.results is not None # Hint for mypy results = [r.as_dict() for r in response.results] return results # pylint:disable=client-method-missing-tracing-decorator-async - async def upload_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: + async def upload_documents( + self, documents: List[Dict], **kwargs: Any + ) -> List[IndexingResult]: """Upload documents to the Azure search index. An upload action is similar to an "upsert" where the document will be @@ -571,7 +599,9 @@ async def upload_documents(self, documents: List[Dict], **kwargs: Any) -> List[I return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async, delete-operation-wrong-return-type - async def delete_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: + async def delete_documents( + self, documents: List[Dict], **kwargs: Any + ) -> List[IndexingResult]: """Delete documents from the Azure search index Delete removes the specified document from the index. Any field you @@ -605,7 +635,9 @@ async def delete_documents(self, documents: List[Dict], **kwargs: Any) -> List[I return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async - async def merge_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: + async def merge_documents( + self, documents: List[Dict], **kwargs: Any + ) -> List[IndexingResult]: """Merge documents in to existing documents in the Azure search index. Merge updates an existing document with the specified fields. If the @@ -635,7 +667,9 @@ async def merge_documents(self, documents: List[Dict], **kwargs: Any) -> List[In return cast(List[IndexingResult], results) # pylint:disable=client-method-missing-tracing-decorator-async - async def merge_or_upload_documents(self, documents: List[Dict], **kwargs: Any) -> List[IndexingResult]: + async def merge_or_upload_documents( + self, documents: List[Dict], **kwargs: Any + ) -> List[IndexingResult]: """Merge documents in to existing documents in the Azure search index, or upload them if they do not yet exist. @@ -656,7 +690,9 @@ async def merge_or_upload_documents(self, documents: List[Dict], **kwargs: Any) return cast(List[IndexingResult], results) @distributed_trace_async - async def index_documents(self, batch: IndexDocumentsBatch, **kwargs: Any) -> List[IndexingResult]: + async def index_documents( + self, batch: IndexDocumentsBatch, **kwargs: Any + ) -> List[IndexingResult]: """Specify a document operations to perform as a batch. :param batch: A batch of document operations to perform. @@ -667,13 +703,17 @@ async def index_documents(self, batch: IndexDocumentsBatch, **kwargs: Any) -> Li """ return await self._index_documents_actions(actions=batch.actions, **kwargs) - async def _index_documents_actions(self, actions: List[IndexAction], **kwargs: Any) -> List[IndexingResult]: + async def _index_documents_actions( + self, actions: List[IndexAction], **kwargs: Any + ) -> List[IndexingResult]: error_map = {413: RequestEntityTooLargeError} kwargs["headers"] = self._merge_client_headers(kwargs.get("headers")) batch = IndexBatch(actions=actions) try: - batch_response = await self._client.documents.index(batch=batch, error_map=error_map, **kwargs) + batch_response = await self._client.documents.index( + batch=batch, error_map=error_map, **kwargs + ) return cast(List[IndexingResult], batch_response.results) except RequestEntityTooLargeError: if len(actions) == 1: From 253e748030ebb5a2ace8f4aacde03ed93aebbac7 Mon Sep 17 00:00:00 2001 From: xiangyan99 Date: Tue, 24 Oct 2023 14:42:57 -0700 Subject: [PATCH 6/7] black --- .../azure-identity/azure/identity/_bearer_token_provider.py | 1 + .../azure-identity/azure/identity/aio/_bearer_token_provider.py | 1 + 2 files changed, 2 insertions(+) diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py index 728880f6fc77..a0132e9cc1a1 100644 --- a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -37,6 +37,7 @@ def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Call """ policy = BearerTokenCredentialPolicy(credential, *scopes) + def wrapper() -> str: request = _make_request() policy.on_request(request) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py index a7ae5c97c153..8609f26074a9 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py @@ -38,6 +38,7 @@ def get_bearer_token_provider(credential: AsyncTokenCredential, *scopes: str) -> """ policy = AsyncBearerTokenCredentialPolicy(credential, *scopes) + async def wrapper() -> str: request = _make_request() await policy.on_request(request) From 37584eeb190f399651f9333937448a7b3e6a4463 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 25 Oct 2023 19:59:51 +0000 Subject: [PATCH 7/7] Feedback --- .../azure-identity/azure/identity/_bearer_token_provider.py | 2 +- .../azure/identity/aio/_bearer_token_provider.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py index a0132e9cc1a1..209f46d46ef7 100644 --- a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -24,7 +24,7 @@ def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Call from azure.identity import DefaultAzureCredential, get_bearer_token_provider credential = DefaultAzureCredential() - bearer_token_provider = get_bearer_token_provider(credential, "https://storage.azure.com/.default") + bearer_token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default") # Usage request.headers["Authorization"] = "Bearer " + bearer_token_provider() diff --git a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py index 8609f26074a9..bde068e10558 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_bearer_token_provider.py @@ -24,14 +24,14 @@ def get_bearer_token_provider(credential: AsyncTokenCredential, *scopes: str) -> from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider credential = DefaultAzureCredential() - bearer_token_provider = get_bearer_token_provider(credential, "https://storage.azure.com/.default") + bearer_token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default") # Usage request.headers["Authorization"] = "Bearer " + await bearer_token_provider() :param credential: The credential used to authenticate the request. - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenCredential :param str scopes: The scopes required for the bearer token. :rtype: coroutine :return: A coroutine that returns a bearer token.