11import asyncio
22import logging
3- from typing import Any , Awaitable , Callable , Coroutine , List , Optional , Union
3+ from typing import Any , Awaitable , Callable , List , Optional , Union
44
55from redis .asyncio .client import PubSubHandler
66from redis .asyncio .multidb .command_executor import DefaultCommandExecutor
7- from redis .asyncio .multidb .config import DEFAULT_GRACE_PERIOD , MultiDbConfig
8- from redis .asyncio .multidb .database import AsyncDatabase , Databases
7+ from redis .asyncio .multidb .config import (
8+ DEFAULT_GRACE_PERIOD ,
9+ DatabaseConfig ,
10+ InitialHealthCheck ,
11+ MultiDbConfig ,
12+ )
13+ from redis .asyncio .multidb .database import AsyncDatabase , Database , Databases
914from redis .asyncio .multidb .failure_detector import AsyncFailureDetector
1015from redis .asyncio .multidb .healthcheck import HealthCheck , HealthCheckPolicy
16+ from redis .asyncio .retry import Retry
1117from redis .background import BackgroundScheduler
18+ from redis .backoff import NoBackoff
1219from redis .commands import AsyncCoreCommands , AsyncRedisModuleCommands
1320from redis .multidb .circuit import CircuitBreaker
1421from redis .multidb .circuit import State as CBState
15- from redis .multidb .exception import NoValidDatabaseException , UnhealthyDatabaseException
22+ from redis .multidb .exception import (
23+ InitialHealthCheckFailedError ,
24+ NoValidDatabaseException ,
25+ UnhealthyDatabaseException ,
26+ )
1627from redis .typing import ChannelT , EncodableT , KeyT
1728from redis .utils import experimental
1829
@@ -90,11 +101,8 @@ async def initialize(self):
90101 Perform initialization of databases to define their initial state.
91102 """
92103
93- async def raise_exception_on_failed_hc (error ):
94- raise error
95-
96104 # Initial databases check to define initial state
97- await self ._check_databases_health ( on_error = raise_exception_on_failed_hc )
105+ await self ._perform_initial_health_check ( )
98106
99107 # Starts recurring health checks on the background.
100108 self ._recurring_hc_task = asyncio .create_task (
@@ -153,15 +161,44 @@ async def set_active_database(self, database: AsyncDatabase) -> None:
153161 "Cannot set active database, database is unhealthy"
154162 )
155163
156- async def add_database (self , database : AsyncDatabase ):
164+ async def add_database (self , config : DatabaseConfig , skip_unhealthy : bool = True ):
157165 """
158166 Adds a new database to the database list.
159167 """
160- for existing_db , _ in self . _databases :
161- if existing_db == database :
162- raise ValueError ( "Given database already exists" )
168+ # The retry object is not used in the lower level clients, so we can safely remove it.
169+ # We rely on command_retry in terms of global retries.
170+ config . client_kwargs . update ({ "retry" : Retry ( retries = 0 , backoff = NoBackoff ())} )
163171
164- await self ._check_db_health (database )
172+ if config .from_url :
173+ client = self ._config .client_class .from_url (
174+ config .from_url , ** config .client_kwargs
175+ )
176+ elif config .from_pool :
177+ config .from_pool .set_retry (Retry (retries = 0 , backoff = NoBackoff ()))
178+ client = self ._config .client_class .from_pool (
179+ connection_pool = config .from_pool
180+ )
181+ else :
182+ client = self ._config .client_class (** config .client_kwargs )
183+
184+ circuit = (
185+ config .default_circuit_breaker ()
186+ if config .circuit is None
187+ else config .circuit
188+ )
189+
190+ database = Database (
191+ client = client ,
192+ circuit = circuit ,
193+ weight = config .weight ,
194+ health_check_url = config .health_check_url ,
195+ )
196+
197+ try :
198+ await self ._check_db_health (database )
199+ except UnhealthyDatabaseException :
200+ if not skip_unhealthy :
201+ raise
165202
166203 highest_weighted_db , highest_weight = self ._databases .get_top_n (1 )[0 ]
167204 self ._databases .add (database , database .weight )
@@ -269,32 +306,35 @@ async def pubsub(self, **kwargs):
269306
270307 return PubSub (self , ** kwargs )
271308
272- async def _check_databases_health (
273- self ,
274- on_error : Optional [Callable [[Exception ], Coroutine [Any , Any , None ]]] = None ,
275- ):
309+ async def _check_databases_health (self ) -> dict [Database , bool ]:
276310 """
277311 Runs health checks as a recurring task.
278312 Runs health checks against all databases.
279313 """
280314 try :
281- self ._hc_tasks = [
282- asyncio .create_task (self ._check_db_health (database ))
283- for database , _ in self ._databases
284- ]
315+ task_to_db : dict [asyncio .Task , Database ] = {}
316+
317+ self ._hc_tasks = []
318+ for database , _ in self ._databases :
319+ task = asyncio .create_task (self ._check_db_health (database ))
320+ task_to_db [task ] = database
321+ self ._hc_tasks .append (task )
322+
285323 results = await asyncio .wait_for (
286- asyncio .gather (
287- * self ._hc_tasks ,
288- return_exceptions = True ,
289- ),
324+ asyncio .gather (* self ._hc_tasks , return_exceptions = True ),
290325 timeout = self ._health_check_interval ,
291326 )
292327 except asyncio .TimeoutError :
293328 raise asyncio .TimeoutError (
294329 "Health check execution exceeds health_check_interval"
295330 )
296331
297- for result in results :
332+ # Map end results to databases
333+ db_results = {
334+ task_to_db [task ]: result for task , result in zip (self ._hc_tasks , results )
335+ }
336+
337+ for database , result in db_results .items ():
298338 if isinstance (result , UnhealthyDatabaseException ):
299339 unhealthy_db = result .database
300340 unhealthy_db .circuit .state = CBState .OPEN
@@ -304,8 +344,33 @@ async def _check_databases_health(
304344 exc_info = result .original_exception ,
305345 )
306346
307- if on_error :
308- on_error (result .original_exception )
347+ db_results [unhealthy_db ] = False
348+ elif isinstance (result , Exception ):
349+ db_results [database ] = False
350+
351+ return db_results
352+
353+ async def _perform_initial_health_check (self ):
354+ """
355+ Runs initial health check and evaluate healthiness based on initial_health_check_policy.
356+ """
357+ results = await self ._check_databases_health ()
358+ is_healthy = True
359+
360+ if self ._config .initial_health_check_policy == InitialHealthCheck .ALL_HEALTHY :
361+ is_healthy = False not in results .values ()
362+ elif (
363+ self ._config .initial_health_check_policy
364+ == InitialHealthCheck .MAJORITY_HEALTHY
365+ ):
366+ is_healthy = sum (results .values ()) > len (results ) / 2
367+ elif self ._config .initial_health_check_policy == InitialHealthCheck .ANY_HEALTHY :
368+ is_healthy = True in results .values ()
369+
370+ if not is_healthy :
371+ raise InitialHealthCheckFailedError (
372+ f"Initial health check failed. Initial health check policy: { self ._config .initial_health_check_policy } "
373+ )
309374
310375 async def _check_db_health (self , database : AsyncDatabase ) -> bool :
311376 """
@@ -337,7 +402,7 @@ def _on_circuit_state_change_callback(
337402 return
338403
339404 if old_state == CBState .CLOSED and new_state == CBState .OPEN :
340- logger .error (
405+ logger .warning (
341406 f"Database { circuit .database } is unreachable. Failover has been initiated."
342407 )
343408 loop .call_later (DEFAULT_GRACE_PERIOD , _half_open_circuit , circuit )
0 commit comments