Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#### Bugs Fixed

#### Other Changes
* Removed dual endpoint tracking from the sdk. See [PR 40451](https://github.com/Azure/azure-sdk-for-python/pull/40451).

### 4.14.0b4 (2025-09-11)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
try:
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
self._database_account_cache = database_account
self.location_cache.mark_endpoint_available(self.DefaultEndpoint)
return database_account, self.DefaultEndpoint
# If for any reason(non-globaldb related), we are not able to get the database
# account from the above call to GetDatabaseAccount, we would try to get this
Expand All @@ -152,9 +151,6 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
# until we get the database account and return None at the end, if we are not able
# to get that info from any endpoints
except (exceptions.CosmosHttpResponseError, AzureError):
# when atm is available, L: 145, 146 should be removed as the global endpoint shouldn't be used
# for dataplane operations anymore
self._mark_endpoint_unavailable(self.DefaultEndpoint)
for location_name in self.PreferredLocations:
locational_endpoint = LocationCache.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
try:
Expand Down
128 changes: 23 additions & 105 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from . import documents, _base as base
from .http_constants import ResourceType
from .documents import _OperationType, ConnectionPolicy
from .documents import ConnectionPolicy
from ._request_object import RequestObject

# pylint: disable=protected-access
Expand All @@ -43,40 +43,27 @@ class EndpointOperationType(object):
WriteType = "Write"

class RegionalRoutingContext(object):
def __init__(self, primary_endpoint: str, alternate_endpoint: str):
def __init__(self, primary_endpoint: str):
self.primary_endpoint: str = primary_endpoint
self.alternate_endpoint: str = alternate_endpoint

def set_primary(self, endpoint: str):
self.primary_endpoint = endpoint

def set_alternate(self, endpoint: str):
self.alternate_endpoint = endpoint

def get_primary(self):
return self.primary_endpoint

def get_alternate(self):
return self.alternate_endpoint

def __eq__(self, other):
return (self.primary_endpoint == other.primary_endpoint
and self.alternate_endpoint == other.alternate_endpoint)
return self.primary_endpoint == other.primary_endpoint

def __str__(self):
return "Primary: " + self.primary_endpoint + ", Alternate: " + self.alternate_endpoint
return "Primary: " + self.primary_endpoint

def get_endpoints_by_location(new_locations: List[Dict[str, str]],
old_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext],
default_regional_endpoint: RegionalRoutingContext,
writes: bool,
use_multiple_write_locations: bool):
def get_regional_routing_contexts_by_loc(new_locations: List[Dict[str, str]]):
# construct from previous object
regional_routing_context_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict()
regional_routing_contexts_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict()
parsed_locations = []


for new_location in new_locations: # pylint: disable=too-many-nested-blocks
for new_location in new_locations:
# if name in new_location and same for database account endpoint
if "name" in new_location and "databaseAccountEndpoint" in new_location:
if not new_location["name"]:
Expand All @@ -85,44 +72,19 @@ def get_endpoints_by_location(new_locations: List[Dict[str, str]],
try:
region_uri = new_location["databaseAccountEndpoint"]
parsed_locations.append(new_location["name"])
if not writes or use_multiple_write_locations:
regional_object = RegionalRoutingContext(region_uri, region_uri)
elif new_location["name"] in old_regional_routing_contexts_by_location:
regional_object = old_regional_routing_contexts_by_location[new_location["name"]]
current = regional_object.get_primary()
# swap the previous with current and current with new region_uri received from the gateway
if current != region_uri:
regional_object.set_alternate(current)
regional_object.set_primary(region_uri)
# This is the bootstrapping condition
else:
regional_object = RegionalRoutingContext(region_uri, region_uri)
# if it is for writes, then we update the previous to default_endpoint
if writes:
# if region_uri is different than global endpoint set global endpoint
# as fallback
# else construct regional uri
if region_uri != default_regional_endpoint.get_primary():
regional_object.set_alternate(default_regional_endpoint.get_primary())
else:
constructed_region_uri = LocationCache.GetLocationalEndpoint(
default_regional_endpoint.get_primary(),
new_location["name"])
regional_object.set_alternate(constructed_region_uri)
regional_routing_context_by_location.update({new_location["name"]: regional_object})
regional_object = RegionalRoutingContext(region_uri)
regional_routing_contexts_by_location.update({new_location["name"]: regional_object})
except Exception as e:
raise e

# Also store a hash map of endpoints for each location
locations_by_endpoints = {value.get_primary(): key for key, value in regional_routing_context_by_location.items()}
locations_by_endpoints = {value.get_primary(): key for key, value in regional_routing_contexts_by_location.items()}

return regional_routing_context_by_location, locations_by_endpoints, parsed_locations
return regional_routing_contexts_by_location, locations_by_endpoints, parsed_locations

def _get_health_check_endpoints(regional_routing_contexts) -> Set[str]:
# should use the endpoints in the order returned from gateway and only the ones specified in preferred locations
preferred_endpoints = {context.get_primary() for context in regional_routing_contexts}.union(
{context.get_alternate() for context in regional_routing_contexts}
)
preferred_endpoints = {context.get_primary() for context in regional_routing_contexts}
return preferred_endpoints

def _get_applicable_regional_routing_contexts(regional_routing_contexts: List[RegionalRoutingContext],
Expand Down Expand Up @@ -157,8 +119,7 @@ def __init__(
default_endpoint: str,
connection_policy: ConnectionPolicy,
):
self.default_regional_routing_context: RegionalRoutingContext = RegionalRoutingContext(default_endpoint,
default_endpoint)
self.default_regional_routing_context: RegionalRoutingContext = RegionalRoutingContext(default_endpoint)
self.effective_preferred_locations: List[str] = []
self.enable_multiple_writable_locations: bool = False
self.write_regional_routing_contexts: List[RegionalRoutingContext] = [self.default_regional_routing_context]
Expand Down Expand Up @@ -205,9 +166,8 @@ def perform_on_database_account_read(self, database_account):

def get_all_write_endpoints(self) -> Set[str]:
return {
endpoint
context.get_primary()
for context in self.get_write_regional_routing_contexts()
for endpoint in (context.get_primary(), context.get_alternate())
}

def get_ordered_write_locations(self):
Expand Down Expand Up @@ -272,10 +232,6 @@ def resolve_service_endpoint(self, request):
request.use_preferred_locations if request.use_preferred_locations is not None else True
)

# whether to check for write or read unavailable
endpoint_operation_type = EndpointOperationType.WriteType if (
documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType

if not use_preferred_locations or (
documents._OperationType.IsWriteOperation(request.operation_type)
and not self.can_use_multiple_write_locations_for_request(request)
Expand All @@ -290,14 +246,6 @@ def resolve_service_endpoint(self, request):
and write_location in self.account_write_regional_routing_contexts_by_location):
write_regional_routing_context = (
self.account_write_regional_routing_contexts_by_location)[write_location]
if (
request.last_routed_location_endpoint_within_region is not None
and request.last_routed_location_endpoint_within_region
== write_regional_routing_context.get_primary()
or self.is_endpoint_unavailable_internal(write_regional_routing_context.get_primary(),
endpoint_operation_type)
):
return write_regional_routing_context.get_alternate()
return write_regional_routing_context.get_primary()
# if endpoint discovery is off for reads it should use passed in endpoint
return self.default_regional_routing_context.get_primary()
Expand All @@ -308,14 +256,6 @@ def resolve_service_endpoint(self, request):
else self._get_applicable_read_regional_routing_contexts(request)
)
regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)]
if (
request.last_routed_location_endpoint_within_region is not None
and request.last_routed_location_endpoint_within_region
== regional_routing_context.get_primary()
or self.is_endpoint_unavailable_internal(regional_routing_context.get_primary(),
endpoint_operation_type)
):
return regional_routing_context.get_alternate()
return regional_routing_context.get_primary()

def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
Expand All @@ -342,7 +282,7 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement
return True

if not self.can_use_multiple_write_locations():
if self.is_location_unavailable(self.write_regional_routing_contexts[0],
if self.is_endpoint_unavailable(self.write_regional_routing_contexts[0].get_primary(),
EndpointOperationType.WriteType):
# same logic as other
# Since most preferred write endpoint is unavailable, we can only refresh in background if
Expand All @@ -360,16 +300,7 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement
return should_refresh
return False

def is_location_unavailable(self, endpoint: RegionalRoutingContext, operation_type: str):
# For writes with single write region accounts only mark it unavailable if both are down
if not _OperationType.IsReadOnlyOperation(operation_type) and not self.can_use_multiple_write_locations():
return (self.is_endpoint_unavailable_internal(endpoint.get_primary(), operation_type)
and self.is_endpoint_unavailable_internal(endpoint.get_alternate(), operation_type))

# For reads mark the region as down if primary endpoint is unavailable
return self.is_endpoint_unavailable_internal(endpoint.get_primary(), operation_type)

def is_endpoint_unavailable_internal(self, endpoint: str, expected_available_operation: str):
def is_endpoint_unavailable(self, endpoint: str, expected_available_operation: str):
unavailability_info = (
self.location_unavailability_info_by_endpoint[endpoint]
if endpoint in self.location_unavailability_info_by_endpoint
Expand Down Expand Up @@ -420,24 +351,12 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
if read_locations:
(self.account_read_regional_routing_contexts_by_location,
self.account_locations_by_read_endpoints,
self.account_read_locations) = get_endpoints_by_location(
read_locations,
self.account_read_regional_routing_contexts_by_location,
self.default_regional_routing_context,
False,
self.connection_policy.UseMultipleWriteLocations
)
self.account_read_locations) = get_regional_routing_contexts_by_loc(read_locations)

if write_locations:
(self.account_write_regional_routing_contexts_by_location,
self.account_locations_by_write_endpoints,
self.account_write_locations) = get_endpoints_by_location(
write_locations,
self.account_write_regional_routing_contexts_by_location,
self.default_regional_routing_context,
True,
self.connection_policy.UseMultipleWriteLocations
)
self.account_write_locations) = get_regional_routing_contexts_by_loc(write_locations)

# if preferred locations is empty and the default endpoint is a global endpoint,
# we should use the read locations from gateway as effective preferred locations
Expand Down Expand Up @@ -482,7 +401,8 @@ def get_preferred_regional_routing_contexts(
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
else None
if regional_endpoint:
if self.is_location_unavailable(regional_endpoint, expected_available_operation):
if self.is_endpoint_unavailable(regional_endpoint.get_primary(),
expected_available_operation):
unavailable_endpoints.append(regional_endpoint)
else:
regional_endpoints.append(regional_endpoint)
Expand Down Expand Up @@ -525,12 +445,10 @@ def can_use_multiple_write_locations_for_request(self, request): # pylint: disa

def endpoints_to_health_check(self) -> Set[str]:
# add read endpoints from gateway and in preferred locations
health_check_endpoints = _get_health_check_endpoints(
self.read_regional_routing_contexts
)
health_check_endpoints = _get_health_check_endpoints(self.read_regional_routing_contexts)
# add first write endpoint in case that the write region is not in preferred locations
health_check_endpoints = health_check_endpoints.union(_get_health_check_endpoints(
self.write_regional_routing_contexts[:1]
health_check_endpoints = health_check_endpoints.union(
_get_health_check_endpoints(self.write_regional_routing_contexts[:1]
))

return health_check_endpoints
Expand Down
1 change: 0 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
self.use_preferred_locations: Optional[bool] = None
self.location_index_to_route: Optional[int] = None
self.location_endpoint_to_route: Optional[str] = None
self.last_routed_location_endpoint_within_region: Optional[str] = None
self.excluded_locations: Optional[List[str]] = None
self.excluded_locations_circuit_breaker: List[str] = []
self.healthy_tentative_location: Optional[str] = None
Expand Down
Loading