-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Cosmos]: Adding Semantic Reranker API #42991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
aayush3011
merged 26 commits into
Azure:main
from
aayush3011:users/akataria/semanticReranking
Oct 2, 2025
Merged
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
a473cf9
Adding semantic reranker api
aayush3011 969f615
updating changelog
aayush3011 e48b95e
Update sdk/cosmos/azure-cosmos/azure/cosmos/container.py
aayush3011 a9269d2
updating docstring
aayush3011 dd6a690
Adding async api changes
aayush3011 4b27de9
Merge branch 'main' into users/akataria/semanticReranking
aayush3011 227303f
Resolving comments
aayush3011 23e0c2a
Fixing build
aayush3011 9812e55
Fixing build
aayush3011 b7bbc89
Fixing build
aayush3011 dbdb4eb
Fixing build
aayush3011 28b32d6
Resolving comments
aayush3011 4024af5
Updating changelog
aayush3011 3ee7b60
Merge branch 'main' into users/akataria/semanticReranking
aayush3011 5303ebc
Adding env variable for the inference service endpoint
aayush3011 033e6a7
Adding env variable for the inference service endpoint
aayush3011 930dbf5
Fixing build
aayush3011 f5e25a7
Fixing build
aayush3011 a3e7357
Fixing build
aayush3011 e31c579
Fixing build
aayush3011 0276afd
Fixing build
aayush3011 b59c50b
Fixing build
aayush3011 94cabdf
Fixing build
aayush3011 867d1f0
Resolving comments
aayush3011 6e057ac
Resolving comments
aayush3011 c03859d
Resolving comments
aayush3011 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
sdk/cosmos/azure-cosmos/azure/cosmos/_inference_auth_policy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # The MIT License (MIT) | ||
| # Copyright (c) 2014 Microsoft Corporation | ||
|
|
||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| # of this software and associated documentation files (the "Software"), to deal | ||
| # in the Software without restriction, including without limitation the rights | ||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| # copies of the Software, and to permit persons to whom the Software is | ||
| # furnished to do so, subject to the following conditions: | ||
|
|
||
| # The above copyright notice and this permission notice shall be included in all | ||
| # copies or substantial portions of the Software. | ||
|
|
||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| # SOFTWARE. | ||
| from typing import TypeVar, Any, MutableMapping, cast | ||
|
|
||
| from azure.core.pipeline import PipelineRequest | ||
| from azure.core.pipeline.policies import BearerTokenCredentialPolicy | ||
| from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest | ||
| from azure.core.rest import HttpRequest | ||
| from azure.core.credentials import AccessToken | ||
|
|
||
| HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) | ||
|
|
||
|
|
||
| class InferenceServiceBearerTokenPolicy(BearerTokenCredentialPolicy): | ||
| """Bearer token authentication policy for inference service. | ||
|
|
||
| This policy preserves the standard JWT Bearer token format required by | ||
| external inference services, unlike CosmosBearerTokenCredentialPolicy which | ||
| modifies tokens for Cosmos DB authentication. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _update_headers(headers: MutableMapping[str, str], token: str) -> None: | ||
| """Updates the Authorization header with the standard-bearer token format. | ||
|
|
||
| :param MutableMapping[str, str] headers: The HTTP Request headers | ||
| :param str token: The OAuth token. | ||
| """ | ||
| headers["Authorization"] = f"Bearer {token}" | ||
|
|
||
| def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: | ||
| """Called before the policy sends a request. | ||
|
|
||
| The base implementation authorizes the request with a bearer token. | ||
|
|
||
| :param ~azure.core.pipeline.PipelineRequest request: the request | ||
| """ | ||
| super().on_request(request) | ||
| # The None-check for self._token is done in the parent on_request | ||
| self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) | ||
|
|
||
| def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: | ||
| """Acquire a token from the credential and authorize the request with it. | ||
|
|
||
| Keyword arguments are passed to the credential's get_token method. The token will be cached and used to | ||
| authorize future requests. | ||
|
|
||
| :param ~azure.core.pipeline.PipelineRequest request: the request | ||
| :param str scopes: required scopes of authentication | ||
| """ | ||
| super().authorize_request(request, *scopes, **kwargs) | ||
| # The None-check for self._token is done in the parent authorize_request | ||
| self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) |
203 changes: 203 additions & 0 deletions
203
sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| # The MIT License (MIT) | ||
| # Copyright (c) 2014 Microsoft Corporation | ||
|
|
||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| # of this software and associated documentation files (the "Software"), to deal | ||
| # in the Software without restriction, including without limitation the rights | ||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| # copies of the Software, and to permit persons to whom the Software is | ||
| # furnished to do so, subject to the following conditions: | ||
|
|
||
| # The above copyright notice and this permission notice shall be included in all | ||
| # copies or substantial portions of the Software. | ||
|
|
||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| # SOFTWARE. | ||
|
|
||
| import json | ||
| import urllib | ||
| from typing import Any, cast, Dict, List, Optional | ||
| from urllib3.util.retry import Retry | ||
|
|
||
| from azure.core import PipelineClient | ||
| from azure.core.exceptions import DecodeError | ||
| from azure.core.pipeline.policies import (ContentDecodePolicy, DistributedTracingPolicy, HeadersPolicy, HTTPPolicy, | ||
| NetworkTraceLoggingPolicy, UserAgentPolicy) | ||
| from azure.core.pipeline.transport import HttpRequest | ||
| from azure.core.utils import CaseInsensitiveDict | ||
|
|
||
| from . import exceptions | ||
| from ._cosmos_responses import CosmosDict | ||
| from ._inference_auth_policy import InferenceServiceBearerTokenPolicy | ||
| from ._retry_utility import ConnectionRetryPolicy | ||
| from .http_constants import HttpHeaders | ||
|
|
||
| _DEFAULT_SCOPE = "https://dbinference.azure.com/.default" | ||
|
|
||
| # cspell:ignore rerank reranker reranking | ||
| # pylint: disable=protected-access | ||
|
|
||
|
|
||
| class _InferenceService: | ||
| """Internal client for inference service.""" | ||
|
|
||
| def __init__(self, cosmos_client_connection): | ||
| """Initialize semantic reranker with credentials and endpoint information. | ||
|
aayush3011 marked this conversation as resolved.
Outdated
|
||
|
|
||
| :param cosmos_client_connection: Optional reference to cosmos client connection for accessing settings | ||
| :type cosmos_client_connection: Optional[CosmosClientConnection] | ||
| """ | ||
| self._client_connection = cosmos_client_connection | ||
| self._aad_credentials = self._client_connection.aad_credentials | ||
| self._token_scope = _DEFAULT_SCOPE | ||
|
|
||
| parsed = urllib.parse.urlparse(self._client_connection.url_connection) | ||
| self._account_name = parsed.hostname.split('.')[0] if parsed.hostname else "" | ||
|
aayush3011 marked this conversation as resolved.
Outdated
|
||
|
|
||
| self._inference_endpoint = f"https://{self._account_name}.dbinference.azure.com/inference/semanticReranking" | ||
| self._inference_pipeline_client = self._create_inference_pipeline_client() | ||
|
|
||
| def _create_inference_pipeline_client(self) -> PipelineClient: | ||
| """Create a pipeline for inference requests. | ||
|
|
||
| :returns: A PipelineClient configured for inference calls. | ||
| :rtype: ~azure.core.PipelineClient | ||
| """ | ||
| access_token = self._aad_credentials | ||
| auth_policy = InferenceServiceBearerTokenPolicy(access_token, self._token_scope) | ||
|
|
||
| connection_policy = self._client_connection.connection_policy | ||
| retry_policy = None | ||
| if isinstance(connection_policy.ConnectionRetryConfiguration, HTTPPolicy): | ||
| retry_policy = connection_policy.ConnectionRetryConfiguration | ||
| elif isinstance(connection_policy.ConnectionRetryConfiguration, int): | ||
| retry_policy = ConnectionRetryPolicy(total=connection_policy.ConnectionRetryConfiguration) | ||
| elif isinstance(connection_policy.ConnectionRetryConfiguration, Retry): | ||
| # Convert a urllib3 retry policy to a Pipeline policy | ||
| retry_policy = ConnectionRetryPolicy( | ||
|
aayush3011 marked this conversation as resolved.
|
||
| retry_total=connection_policy.ConnectionRetryConfiguration.total, | ||
| retry_connect=connection_policy.ConnectionRetryConfiguration.connect, | ||
| retry_read=connection_policy.ConnectionRetryConfiguration.read, | ||
| retry_status=connection_policy.ConnectionRetryConfiguration.status, | ||
| retry_backoff_max=connection_policy.ConnectionRetryConfiguration.DEFAULT_BACKOFF_MAX, | ||
| retry_on_status_codes=list(connection_policy.ConnectionRetryConfiguration.status_forcelist), | ||
| retry_backoff_factor=connection_policy.ConnectionRetryConfiguration.backoff_factor | ||
| ) | ||
| else: | ||
| raise TypeError( | ||
| "Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry") | ||
|
aayush3011 marked this conversation as resolved.
|
||
| policies = [ | ||
| HeadersPolicy(), | ||
| UserAgentPolicy(base_user_agent=self._get_user_agent()), | ||
| ContentDecodePolicy(), | ||
| auth_policy, | ||
| retry_policy, | ||
| NetworkTraceLoggingPolicy(), | ||
| DistributedTracingPolicy(), | ||
| ] | ||
|
aayush3011 marked this conversation as resolved.
|
||
|
|
||
| return PipelineClient( | ||
|
aayush3011 marked this conversation as resolved.
|
||
| base_url=self._inference_endpoint, | ||
| policies=policies | ||
| ) | ||
|
|
||
| def _get_user_agent(self) -> str: | ||
| """Return the user agent string for inference pipeline. | ||
|
|
||
| :returns: User agent string. | ||
| :rtype: str | ||
| """ | ||
| if self._client_connection and hasattr(self._client_connection, '_user_agent'): | ||
| return self._client_connection._user_agent + "_inference" | ||
| return "azure-cosmos-python-sdk-inference" | ||
|
|
||
| def rerank( | ||
| self, | ||
| reranking_context: str, | ||
| documents: List[str], | ||
| semantic_reranking_options: Optional[Dict[str, Any]] = None, | ||
| ) -> CosmosDict: | ||
| """Rerank documents using the semantic reranking service. | ||
|
|
||
| :param reranking_context: Query / context string used to score documents. | ||
| :type reranking_context: str | ||
| :param documents: List of document strings to rerank. | ||
| :type documents: List[str] | ||
| :param semantic_reranking_options: Optional dictionary of tuning parameters. Supported keys: | ||
| * return_documents (bool): Include original document text in results. Default True. | ||
| * top_k (int): Limit number of scored documents returned. | ||
| * batch_size (int): Batch size for internal scoring operations. | ||
| * sort (bool): If True (default) results are ordered by descending score. | ||
| :type semantic_reranking_options: Optional[Dict[str, Any]] | ||
| :returns: Reranking result payload. | ||
| :rtype: ~azure.cosmos.CosmosDict[str, Any] | ||
| :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: On HTTP or service error. | ||
| """ | ||
| try: | ||
| body = { | ||
| "query": reranking_context, | ||
| "documents": documents, | ||
| } | ||
|
|
||
| if semantic_reranking_options: | ||
| if "return_documents" in semantic_reranking_options: | ||
| body["return_documents"] = semantic_reranking_options["return_documents"] | ||
| if "top_k" in semantic_reranking_options: | ||
| body["top_k"] = semantic_reranking_options["top_k"] | ||
| if "batch_size" in semantic_reranking_options: | ||
| body["batch_size"] = semantic_reranking_options["batch_size"] | ||
| if "sort" in semantic_reranking_options: | ||
| body["sort"] = semantic_reranking_options["sort"] | ||
|
|
||
| headers = { | ||
| HttpHeaders.ContentType: "application/json" | ||
| } | ||
|
|
||
| request = HttpRequest( | ||
| method="POST", | ||
| url=self._inference_endpoint, | ||
| headers=headers, | ||
| data=json.dumps(body, separators=(",", ":")) | ||
| ) | ||
|
|
||
| pipeline_response = self._inference_pipeline_client._pipeline.run(request) | ||
| response = pipeline_response.http_response | ||
| response_headers = cast(CaseInsensitiveDict, response.headers) | ||
|
|
||
| data = response.body() | ||
| if data: | ||
| data = data.decode("utf-8") | ||
|
|
||
| if response.status_code == 404: | ||
| raise exceptions.CosmosResourceNotFoundError(message=data, response=response) | ||
| if response.status_code == 409: | ||
| raise exceptions.CosmosResourceExistsError(message=data, response=response) | ||
| if response.status_code == 412: | ||
| raise exceptions.CosmosAccessConditionFailedError(message=data, response=response) | ||
| if response.status_code >= 400: | ||
| raise exceptions.CosmosHttpResponseError(message=data, response=response) | ||
|
aayush3011 marked this conversation as resolved.
Outdated
|
||
|
|
||
| result = None | ||
| if data: | ||
| try: | ||
| result = json.loads(data) | ||
| except Exception as e: | ||
| raise DecodeError( | ||
| message="Failed to decode JSON data: {}".format(e), | ||
| response=response, | ||
| error=e) from e | ||
|
|
||
| return CosmosDict(result, response_headers=response_headers) | ||
|
|
||
| except Exception as e: | ||
| if isinstance(e, (exceptions.CosmosHttpResponseError, exceptions.CosmosResourceNotFoundError)): | ||
| raise | ||
| raise exceptions.CosmosHttpResponseError( | ||
| message=f"Semantic reranking failed: {str(e)}", | ||
| response=None | ||
| ) from e | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.