Skip to content

Commit 0438784

Browse files
vladvildanovpetyaslavova
authored andcommitted
Added initial health check policies, refactored add_database method (#3906)
* Added initial health check policies, refactored add_database method * Codestyle fixes
1 parent 76befb4 commit 0438784

11 files changed

Lines changed: 812 additions & 155 deletions

File tree

redis/asyncio/multidb/client.py

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
import asyncio
22
import logging
3-
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
3+
from typing import Any, Awaitable, Callable, List, Optional, Union
44

55
from redis.asyncio.client import PubSubHandler
66
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
7-
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
8-
from redis.asyncio.multidb.database import AsyncDatabase, Databases
7+
from redis.asyncio.multidb.config import (
8+
DEFAULT_GRACE_PERIOD,
9+
DatabaseConfig,
10+
InitialHealthCheck,
11+
MultiDbConfig,
12+
)
13+
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
914
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
1015
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
16+
from redis.asyncio.retry import Retry
1117
from redis.background import BackgroundScheduler
18+
from redis.backoff import NoBackoff
1219
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
1320
from redis.multidb.circuit import CircuitBreaker
1421
from redis.multidb.circuit import State as CBState
15-
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
22+
from redis.multidb.exception import (
23+
InitialHealthCheckFailedError,
24+
NoValidDatabaseException,
25+
UnhealthyDatabaseException,
26+
)
1627
from redis.typing import ChannelT, EncodableT, KeyT
1728
from redis.utils import experimental
1829

@@ -90,11 +101,8 @@ async def initialize(self):
90101
Perform initialization of databases to define their initial state.
91102
"""
92103

93-
async def raise_exception_on_failed_hc(error):
94-
raise error
95-
96104
# Initial databases check to define initial state
97-
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
105+
await self._perform_initial_health_check()
98106

99107
# Starts recurring health checks on the background.
100108
self._recurring_hc_task = asyncio.create_task(
@@ -153,15 +161,44 @@ async def set_active_database(self, database: AsyncDatabase) -> None:
153161
"Cannot set active database, database is unhealthy"
154162
)
155163

156-
async def add_database(self, database: AsyncDatabase):
164+
async def add_database(self, config: DatabaseConfig, skip_unhealthy: bool = True):
157165
"""
158166
Adds a new database to the database list.
159167
"""
160-
for existing_db, _ in self._databases:
161-
if existing_db == database:
162-
raise ValueError("Given database already exists")
168+
# The retry object is not used in the lower level clients, so we can safely remove it.
169+
# We rely on command_retry in terms of global retries.
170+
config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())})
163171

164-
await self._check_db_health(database)
172+
if config.from_url:
173+
client = self._config.client_class.from_url(
174+
config.from_url, **config.client_kwargs
175+
)
176+
elif config.from_pool:
177+
config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff()))
178+
client = self._config.client_class.from_pool(
179+
connection_pool=config.from_pool
180+
)
181+
else:
182+
client = self._config.client_class(**config.client_kwargs)
183+
184+
circuit = (
185+
config.default_circuit_breaker()
186+
if config.circuit is None
187+
else config.circuit
188+
)
189+
190+
database = Database(
191+
client=client,
192+
circuit=circuit,
193+
weight=config.weight,
194+
health_check_url=config.health_check_url,
195+
)
196+
197+
try:
198+
await self._check_db_health(database)
199+
except UnhealthyDatabaseException:
200+
if not skip_unhealthy:
201+
raise
165202

166203
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
167204
self._databases.add(database, database.weight)
@@ -269,32 +306,35 @@ async def pubsub(self, **kwargs):
269306

270307
return PubSub(self, **kwargs)
271308

272-
async def _check_databases_health(
273-
self,
274-
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
275-
):
309+
async def _check_databases_health(self) -> dict[Database, bool]:
276310
"""
277311
Runs health checks as a recurring task.
278312
Runs health checks against all databases.
279313
"""
280314
try:
281-
self._hc_tasks = [
282-
asyncio.create_task(self._check_db_health(database))
283-
for database, _ in self._databases
284-
]
315+
task_to_db: dict[asyncio.Task, Database] = {}
316+
317+
self._hc_tasks = []
318+
for database, _ in self._databases:
319+
task = asyncio.create_task(self._check_db_health(database))
320+
task_to_db[task] = database
321+
self._hc_tasks.append(task)
322+
285323
results = await asyncio.wait_for(
286-
asyncio.gather(
287-
*self._hc_tasks,
288-
return_exceptions=True,
289-
),
324+
asyncio.gather(*self._hc_tasks, return_exceptions=True),
290325
timeout=self._health_check_interval,
291326
)
292327
except asyncio.TimeoutError:
293328
raise asyncio.TimeoutError(
294329
"Health check execution exceeds health_check_interval"
295330
)
296331

297-
for result in results:
332+
# Map end results to databases
333+
db_results = {
334+
task_to_db[task]: result for task, result in zip(self._hc_tasks, results)
335+
}
336+
337+
for database, result in db_results.items():
298338
if isinstance(result, UnhealthyDatabaseException):
299339
unhealthy_db = result.database
300340
unhealthy_db.circuit.state = CBState.OPEN
@@ -304,8 +344,33 @@ async def _check_databases_health(
304344
exc_info=result.original_exception,
305345
)
306346

307-
if on_error:
308-
on_error(result.original_exception)
347+
db_results[unhealthy_db] = False
348+
elif isinstance(result, Exception):
349+
db_results[database] = False
350+
351+
return db_results
352+
353+
async def _perform_initial_health_check(self):
354+
"""
355+
Runs initial health check and evaluate healthiness based on initial_health_check_policy.
356+
"""
357+
results = await self._check_databases_health()
358+
is_healthy = True
359+
360+
if self._config.initial_health_check_policy == InitialHealthCheck.ALL_HEALTHY:
361+
is_healthy = False not in results.values()
362+
elif (
363+
self._config.initial_health_check_policy
364+
== InitialHealthCheck.MAJORITY_HEALTHY
365+
):
366+
is_healthy = sum(results.values()) > len(results) / 2
367+
elif self._config.initial_health_check_policy == InitialHealthCheck.ANY_HEALTHY:
368+
is_healthy = True in results.values()
369+
370+
if not is_healthy:
371+
raise InitialHealthCheckFailedError(
372+
f"Initial health check failed. Initial health check policy: {self._config.initial_health_check_policy}"
373+
)
309374

310375
async def _check_db_health(self, database: AsyncDatabase) -> bool:
311376
"""
@@ -337,7 +402,7 @@ def _on_circuit_state_change_callback(
337402
return
338403

339404
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
340-
logger.error(
405+
logger.warning(
341406
f"Database {circuit.database} is unreachable. Failover has been initiated."
342407
)
343408
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)

redis/asyncio/multidb/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass, field
2+
from enum import Enum
23
from typing import List, Optional, Type, Union
34

45
import pybreaker
@@ -43,6 +44,12 @@
4344
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
4445

4546

47+
class InitialHealthCheck(Enum):
48+
ALL_HEALTHY = "all_healthy"
49+
MAJORITY_HEALTHY = "majority_healthy"
50+
ANY_HEALTHY = "any_healthy"
51+
52+
4653
def default_event_dispatcher() -> EventDispatcherInterface:
4754
return EventDispatcher()
4855

@@ -108,6 +115,8 @@ class MultiDbConfig:
108115
failover_delay: Delay between failover attempts.
109116
auto_fallback_interval: Time interval to trigger automatic fallback.
110117
event_dispatcher: Interface for dispatching events related to database operations.
118+
initial_health_check_policy: Defines the policy used to determine whether the databases setup is
119+
healthy during the initial health check.
111120
112121
Methods:
113122
databases:
@@ -148,6 +157,7 @@ class MultiDbConfig:
148157
event_dispatcher: EventDispatcherInterface = field(
149158
default_factory=default_event_dispatcher
150159
)
160+
initial_health_check_policy: InitialHealthCheck = InitialHealthCheck.ALL_HEALTHY
151161

152162
def databases(self) -> Databases:
153163
databases = WeightedList()

redis/multidb/client.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,27 @@
55
from typing import Any, Callable, List, Optional
66

77
from redis.background import BackgroundScheduler
8+
from redis.backoff import NoBackoff
89
from redis.client import PubSubWorkerThread
910
from redis.commands import CoreCommands, RedisModuleCommands
1011
from redis.multidb.circuit import CircuitBreaker
1112
from redis.multidb.circuit import State as CBState
1213
from redis.multidb.command_executor import DefaultCommandExecutor
13-
from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
14+
from redis.multidb.config import (
15+
DEFAULT_GRACE_PERIOD,
16+
DatabaseConfig,
17+
InitialHealthCheck,
18+
MultiDbConfig,
19+
)
1420
from redis.multidb.database import Database, Databases, SyncDatabase
15-
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
21+
from redis.multidb.exception import (
22+
InitialHealthCheckFailedError,
23+
NoValidDatabaseException,
24+
UnhealthyDatabaseException,
25+
)
1626
from redis.multidb.failure_detector import FailureDetector
1727
from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy
28+
from redis.retry import Retry
1829
from redis.utils import experimental
1930

2031
logger = logging.getLogger(__name__)
@@ -74,11 +85,8 @@ def initialize(self):
7485
Perform initialization of databases to define their initial state.
7586
"""
7687

77-
def raise_exception_on_failed_hc(error):
78-
raise error
79-
8088
# Initial databases check to define initial state
81-
self._check_databases_health(on_error=raise_exception_on_failed_hc)
89+
self._perform_initial_health_check()
8290

8391
# Starts recurring health checks on the background.
8492
self._bg_scheduler.run_recurring(
@@ -135,15 +143,44 @@ def set_active_database(self, database: SyncDatabase) -> None:
135143
"Cannot set active database, database is unhealthy"
136144
)
137145

138-
def add_database(self, database: SyncDatabase):
146+
def add_database(self, config: DatabaseConfig, skip_unhealthy: bool = True):
139147
"""
140148
Adds a new database to the database list.
141149
"""
142-
for existing_db, _ in self._databases:
143-
if existing_db == database:
144-
raise ValueError("Given database already exists")
150+
# The retry object is not used in the lower level clients, so we can safely remove it.
151+
# We rely on command_retry in terms of global retries.
152+
config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())})
145153

146-
self._check_db_health(database)
154+
if config.from_url:
155+
client = self._config.client_class.from_url(
156+
config.from_url, **config.client_kwargs
157+
)
158+
elif config.from_pool:
159+
config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff()))
160+
client = self._config.client_class.from_pool(
161+
connection_pool=config.from_pool
162+
)
163+
else:
164+
client = self._config.client_class(**config.client_kwargs)
165+
166+
circuit = (
167+
config.default_circuit_breaker()
168+
if config.circuit is None
169+
else config.circuit
170+
)
171+
172+
database = Database(
173+
client=client,
174+
circuit=circuit,
175+
weight=config.weight,
176+
health_check_url=config.health_check_url,
177+
)
178+
179+
try:
180+
self._check_db_health(database)
181+
except UnhealthyDatabaseException:
182+
if not skip_unhealthy:
183+
raise
147184

148185
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
149186
self._databases.add(database, database.weight)
@@ -254,24 +291,27 @@ def _check_db_health(self, database: SyncDatabase) -> bool:
254291

255292
return is_healthy
256293

257-
def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
294+
def _check_databases_health(self) -> dict[Database, bool]:
258295
"""
259296
Runs health checks as a recurring task.
260297
Runs health checks against all databases.
261298
"""
262299
with ThreadPoolExecutor(max_workers=len(self._databases)) as executor:
263300
# Submit all health checks
264301
futures = {
265-
executor.submit(self._check_db_health, database)
302+
executor.submit(self._check_db_health, database): database
266303
for database, _ in self._databases
267304
}
268305

306+
results = {}
307+
269308
try:
270309
for future in as_completed(
271310
futures, timeout=self._health_check_interval
272311
):
273312
try:
274-
future.result()
313+
database = futures[future]
314+
results[database] = future.result()
275315
except UnhealthyDatabaseException as e:
276316
unhealthy_db = e.database
277317
unhealthy_db.circuit.state = CBState.OPEN
@@ -281,12 +321,34 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
281321
exc_info=e.original_exception,
282322
)
283323

284-
if on_error:
285-
on_error(e.original_exception)
324+
results[unhealthy_db] = False
286325
except TimeoutError:
287326
raise TimeoutError(
288327
"Health check execution exceeds health_check_interval"
289328
)
329+
return results
330+
331+
def _perform_initial_health_check(self):
332+
"""
333+
Runs initial health check and evaluate healthiness based on initial_health_check_policy.
334+
"""
335+
results = self._check_databases_health()
336+
is_healthy = True
337+
338+
if self._config.initial_health_check_policy == InitialHealthCheck.ALL_HEALTHY:
339+
is_healthy = False not in results.values()
340+
elif (
341+
self._config.initial_health_check_policy
342+
== InitialHealthCheck.MAJORITY_HEALTHY
343+
):
344+
is_healthy = sum(results.values()) > len(results) / 2
345+
elif self._config.initial_health_check_policy == InitialHealthCheck.ANY_HEALTHY:
346+
is_healthy = True in results.values()
347+
348+
if not is_healthy:
349+
raise InitialHealthCheckFailedError(
350+
f"Initial health check failed. Initial health check policy: {self._config.initial_health_check_policy}"
351+
)
290352

291353
def _on_circuit_state_change_callback(
292354
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
@@ -296,7 +358,7 @@ def _on_circuit_state_change_callback(
296358
return
297359

298360
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
299-
logger.error(
361+
logger.warning(
300362
f"Database {circuit.database} is unreachable. Failover has been initiated."
301363
)
302364

0 commit comments

Comments
 (0)