Skip to content

Commit 3104e9b

Browse files
authored
[Cosmos] Typing part 3 (#33738)
* Started async typing * Enable mypy * pylint fixes * More typing fixes * pylint fixes * Fix import * Finished database typing * Mypy fixes * Typing for pipeline client * More typing * Some test fixes * Fix tests * Update logging policy
1 parent f206e92 commit 3104e9b

19 files changed

Lines changed: 1482 additions & 658 deletions

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import TypeVar, Any, MutableMapping
6+
from typing import TypeVar, Any, MutableMapping, cast
77

88
from azure.core.pipeline import PipelineRequest
99
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
1010
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1111
from azure.core.rest import HttpRequest
12+
from azure.core.credentials import AccessToken
1213

1314
from .http_constants import HttpHeaders
1415

@@ -34,7 +35,8 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3435
:param ~azure.core.pipeline.PipelineRequest request: the request
3536
"""
3637
super().on_request(request)
37-
self._update_headers(request.http_request.headers, self._token.token)
38+
# The None-check for self._token is done in the parent on_request
39+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
3840

3941
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
4042
"""Acquire a token from the credential and authorize the request with it.
@@ -46,4 +48,5 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
4648
:param str scopes: required scopes of authentication
4749
"""
4850
super().authorize_request(request, *scopes, **kwargs)
49-
self._update_headers(request.http_request.headers, self._token.token)
51+
# The None-check for self._token is done in the parent authorize_request
52+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
if TYPE_CHECKING:
4545
from ._cosmos_client_connection import CosmosClientConnection
46+
from .aio._cosmos_client_connection_async import CosmosClientConnection as AsyncClientConnection
4647

4748

4849
_COMMON_OPTIONS = {
@@ -107,7 +108,7 @@ def build_options(kwargs: Dict[str, Any]) -> Dict[str, Any]:
107108

108109

109110
def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
110-
cosmos_client_connection: "CosmosClientConnection",
111+
cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"],
111112
default_headers: Mapping[str, Any],
112113
verb: str,
113114
path: str,

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from azure.core.credentials import TokenCredential
3232
from azure.core.paging import ItemPaged
3333
from azure.core import PipelineClient
34+
from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import
3435
from azure.core.pipeline.policies import (
3536
HTTPPolicy,
3637
ContentDecodePolicy,
@@ -205,7 +206,11 @@ def __init__(
205206
]
206207

207208
transport = kwargs.pop("transport", None)
208-
self.pipeline_client = PipelineClient(base_url=url_connection, transport=transport, policies=policies)
209+
self.pipeline_client: PipelineClient[HttpRequest, HttpResponse] = PipelineClient(
210+
base_url=url_connection,
211+
transport=transport,
212+
policies=policies
213+
)
209214

210215
# Query compatibility mode.
211216
# Allows to specify compatibility mode used by client when making query requests. Should be removed when
@@ -404,7 +409,7 @@ def QueryDatabases(
404409
if options is None:
405410
options = {}
406411

407-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
412+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
408413
return (
409414
self.__QueryFeed(
410415
"/dbs", "dbs", "", lambda r: r["Databases"],
@@ -466,7 +471,7 @@ def QueryContainers(
466471
path = base.GetPathFromLink(database_link, "colls")
467472
database_id = base.GetResourceIdOrFullNameFromLink(database_link)
468473

469-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
474+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
470475
return (
471476
self.__QueryFeed(
472477
path, "colls", database_id, lambda r: r["DocumentCollections"],
@@ -699,7 +704,7 @@ def QueryUsers(
699704
path = base.GetPathFromLink(database_link, "users")
700705
database_id = base.GetResourceIdOrFullNameFromLink(database_link)
701706

702-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
707+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
703708
return (
704709
self.__QueryFeed(
705710
path, "users", database_id, lambda r: r["Users"],
@@ -835,7 +840,7 @@ def ReadPermissions(
835840
user_link: str,
836841
options: Optional[Mapping[str, Any]] = None,
837842
**kwargs: Any
838-
) -> Dict[str, Any]:
843+
) -> ItemPaged[Dict[str, Any]]:
839844
"""Reads all permissions for a user.
840845
841846
:param str user_link:
@@ -881,7 +886,7 @@ def QueryPermissions(
881886
path = base.GetPathFromLink(user_link, "permissions")
882887
user_id = base.GetResourceIdOrFullNameFromLink(user_link)
883888

884-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
889+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
885890
return (
886891
self.__QueryFeed(
887892
path, "permissions", user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs
@@ -1069,7 +1074,7 @@ def QueryItems(
10691074
path = base.GetPathFromLink(database_or_container_link, "docs")
10701075
collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link)
10711076

1072-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1077+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
10731078
return (
10741079
self.__QueryFeed(
10751080
path,
@@ -1163,7 +1168,7 @@ def _QueryChangeFeed(
11631168
path = base.GetPathFromLink(collection_link, resource_key)
11641169
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
11651170

1166-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1171+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
11671172
return (
11681173
self.__QueryFeed(
11691174
path,
@@ -1235,7 +1240,7 @@ def _QueryPartitionKeyRanges(
12351240
path = base.GetPathFromLink(collection_link, "pkranges")
12361241
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
12371242

1238-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1243+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
12391244
return (
12401245
self.__QueryFeed(
12411246
path, "pkranges", collection_id, lambda r: r["PartitionKeyRanges"],
@@ -1438,7 +1443,7 @@ def QueryTriggers(
14381443
path = base.GetPathFromLink(collection_link, "triggers")
14391444
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
14401445

1441-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1446+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
14421447
return (
14431448
self.__QueryFeed(
14441449
path, "triggers", collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, **kwargs
@@ -1597,7 +1602,7 @@ def QueryUserDefinedFunctions(
15971602
path = base.GetPathFromLink(collection_link, "udfs")
15981603
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
15991604

1600-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1605+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
16011606
return (
16021607
self.__QueryFeed(
16031608
path, "udfs", collection_id, lambda r: r["UserDefinedFunctions"],
@@ -1757,7 +1762,7 @@ def QueryStoredProcedures(
17571762
path = base.GetPathFromLink(collection_link, "sprocs")
17581763
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
17591764

1760-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1765+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
17611766
return (
17621767
self.__QueryFeed(
17631768
path, "sprocs", collection_id, lambda r: r["StoredProcedures"],
@@ -1915,7 +1920,7 @@ def QueryConflicts(
19151920
path = base.GetPathFromLink(collection_link, "conflicts")
19161921
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
19171922

1918-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
1923+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
19191924
return (
19201925
self.__QueryFeed(
19211926
path, "conflicts", collection_id, lambda r: r["Conflicts"],
@@ -2528,7 +2533,7 @@ def QueryOffers(
25282533
if options is None:
25292534
options = {}
25302535

2531-
def fetch_fn(options: Mapping[str, Any]) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
2536+
def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
25322537
return (
25332538
self.__QueryFeed(
25342539
"/offers", "offers", "", lambda r: r["Offers"], lambda _, b: b, query, options, **kwargs
@@ -2779,7 +2784,7 @@ def __Get(
27792784
self,
27802785
path: str,
27812786
request_params: RequestObject,
2782-
req_headers: Mapping[str, Any],
2787+
req_headers: Dict[str, Any],
27832788
**kwargs: Any
27842789
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
27852790
"""Azure Cosmos 'GET' http request.
@@ -2807,7 +2812,7 @@ def __Post(
28072812
path: str,
28082813
request_params: RequestObject,
28092814
body: Optional[Union[str, List[Dict[str, Any]], Dict[str, Any]]],
2810-
req_headers: Mapping[str, Any],
2815+
req_headers: Dict[str, Any],
28112816
**kwargs: Any
28122817
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
28132818
"""Azure Cosmos 'POST' http request.
@@ -2836,7 +2841,7 @@ def __Put(
28362841
path: str,
28372842
request_params: RequestObject,
28382843
body: Dict[str, Any],
2839-
req_headers: Mapping[str, Any],
2844+
req_headers: Dict[str, Any],
28402845
**kwargs: Any
28412846
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
28422847
"""Azure Cosmos 'PUT' http request.
@@ -2865,7 +2870,7 @@ def __Patch(
28652870
path: str,
28662871
request_params: RequestObject,
28672872
request_data: Dict[str, Any],
2868-
req_headers: Mapping[str, Any],
2873+
req_headers: Dict[str, Any],
28692874
**kwargs: Any
28702875
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
28712876
"""Azure Cosmos 'PATCH' http request.
@@ -2893,7 +2898,7 @@ def __Delete(
28932898
self,
28942899
path: str,
28952900
request_params: RequestObject,
2896-
req_headers: Mapping[str, Any],
2901+
req_headers: Dict[str, Any],
28972902
**kwargs: Any
28982903
) -> Tuple[None, Dict[str, Any]]:
28992904
"""Azure Cosmos 'DELETE' http request.
@@ -2924,7 +2929,7 @@ def QueryFeed(
29242929
options: Mapping[str, Any],
29252930
partition_key_range_id: Optional[str] = None,
29262931
**kwargs: Any
2927-
) -> Tuple[ItemPaged[Dict[str, Any]], Dict[str, Any]]:
2932+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
29282933
"""Query Feed for Document Collection resource.
29292934
29302935
:param str path: Path to the document collection.
@@ -2963,7 +2968,7 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements
29632968
response_hook: Optional[Callable[[Mapping[str, Any], Mapping[str, Any]], None]] = None,
29642969
is_query_plan: bool = False,
29652970
**kwargs: Any
2966-
) -> ItemPaged[Dict[str, Any]]:
2971+
) -> List[Dict[str, Any]]:
29672972
"""Query for more than one Azure Cosmos resources.
29682973
29692974
:param str path:
@@ -3114,7 +3119,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
31143119

31153120
return __GetBodiesFromQueryResult(result)
31163121

3117-
def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: Any) -> ItemPaged[Dict[str, Any]]:
3122+
def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: Any) -> List[Dict[str, Any]]:
31183123
supported_query_features = (documents._QueryFeature.Aggregate + "," +
31193124
documents._QueryFeature.CompositeAggregate + "," +
31203125
documents._QueryFeature.Distinct + "," +

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,18 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
7676
def on_response(
7777
self,
7878
request: PipelineRequest[HTTPRequestType],
79-
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
79+
response: PipelineResponse[HTTPRequestType, HTTPResponseType], # type: ignore[override]
8080
) -> None:
8181
super().on_response(request, response)
8282
if self._enable_diagnostics_logging:
8383
http_response = response.http_response
8484
options = response.context.options
8585
logger = request.context.setdefault("logger", options.pop("logger", self.logger))
8686
try:
87-
logger.info("Elapsed time in seconds: {}".format(time.time() - request.context.get("start_time")))
87+
if "start_time" in request.context:
88+
logger.info("Elapsed time in seconds: {}".format(time.time() - request.context["start_time"]))
89+
else:
90+
logger.info("Elapsed time in seconds: unknown")
8891
if http_response.status_code >= 400:
8992
logger.info("Response error message: %r", _format_error(http_response.text()))
9093
except Exception as err: # pylint: disable=broad-except

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# license information.
55
# -------------------------------------------------------------------------
66

7-
from typing import Any, MutableMapping, TypeVar
7+
from typing import Any, MutableMapping, TypeVar, cast
88

99
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
1010
from azure.core.pipeline import PipelineRequest
1111
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1212
from azure.core.rest import HttpRequest
13+
from azure.core.credentials import AccessToken
1314

1415
from ..http_constants import HttpHeaders
1516

@@ -35,7 +36,8 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3536
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
3637
"""
3738
await super().on_request(request)
38-
self._update_headers(request.http_request.headers, self._token.token)
39+
# The None-check for self._token is done in the parent on_request
40+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
3941

4042
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
4143
"""Acquire a token from the credential and authorize the request with it.
@@ -47,4 +49,5 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
4749
:param str scopes: required scopes of authentication
4850
"""
4951
await super().authorize_request(request, *scopes, **kwargs)
50-
self._update_headers(request.http_request.headers, self._token.token)
52+
# The None-check for self._token is done in the parent authorize_request
53+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

0 commit comments

Comments
 (0)