Skip to content
Merged
21 changes: 21 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import logging
import re
import threading
import time
Expand Down Expand Up @@ -89,6 +90,20 @@
NEVER_DECODE = "NEVER_DECODE"


logger = logging.getLogger(__name__)


def is_debug_log_enabled():
return logger.isEnabledFor(logging.DEBUG)


def add_debug_log_for_operation_failure(connection: "AbstractConnection"):
logger.debug(
f"Operation failed, "
f"with connection: {connection}, details: {connection.extract_connection_details() if connection else 'no connection'}",
)


class CaseInsensitiveDict(dict):
"Case insensitive dict implementation. Assumes string keys only."

Expand Down Expand Up @@ -727,6 +742,8 @@ def _execute_command(self, *args, **options):
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
if is_debug_log_enabled():
add_debug_log_for_operation_failure(conn)
Comment thread
cursor[bot] marked this conversation as resolved.
actual_retry_attempts[0] = failure_count
self._close_connection(conn, error, failure_count, start_time, command_name)

Expand Down Expand Up @@ -1709,6 +1726,8 @@ def immediate_execute_command(self, *args, **options):
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
if is_debug_log_enabled():
add_debug_log_for_operation_failure(conn)
actual_retry_attempts[0] = failure_count
self._disconnect_reset_raise_on_watching(
conn, error, failure_count, start_time, command_name
Expand Down Expand Up @@ -1946,6 +1965,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]:
actual_retry_attempts = [0]

def failure_callback(error, failure_count):
if is_debug_log_enabled():
add_debug_log_for_operation_failure(conn)
actual_retry_attempts[0] = failure_count
self._disconnect_raise_on_watching(
conn, error, failure_count, start_time, operation_name
Expand Down
6 changes: 6 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3951,6 +3951,12 @@ def _reinitialize_on_error(self, error, failure_count):
or type(error) in self.CONNECTION_ERRORS
):
if self._transaction_connection:
if is_debug_log_enabled():
logger.debug(
f"Operation failed, "
f"with connection: {self._transaction_connection}, "
f"details: {self._transaction_connection.extract_connection_details()}",
)
Comment thread
cursor[bot] marked this conversation as resolved.
# Disconnect and release back to pool
self._transaction_connection.disconnect()
node = self._nodes_manager.find_connection_owner(
Expand Down
19 changes: 19 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ def reset_should_reconnect(self):
"""
pass

@abstractmethod
def extract_connection_details(self) -> str:
pass


class MaintNotificationsAbstractConnection:
"""
Expand Down Expand Up @@ -1442,6 +1446,18 @@ def socket_connect_timeout(self) -> Optional[Union[float, int]]:
def socket_connect_timeout(self, value: Optional[Union[float, int]]):
self._socket_connect_timeout = value

def extract_connection_details(self) -> str:
socket_address = None
if self._sock is None:
return "not connected"
try:
socket_address = self._sock.getsockname() if self._sock else None
socket_address = socket_address[1] if socket_address else None
except (AttributeError, OSError):
pass

return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}"


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
Expand Down Expand Up @@ -1899,6 +1915,9 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]
reason=CSCReason.INVALIDATION,
)

def extract_connection_details(self) -> str:
return self._conn.extract_connection_details()


class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
Expand Down
22 changes: 18 additions & 4 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,12 @@ def handle_maintenance_completed_notification(self, **kwargs):
or not self.config.is_relaxed_timeouts_enabled()
):
return
add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED")
notification = None
if kwargs.get("notification"):
notification = kwargs["notification"]
add_debug_log_for_notification(
self.connection, notification if notification else "MAINTENANCE_COMPLETED"
)
self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
# Maintenance completed - reset the connection
# timeouts by providing -1 as the relaxed timeout
Expand All @@ -1016,8 +1021,7 @@ def handle_maintenance_completed_notification(self, **kwargs):
# notifications and skipped end maint notifications
self.connection.reset_received_notifications()

if kwargs.get("notification"):
notification = kwargs["notification"]
if notification:
Comment thread
cursor[bot] marked this conversation as resolved.
record_connection_relaxed_timeout(
connection_name=repr(self.connection),
maint_notification=notification.__class__.__name__,
Expand Down Expand Up @@ -1076,7 +1080,8 @@ def handle_oss_maintenance_completed_notification(
# process the same notification twice
return

logger.debug(f"Handling SMIGRATED notification: {notification}")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Handling SMIGRATED notification: {notification}")
self._in_progress.add(notification)

# Extract the information about the src and destination nodes that are affected
Expand Down Expand Up @@ -1130,7 +1135,16 @@ def handle_oss_maintenance_completed_notification(
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
add_debug_log_for_notification(
conn, "SMIGRATED - mark for reconnect"
)
conn.mark_for_reconnect()
else:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"SMIGRATED: Node {current_node.name} not affected by maintenance, "
f"skipping mark for reconnect"
)

if (
current_node
Expand Down
46 changes: 42 additions & 4 deletions tests/test_scenario/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,39 @@ def extract_cluster_fqdn(url):
return f"https://{cleaned_hostname}"


def _prepare_ssl_certificates(cert_chain: bool) -> dict:
"""
Prepare SSL certificates for Redis cluster connection.

Args:
cert_chain: PEM-encoded certificate chain containing client cert + intermediate + CA cert.
This is the full certificate chain that will be used to validate the server.

Returns:
dict: SSL configuration kwargs for RedisCluster
"""
certs_config_path = os.environ.get("MTLS_CONFIG_PATH", None)

if not cert_chain:
return {
"ssl_cert_reqs": "none",
"ssl_check_hostname": False,
}

if not certs_config_path:
raise ValueError(
"MTLS enabled test is triggered but MTLS_CONFIG_PATH environment variable not set"
)

# The cert_chain contains the full chain (client cert + intermediate + root CA)
# Use it as CA data for validating the server's certificate
return {
"ssl_cert_reqs": "none",
"ssl_keyfile": os.path.join(certs_config_path, "client.key"),
"ssl_certfile": os.path.join(certs_config_path, "client.crt"),
}


@pytest.fixture()
def client_maint_notifications(endpoints_config):
return _get_client_maint_notifications(endpoints_config)
Expand Down Expand Up @@ -307,8 +340,8 @@ def get_cluster_client_maint_notifications(
enable_relaxed_timeout: bool = True,
enable_proactive_reconnect: bool = True,
disable_retries: bool = False,
auth_ssl_client_certs: bool = False,
socket_timeout: Optional[float] = None,
host_config: Optional[str] = None,
):
"""Create Redis cluster client with maintenance notifications enabled."""
# Get credentials from the configuration
Expand Down Expand Up @@ -337,6 +370,13 @@ def get_cluster_client_maint_notifications(
tls_enabled = True if parsed.scheme == "rediss" else False
logging.info(f"TLS enabled: {tls_enabled}")

tls_kwargs = {"ssl": tls_enabled}

if tls_enabled:
# Prepare SSL certificate configuration
ssl_config = _prepare_ssl_certificates(auth_ssl_client_certs)
tls_kwargs.update(ssl_config)

# Configure maintenance notifications
maintenance_config = MaintNotificationsConfig(
enabled=enable_maintenance_notifications,
Expand All @@ -352,12 +392,10 @@ def get_cluster_client_maint_notifications(
socket_timeout=CLIENT_TIMEOUT if socket_timeout is None else socket_timeout,
username=username,
password=password,
ssl=tls_enabled,
ssl_cert_reqs="none",
ssl_check_hostname=False,
protocol=protocol, # RESP3 required for push notifications
maint_notifications_config=maintenance_config,
retry=retry,
**tls_kwargs,
)
logging.info("Redis cluster client created with maintenance notifications enabled")
logging.info(
Expand Down
27 changes: 19 additions & 8 deletions tests/test_scenario/test_maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,16 +1239,10 @@ def execute_commands(duration: int, errors: Queue):
assert errors.empty(), f"Errors occurred in threads: {errors.queue}"


# 5 minutes timeout for this test
# @pytest.mark.skipif(
# use_mock_proxy(),
# reason="Mock proxy doesn't support sending notifications to new connections.",
# )


def generate_params(
fault_injector_client: FaultInjectorClient,
effect_names: list[SlotMigrateEffects],
skip_combinations: list[tuple[SlotMigrateEffects, str]] = [],
):
# params should produce list of tuples: (effect_name, trigger_name, bdb_config, bdb_name)
params = []
Expand All @@ -1261,6 +1255,8 @@ def generate_params(

for trigger_info in triggers_data["triggers"]:
trigger = trigger_info["name"]
if (effect_name, trigger) in skip_combinations:
continue
if trigger == "maintenance_mode":
continue
trigger_requirements = trigger_info["requirements"]
Expand Down Expand Up @@ -1332,11 +1328,23 @@ def setup_env(
self._bdb_name = db_config["name"]
socket_timeout = DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT

auth_ssl_client_certs_config_info = db_config.get(
"authentication_ssl_client_certs", None
)

auth_ssl_client_certs = (
True
if auth_ssl_client_certs_config_info
and auth_ssl_client_certs_config_info[0]["client_cert"] is not None
else False
)
Comment thread
petyaslavova marked this conversation as resolved.

cluster_client_maint_notifications = get_cluster_client_maint_notifications(
endpoints_config=cluster_endpoint_config,
disable_retries=True,
socket_timeout=socket_timeout,
enable_maintenance_notifications=True,
auth_ssl_client_certs=auth_ssl_client_certs,
)
return cluster_client_maint_notifications, cluster_endpoint_config

Expand Down Expand Up @@ -1741,9 +1749,12 @@ def test_notification_handling_with_node_remove(
SlotMigrateEffects.REMOVE,
SlotMigrateEffects.ADD,
],
skip_combinations=[
(SlotMigrateEffects.SLOT_SHUFFLE, "failover"),
], # maintenance ends too fast for the test to be reliable
),
)
def test_new_connections_receive_last_notification_with_migrating(
def test_new_connections_receive_last_smigrating_smigrated_notification(
self,
fault_injector_client_oss_api: FaultInjectorClient,
effect_name: SlotMigrateEffects,
Expand Down