Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from .auth.token import TokenInterface
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent
from .event import AfterConnectionReleasedEvent, EventDispatcher, OnErrorEvent, OnMaintenanceNotificationEvent, \
Comment thread
vladvildanov marked this conversation as resolved.
AfterConnectionCreatedEvent
from .exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand All @@ -53,6 +54,8 @@
MaintNotificationsConnectionHandler,
MaintNotificationsPoolHandler, MaintenanceNotification,
)
from .observability.attributes import AttributeBuilder, DB_CLIENT_CONNECTION_STATE, ConnectionState, \
DB_CLIENT_CONNECTION_POOL_NAME
from .retry import Retry
from .utils import (
CRYPTOGRAPHY_AVAILABLE,
Expand Down Expand Up @@ -2060,6 +2063,13 @@ def set_retry(self, retry: Retry):
def re_auth_callback(self, token: TokenInterface):
pass

@abstractmethod
def get_connection_count(self) -> list[tuple[int, dict]]:
Comment thread
vladvildanov marked this conversation as resolved.
"""
Returns a connection count (both idle and in use).
"""
pass


class MaintNotificationsAbstractConnectionPool:
"""
Expand Down Expand Up @@ -2088,8 +2098,12 @@ def __init__(
"Maintenance notifications handlers on connection are only supported with RESP version 3"
)

self._event_dispatcher = kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
self, maint_notifications_config
self, maint_notifications_config, self._event_dispatcher
)

self._update_connection_kwargs_for_maint_notifications(
Expand Down Expand Up @@ -2157,7 +2171,7 @@ def update_maint_notifications_config(
# first update pool settings
if not self._maint_notifications_pool_handler:
self._maint_notifications_pool_handler = MaintNotificationsPoolHandler(
self, maint_notifications_config
self, maint_notifications_config, self._event_dispatcher
)
else:
self._maint_notifications_pool_handler.config = maint_notifications_config
Expand Down Expand Up @@ -2635,11 +2649,17 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
"Get a connection from the pool"

self._checkpid()
is_created = False

with self._lock:
try:
connection = self._available_connections.pop()
except IndexError:
# Start timing for observability
start_time = time.monotonic()

connection = self.make_connection()
is_created = True
self._in_use_connections.add(connection)

try:
Expand All @@ -2666,6 +2686,14 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
# leak it
self.release(connection)
raise

if is_created:
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
)
)
return connection

def get_encoder(self) -> Encoder:
Expand Down Expand Up @@ -2785,6 +2813,20 @@ async def _mock(self, error: RedisError):
"""
pass

def get_connection_count(self) -> List[tuple[int, dict]]:
attributes = AttributeBuilder.build_base_attributes()
attributes[DB_CLIENT_CONNECTION_POOL_NAME] = repr(self)
free_connections_attributes = attributes.copy()
in_use_connections_attributes = attributes.copy()

free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.IDLE.value
in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ConnectionState.USED.value

return [
(len(self._get_free_connections()), free_connections_attributes),
(len(self._get_in_use_connections()), in_use_connections_attributes),
]


class BlockingConnectionPool(ConnectionPool):
"""
Expand Down Expand Up @@ -2917,6 +2959,7 @@ def get_connection(self, command_name=None, *keys, **options):
"""
# Make sure we haven't changed process.
self._checkpid()
is_created = False

# Try and get a connection from the pool. If one isn't available within
# self.timeout then raise a ``ConnectionError``.
Expand All @@ -2935,7 +2978,10 @@ def get_connection(self, command_name=None, *keys, **options):
# If the ``connection`` is actually ``None`` then that's a cue to make
# a new connection to add to the pool.
if connection is None:
# Start timing for observability
start_time = time.monotonic()
connection = self.make_connection()
is_created = True
finally:
if self._locked:
try:
Expand Down Expand Up @@ -2964,6 +3010,14 @@ def get_connection(self, command_name=None, *keys, **options):
self.release(connection)
raise

if is_created:
self._event_dispatcher.dispatch(
AfterConnectionCreatedEvent(
connection_pool=self,
duration_seconds=time.monotonic() - start_time,
)
)

return connection

def release(self, connection):
Expand Down
81 changes: 77 additions & 4 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union, Callable
Comment thread
vladvildanov marked this conversation as resolved.

from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \
record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff


class EventListenerInterface(ABC):
Expand Down Expand Up @@ -85,7 +86,8 @@ def __init__(
ReAuthConnectionListener(),
],
AfterPooledConnectionsInstantiationEvent: [
RegisterReAuthForPooledConnections()
RegisterReAuthForPooledConnections(),
InitializeConnectionCountObservability()
],
AfterSingleConnectionInstantiationEvent: [
RegisterReAuthForSingleConnection()
Expand All @@ -97,6 +99,16 @@ def __init__(
AsyncReAuthConnectionListener(),
],
OnErrorEvent: [ExportErrorCountMetric()],
OnMaintenanceNotificationEvent: [
ExportMaintenanceNotificationCountMetric(),
],
AfterConnectionCreatedEvent: [ExportConnectionCreateTimeMetric()],
AfterConnectionTimeoutUpdatedEvent: [
ExportConnectionRelaxedTimeoutMetric(),
],
AfterConnectionHandoffEvent: [
ExportConnectionHandoffMetric(),
],
}

self._lock = threading.Lock()
Expand Down Expand Up @@ -333,6 +345,30 @@ class OnMaintenanceNotificationEvent:
notification: "MaintenanceNotification"
connection: "MaintNotificationsAbstractConnection"

@dataclass
class AfterConnectionCreatedEvent:
"""
Event fired after connection is created in pool.
"""
connection_pool: "ConnectionPoolInterface"
duration_seconds: float

@dataclass
class AfterConnectionTimeoutUpdatedEvent:
"""
Event fired after connection timeout is updated.
"""
connection: "MaintNotificationsAbstractConnection"
notification: "MaintenanceNotification"
relaxed: bool

@dataclass
class AfterConnectionHandoffEvent:
"""
Event fired after connection is handed off.
"""
connection_pool: "ConnectionPoolInterface"

class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
pass

Expand Down Expand Up @@ -547,4 +583,41 @@ def listen(self, event: OnMaintenanceNotificationEvent):
network_peer_address=event.connection.host,
network_peer_port=event.connection.port,
maint_notification=repr(event.notification),
)
)

class ExportConnectionCreateTimeMetric(EventListenerInterface):
"""
Listener that exports connection create time metric.
"""
def listen(self, event: AfterConnectionCreatedEvent):
record_connection_create_time(
connection_pool=event.connection_pool,
duration_seconds=event.duration_seconds,
)

class InitializeConnectionCountObservability(EventListenerInterface):
"""
Listener that initializes connection count observability.
"""
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
init_connection_count(event.connection_pools)

class ExportConnectionRelaxedTimeoutMetric(EventListenerInterface):
"""
Listener that exports connection relaxed timeout metric.
"""
def listen(self, event: AfterConnectionTimeoutUpdatedEvent):
record_connection_relaxed_timeout(
connection_name=repr(event.connection),
maint_notification=repr(event.notification),
relaxed=event.relaxed,
)

class ExportConnectionHandoffMetric(EventListenerInterface):
"""
Listener that exports connection handoff metric.
"""
def listen(self, event: AfterConnectionHandoffEvent):
record_connection_handoff(
pool_name=repr(event.connection_pool),
)
43 changes: 37 additions & 6 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Union

from redis.event import OnMaintenanceNotificationEvent
from redis.event import OnMaintenanceNotificationEvent, EventDispatcherInterface, EventDispatcher, \
AfterConnectionTimeoutUpdatedEvent, AfterConnectionHandoffEvent
from redis.typing import Number


Expand Down Expand Up @@ -560,21 +561,27 @@ def __init__(
self,
pool: "MaintNotificationsAbstractConnectionPool",
config: MaintNotificationsConfig,
event_dispatcher: Optional[EventDispatcherInterface] = None,
Comment thread
vladvildanov marked this conversation as resolved.
) -> None:
self.pool = pool
self.config = config
self._processed_notifications = set()
self._lock = threading.RLock()
self.connection = None

if event_dispatcher is not None:
self.event_dispatcher = event_dispatcher
else:
self.event_dispatcher = EventDispatcher()

def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
self.connection = connection

def get_handler_for_connection(self):
# Copy all data that should be shared between connections
# but each connection should have its own pool handler
# since each connection can be in a different state
copy = MaintNotificationsPoolHandler(self.pool, self.config)
copy = MaintNotificationsPoolHandler(self.pool, self.config, self.event_dispatcher)
copy._processed_notifications = self._processed_notifications
copy._lock = self._lock
copy.connection = None
Expand Down Expand Up @@ -683,6 +690,12 @@ def handle_node_moving_notification(self, notification: NodeMovingNotification):
args=(notification,),
).start()

self.event_dispatcher.dispatch(
AfterConnectionHandoffEvent(
connection_pool=self.pool,
)
)

self._processed_notifications.add(notification)

def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
Expand Down Expand Up @@ -784,12 +797,12 @@ def handle_notification(self, notification: MaintenanceNotification):
return

if notification_type:
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE, notification=notification)
else:
self.handle_maintenance_completed_notification()
self.handle_maintenance_completed_notification(notification=notification)

def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState
self, maintenance_state: MaintenanceState, **kwargs
):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -804,7 +817,16 @@ def handle_maintenance_start_notification(
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)

def handle_maintenance_completed_notification(self):
if kwargs.get('notification'):
self.connection.event_dispatcher.dispatch(
AfterConnectionTimeoutUpdatedEvent(
connection=self.connection,
notification=kwargs.get('notification'),
relaxed=True,
)
)

def handle_maintenance_completed_notification(self, **kwargs):
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -816,3 +838,12 @@ def handle_maintenance_completed_notification(self):
# timeouts by providing -1 as the relaxed timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE

if kwargs.get('notification'):
self.connection.event_dispatcher.dispatch(
AfterConnectionTimeoutUpdatedEvent(
connection=self.connection,
notification=kwargs.get('notification'),
relaxed=False,
)
)
Loading