|
1 | 1 | import asyncio |
| 2 | +import inspect |
2 | 3 | import logging |
3 | 4 | from abc import ABC, abstractmethod |
4 | 5 | from enum import Enum |
5 | | -from typing import List, Optional, Tuple, Union |
| 6 | +from typing import List, Optional, Tuple, Type, Union |
6 | 7 |
|
7 | 8 | from redis.asyncio import Redis as AsyncRedis |
8 | 9 | from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster |
|
17 | 18 | # Type alias for async Redis clients (standalone or cluster) |
18 | 19 | AsyncRedisClientT = Union[AsyncRedis, AsyncRedisCluster] |
19 | 20 |
|
| 21 | + |
| 22 | +def _get_init_params(cls: Type) -> frozenset: |
| 23 | + """Extract parameter names from a class's __init__ method.""" |
| 24 | + sig = inspect.signature(cls.__init__) |
| 25 | + return frozenset( |
| 26 | + name |
| 27 | + for name, param in sig.parameters.items() |
| 28 | + if name != "self" |
| 29 | + and param.kind |
| 30 | + in ( |
| 31 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 32 | + inspect.Parameter.KEYWORD_ONLY, |
| 33 | + ) |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +def _filter_kwargs(kwargs: dict, cls: Type) -> dict: |
| 38 | + """Filter kwargs to only include parameters accepted by the class's __init__.""" |
| 39 | + allowed = _get_init_params(cls) |
| 40 | + return {k: v for k, v in kwargs.items() if k in allowed} |
| 41 | + |
| 42 | + |
20 | 43 | DEFAULT_HEALTH_CHECK_PROBES = 3 |
21 | 44 | DEFAULT_HEALTH_CHECK_INTERVAL = 5 |
22 | 45 | DEFAULT_HEALTH_CHECK_TIMEOUT = 3 |
@@ -152,19 +175,21 @@ async def get_client(self, database) -> AsyncRedisClientT: |
152 | 175 | # Check for both sync and async standalone Redis clients |
153 | 176 | if isinstance(database.client, (AsyncRedis, SyncRedis)): |
154 | 177 | conn_kwargs = database.client.get_connection_kwargs() |
155 | | - client = AsyncRedis(**conn_kwargs) |
| 178 | + filtered_kwargs = _filter_kwargs(conn_kwargs, AsyncRedis) |
| 179 | + client = AsyncRedis(**filtered_kwargs) |
156 | 180 | elif isinstance(database.client, (AsyncRedisCluster, SyncRedisCluster)): |
157 | 181 | # Cluster client - create a single cluster client that handles |
158 | 182 | # topology changes internally |
159 | | - conn_kwargs = database.client.connection_kwargs.copy() |
| 183 | + conn_kwargs = database.client.get_connection_kwargs().copy() |
| 184 | + filtered_kwargs = _filter_kwargs(conn_kwargs, AsyncRedisCluster) |
160 | 185 | startup_nodes = database.client.startup_nodes |
161 | 186 | # Use the first node as the startup node |
162 | 187 | if startup_nodes: |
163 | 188 | first_node = startup_nodes[0] |
164 | 189 | client = AsyncRedisCluster( |
165 | 190 | host=first_node.host, |
166 | 191 | port=first_node.port, |
167 | | - **conn_kwargs, |
| 192 | + **filtered_kwargs, |
168 | 193 | ) |
169 | 194 | else: |
170 | 195 | raise ValueError( |
|
0 commit comments