Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#### Features Added
* Added read_items API to provide an efficient method for retrieving multiple items in a single request. See [PR 42167](https://github.com/Azure/azure-sdk-for-python/pull/42167).
* Added ability to replace a container's indexing policy if a vector embedding policy was present. See [PR 42810](https://github.com/Azure/azure-sdk-for-python/pull/42810).
* Added support for Semantic Reranking. See [PR 42991](https://github.com/Azure/azure-sdk-for-python/pull/42991)
Comment thread
aayush3011 marked this conversation as resolved.
Outdated

#### Bugs Fixed
* Improved the resilience of Database Account Read metadata operation against short-lived network issues by increasing number of retries. See [PR 42525](https://github.com/Azure/azure-sdk-for-python/pull/42525).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from ._request_object import RequestObject
from ._retry_utility import ConnectionRetryPolicy
from ._routing import routing_map_provider, routing_range
from ._inference_service import _InferenceService
from .documents import ConnectionPolicy, DatabaseAccount
from .partition_key import (
_Undefined,
Expand Down Expand Up @@ -236,6 +237,10 @@ def __init__( # pylint: disable=too-many-statements
policies=policies
)

self._inference_service: Optional[_InferenceService] = None
if self.aad_credentials:
self._inference_service = _InferenceService(self)

# Query compatibility mode.
# Allows to specify compatibility mode used by client when making query requests. Should be removed when
# application/sql is no longer supported.
Expand Down Expand Up @@ -302,6 +307,10 @@ def _set_client_consistency_level(
else:
self.session = None

def _get_inference_service(self) -> Optional[_InferenceService]:
"""Get inference service instance"""
return self._inference_service

@property
def Session(self) -> Optional[_session.Session]:
"""Gets the session object from the client.
Expand Down
71 changes: 71 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_inference_auth_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# The MIT License (MIT)
# Copyright (c) 2014 Microsoft Corporation

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TypeVar, Any, MutableMapping, cast

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpRequest
from azure.core.credentials import AccessToken

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)


class InferenceServiceBearerTokenPolicy(BearerTokenCredentialPolicy):
"""Bearer token authentication policy for inference service.

This policy preserves the standard JWT Bearer token format required by
external inference services, unlike CosmosBearerTokenCredentialPolicy which
modifies tokens for Cosmos DB authentication.
"""

@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
"""Updates the Authorization header with the standard-bearer token format.

:param MutableMapping[str, str] headers: The HTTP Request headers
:param str token: The OAuth token.
"""
headers["Authorization"] = f"Bearer {token}"

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
"""Called before the policy sends a request.

The base implementation authorizes the request with a bearer token.

:param ~azure.core.pipeline.PipelineRequest request: the request
"""
super().on_request(request)
# The None-check for self._token is done in the parent on_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
super().authorize_request(request, *scopes, **kwargs)
# The None-check for self._token is done in the parent authorize_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
203 changes: 203 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# The MIT License (MIT)
# Copyright (c) 2014 Microsoft Corporation

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json
import urllib
from typing import Any, cast, Dict, List, Optional
from urllib3.util.retry import Retry

from azure.core import PipelineClient
from azure.core.exceptions import DecodeError
from azure.core.pipeline.policies import (ContentDecodePolicy, DistributedTracingPolicy, HeadersPolicy, HTTPPolicy,
NetworkTraceLoggingPolicy, UserAgentPolicy)
from azure.core.pipeline.transport import HttpRequest
from azure.core.utils import CaseInsensitiveDict

from . import exceptions
from ._cosmos_responses import CosmosDict
from ._inference_auth_policy import InferenceServiceBearerTokenPolicy
from ._retry_utility import ConnectionRetryPolicy
from .http_constants import HttpHeaders

_DEFAULT_SCOPE = "https://dbinference.azure.com/.default"

# cspell:ignore rerank reranker reranking
# pylint: disable=protected-access


class _InferenceService:
"""Internal client for inference service."""

def __init__(self, cosmos_client_connection):
"""Initialize semantic reranker with credentials and endpoint information.
Comment thread
aayush3011 marked this conversation as resolved.
Outdated

:param cosmos_client_connection: Optional reference to cosmos client connection for accessing settings
:type cosmos_client_connection: Optional[CosmosClientConnection]
"""
self._client_connection = cosmos_client_connection
self._aad_credentials = self._client_connection.aad_credentials
self._token_scope = _DEFAULT_SCOPE

parsed = urllib.parse.urlparse(self._client_connection.url_connection)
self._account_name = parsed.hostname.split('.')[0] if parsed.hostname else ""
Comment thread
aayush3011 marked this conversation as resolved.
Outdated

self._inference_endpoint = f"https://{self._account_name}.dbinference.azure.com/inference/semanticReranking"
self._inference_pipeline_client = self._create_inference_pipeline_client()

def _create_inference_pipeline_client(self) -> PipelineClient:
"""Create a pipeline for inference requests.

:returns: A PipelineClient configured for inference calls.
:rtype: ~azure.core.PipelineClient
"""
access_token = self._aad_credentials
auth_policy = InferenceServiceBearerTokenPolicy(access_token, self._token_scope)

connection_policy = self._client_connection.connection_policy
retry_policy = None
if isinstance(connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
retry_policy = connection_policy.ConnectionRetryConfiguration
elif isinstance(connection_policy.ConnectionRetryConfiguration, int):
retry_policy = ConnectionRetryPolicy(total=connection_policy.ConnectionRetryConfiguration)
elif isinstance(connection_policy.ConnectionRetryConfiguration, Retry):
# Convert a urllib3 retry policy to a Pipeline policy
retry_policy = ConnectionRetryPolicy(
Comment thread
aayush3011 marked this conversation as resolved.
retry_total=connection_policy.ConnectionRetryConfiguration.total,
retry_connect=connection_policy.ConnectionRetryConfiguration.connect,
retry_read=connection_policy.ConnectionRetryConfiguration.read,
retry_status=connection_policy.ConnectionRetryConfiguration.status,
retry_backoff_max=connection_policy.ConnectionRetryConfiguration.DEFAULT_BACKOFF_MAX,
retry_on_status_codes=list(connection_policy.ConnectionRetryConfiguration.status_forcelist),
retry_backoff_factor=connection_policy.ConnectionRetryConfiguration.backoff_factor
)
else:
raise TypeError(
"Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry")
Comment thread
aayush3011 marked this conversation as resolved.
policies = [
HeadersPolicy(),
UserAgentPolicy(base_user_agent=self._get_user_agent()),
ContentDecodePolicy(),
auth_policy,
retry_policy,
NetworkTraceLoggingPolicy(),
DistributedTracingPolicy(),
]
Comment thread
aayush3011 marked this conversation as resolved.

return PipelineClient(
Comment thread
aayush3011 marked this conversation as resolved.
base_url=self._inference_endpoint,
policies=policies
)

def _get_user_agent(self) -> str:
"""Return the user agent string for inference pipeline.

:returns: User agent string.
:rtype: str
"""
if self._client_connection and hasattr(self._client_connection, '_user_agent'):
return self._client_connection._user_agent + "_inference"
return "azure-cosmos-python-sdk-inference"

def rerank(
self,
reranking_context: str,
documents: List[str],
semantic_reranking_options: Optional[Dict[str, Any]] = None,
) -> CosmosDict:
"""Rerank documents using the semantic reranking service.

:param reranking_context: Query / context string used to score documents.
:type reranking_context: str
:param documents: List of document strings to rerank.
:type documents: List[str]
:param semantic_reranking_options: Optional dictionary of tuning parameters. Supported keys:
* return_documents (bool): Include original document text in results. Default True.
* top_k (int): Limit number of scored documents returned.
* batch_size (int): Batch size for internal scoring operations.
* sort (bool): If True (default) results are ordered by descending score.
:type semantic_reranking_options: Optional[Dict[str, Any]]
:returns: Reranking result payload.
:rtype: ~azure.cosmos.CosmosDict[str, Any]
:raises ~azure.cosmos.exceptions.CosmosHttpResponseError: On HTTP or service error.
"""
try:
body = {
"query": reranking_context,
"documents": documents,
}

if semantic_reranking_options:
if "return_documents" in semantic_reranking_options:
body["return_documents"] = semantic_reranking_options["return_documents"]
if "top_k" in semantic_reranking_options:
body["top_k"] = semantic_reranking_options["top_k"]
if "batch_size" in semantic_reranking_options:
body["batch_size"] = semantic_reranking_options["batch_size"]
if "sort" in semantic_reranking_options:
body["sort"] = semantic_reranking_options["sort"]

headers = {
HttpHeaders.ContentType: "application/json"
}

request = HttpRequest(
method="POST",
url=self._inference_endpoint,
headers=headers,
data=json.dumps(body, separators=(",", ":"))
)

pipeline_response = self._inference_pipeline_client._pipeline.run(request)
response = pipeline_response.http_response
response_headers = cast(CaseInsensitiveDict, response.headers)

data = response.body()
if data:
data = data.decode("utf-8")

if response.status_code == 404:
raise exceptions.CosmosResourceNotFoundError(message=data, response=response)
if response.status_code == 409:
raise exceptions.CosmosResourceExistsError(message=data, response=response)
if response.status_code == 412:
raise exceptions.CosmosAccessConditionFailedError(message=data, response=response)
if response.status_code >= 400:
raise exceptions.CosmosHttpResponseError(message=data, response=response)
Comment thread
aayush3011 marked this conversation as resolved.
Outdated

result = None
if data:
try:
result = json.loads(data)
except Exception as e:
raise DecodeError(
message="Failed to decode JSON data: {}".format(e),
response=response,
error=e) from e

return CosmosDict(result, response_headers=response_headers)

except Exception as e:
if isinstance(e, (exceptions.CosmosHttpResponseError, exceptions.CosmosResourceNotFoundError)):
raise
raise exceptions.CosmosHttpResponseError(
message=f"Semantic reranking failed: {str(e)}",
response=None
) from e
Loading
Loading