Skip to content

Commit 87c996c

Browse files
committed
POC
1 parent cab5cfc commit 87c996c

9 files changed

Lines changed: 162 additions & 5 deletions

File tree

redis/asyncio/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Dict,
1414
Iterable,
1515
List,
16+
Literal,
1617
Mapping,
1718
MutableMapping,
1819
Optional,
@@ -129,6 +130,9 @@ class Redis(
129130
Connection object to talk to redis.
130131
"""
131132

133+
# Type discrimination marker for @overload self-type pattern
134+
_is_async_client: Literal[True] = True
135+
132136
response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
133137

134138
@classmethod

redis/asyncio/cluster.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Dict,
1717
Generator,
1818
List,
19+
Literal,
1920
Mapping,
2021
Optional,
2122
Set,
@@ -238,6 +239,9 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
238239
kwargs["ssl"] = True
239240
return cls(**kwargs)
240241

242+
# Type discrimination marker for @overload self-type pattern
243+
_is_async_client: Literal[True] = True
244+
241245
__slots__ = (
242246
"_initialize",
243247
"_lock",

redis/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Callable,
1010
Dict,
1111
List,
12+
Literal,
1213
Mapping,
1314
Optional,
1415
Set,
@@ -135,6 +136,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
135136
It is not safe to pass PubSub or Pipeline objects between threads.
136137
"""
137138

139+
# Type discrimination marker for @overload self-type pattern
140+
_is_async_client: Literal[False] = False
141+
138142
@classmethod
139143
def from_url(cls, url: str, **kwargs) -> "Redis":
140144
"""

redis/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
554554
class RedisCluster(
555555
AbstractRedisCluster, MaintNotificationsAbstractRedisCluster, RedisClusterCommands
556556
):
557+
# Type discrimination marker for @overload self-type pattern
558+
_is_async_client: Literal[False] = False
559+
557560
@classmethod
558561
def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
559562
"""

redis/commands/core.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Set,
3232
Tuple,
3333
Union,
34+
overload,
3435
)
3536

3637
from redis.asyncio.observability.recorder import (
@@ -40,6 +41,7 @@
4041
from redis.typing import (
4142
AbsExpiryT,
4243
AnyKeyT,
44+
AsyncClientProtocol,
4345
BitfieldOffsetT,
4446
ChannelT,
4547
CommandsProtocol,
@@ -55,6 +57,7 @@
5557
ResponseT,
5658
ScriptTextT,
5759
StreamIdT,
60+
SyncClientProtocol,
5861
TimeoutSecT,
5962
ZScoreBoundT,
6063
)
@@ -2064,6 +2067,15 @@ def digest(self, name: KeyT) -> Union[str, bytes, None]:
20642067
# Bulk string response is already handled (bytes/str based on decode_responses)
20652068
return self.execute_command("DIGEST", name)
20662069

2070+
# --- @overload pattern for get() ---
2071+
# Sync client returns bytes | None directly
2072+
@overload
2073+
def get(self: SyncClientProtocol, name: KeyT) -> Optional[bytes]: ...
2074+
2075+
# Async client returns Awaitable[bytes | None]
2076+
@overload
2077+
def get(self: AsyncClientProtocol, name: KeyT) -> Awaitable[Optional[bytes]]: ...
2078+
20672079
def get(self, name: KeyT) -> ResponseT:
20682080
"""
20692081
Return the value at key ``name``, or None if the key doesn't exist
@@ -2554,6 +2566,47 @@ def restore(
25542566

25552567
return self.execute_command("RESTORE", *params)
25562568

2569+
# --- @overload pattern for set() ---
2570+
# Sync client returns bool | None directly
2571+
@overload
2572+
def set(
2573+
self: SyncClientProtocol,
2574+
name: KeyT,
2575+
value: EncodableT,
2576+
ex: Optional[ExpiryT] = ...,
2577+
px: Optional[ExpiryT] = ...,
2578+
nx: bool = ...,
2579+
xx: bool = ...,
2580+
keepttl: bool = ...,
2581+
get: bool = ...,
2582+
exat: Optional[AbsExpiryT] = ...,
2583+
pxat: Optional[AbsExpiryT] = ...,
2584+
ifeq: Optional[Union[bytes, str]] = ...,
2585+
ifne: Optional[Union[bytes, str]] = ...,
2586+
ifdeq: Optional[str] = ...,
2587+
ifdne: Optional[str] = ...,
2588+
) -> Optional[bool]: ...
2589+
2590+
# Async client returns Awaitable[bool | None]
2591+
@overload
2592+
def set(
2593+
self: AsyncClientProtocol,
2594+
name: KeyT,
2595+
value: EncodableT,
2596+
ex: Optional[ExpiryT] = ...,
2597+
px: Optional[ExpiryT] = ...,
2598+
nx: bool = ...,
2599+
xx: bool = ...,
2600+
keepttl: bool = ...,
2601+
get: bool = ...,
2602+
exat: Optional[AbsExpiryT] = ...,
2603+
pxat: Optional[AbsExpiryT] = ...,
2604+
ifeq: Optional[Union[bytes, str]] = ...,
2605+
ifne: Optional[Union[bytes, str]] = ...,
2606+
ifdeq: Optional[str] = ...,
2607+
ifdne: Optional[str] = ...,
2608+
) -> Awaitable[Optional[bool]]: ...
2609+
25572610
@experimental_args(["ifeq", "ifne", "ifdeq", "ifdne"])
25582611
def set(
25592612
self,

redis/commands/redismodules.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .json import JSON
99
from .search import AsyncSearch, Search
1010
from .timeseries import TimeSeries
11-
from .vectorset import VectorSet
11+
from .vectorset import AsyncVectorSet, VectorSet
1212

1313

1414
class RedisModuleCommands:
@@ -99,3 +99,11 @@ def ft(self, index_name="idx") -> AsyncSearch:
9999

100100
s = AsyncSearch(client=self, index_name=index_name)
101101
return s
102+
103+
def vset(self) -> AsyncVectorSet:
104+
"""Access the VectorSet commands namespace."""
105+
106+
from .vectorset import AsyncVectorSet
107+
108+
vset = AsyncVectorSet(client=self)
109+
return vset

redis/commands/vectorset/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import Literal
23

34
from redis._parsers.helpers import pairs_to_dict
45
from redis.commands.vectorset.utils import (
@@ -18,9 +19,11 @@
1819
)
1920

2021

21-
class VectorSet(VectorSetCommands):
22+
class _VectorSetBase(VectorSetCommands):
23+
"""Base class with shared initialization logic for VectorSet clients."""
24+
2225
def __init__(self, client, **kwargs):
23-
"""Create a new VectorSet client."""
26+
"""Initialize VectorSet client with callbacks."""
2427
# Set the module commands' callbacks
2528
self._MODULE_CALLBACKS = {
2629
VEMB_CMD: parse_vemb_result,
@@ -44,3 +47,15 @@ def __init__(self, client, **kwargs):
4447

4548
for k, v in self._MODULE_CALLBACKS.items():
4649
self.client.set_response_callback(k, v)
50+
51+
52+
class VectorSet(_VectorSetBase):
53+
"""Sync VectorSet client."""
54+
55+
_is_async_client: Literal[False] = False
56+
57+
58+
class AsyncVectorSet(_VectorSetBase):
59+
"""Async VectorSet client."""
60+
61+
_is_async_client: Literal[True] = True

redis/commands/vectorset/commands.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import json
22
from enum import Enum
3-
from typing import Any, Awaitable, Dict, List, Optional, Union
3+
from typing import Any, Awaitable, Dict, List, Optional, Union, overload
44

55
from redis.client import NEVER_DECODE
66
from redis.commands.helpers import get_protocol_version
77
from redis.exceptions import DataError
8-
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
8+
from redis.typing import (
9+
AsyncClientProtocol,
10+
CommandsProtocol,
11+
EncodableT,
12+
KeyT,
13+
Number,
14+
ResponseT,
15+
SyncClientProtocol,
16+
)
917

1018
VADD_CMD = "VADD"
1119
VSIM_CMD = "VSIM"
@@ -129,6 +137,38 @@ def vadd(
129137

130138
return self.execute_command(VADD_CMD, key, *pieces)
131139

140+
@overload
141+
def vsim(
142+
self: SyncClientProtocol,
143+
key: KeyT,
144+
input: Union[List[float], bytes, str],
145+
with_scores: Optional[bool] = ...,
146+
with_attribs: Optional[bool] = ...,
147+
count: Optional[int] = ...,
148+
ef: Optional[Number] = ...,
149+
filter: Optional[str] = ...,
150+
filter_ef: Optional[str] = ...,
151+
truth: Optional[bool] = ...,
152+
no_thread: Optional[bool] = ...,
153+
epsilon: Optional[Number] = ...,
154+
) -> VSimResult: ...
155+
156+
@overload
157+
def vsim(
158+
self: AsyncClientProtocol,
159+
key: KeyT,
160+
input: Union[List[float], bytes, str],
161+
with_scores: Optional[bool] = ...,
162+
with_attribs: Optional[bool] = ...,
163+
count: Optional[int] = ...,
164+
ef: Optional[Number] = ...,
165+
filter: Optional[str] = ...,
166+
filter_ef: Optional[str] = ...,
167+
truth: Optional[bool] = ...,
168+
no_thread: Optional[bool] = ...,
169+
epsilon: Optional[Number] = ...,
170+
) -> Awaitable[VSimResult]: ...
171+
132172
def vsim(
133173
self,
134174
key: KeyT,

redis/typing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
Any,
77
Awaitable,
88
Iterable,
9+
Literal,
910
Mapping,
1011
Protocol,
1112
Type,
1213
TypeVar,
1314
Union,
15+
runtime_checkable,
1416
)
1517

1618
if TYPE_CHECKING:
@@ -19,6 +21,30 @@
1921

2022

2123
Number = Union[int, float]
24+
25+
26+
@runtime_checkable
27+
class AsyncClientProtocol(Protocol):
28+
"""Protocol for asynchronous Redis clients (redis.asyncio.client.Redis).
29+
30+
This protocol uses a Literal marker to identify async clients.
31+
Used in @overload to provide correct return types for async clients.
32+
"""
33+
34+
_is_async_client: Literal[True]
35+
36+
37+
@runtime_checkable
38+
class SyncClientProtocol(Protocol):
39+
"""Protocol for synchronous Redis clients (redis.client.Redis).
40+
41+
This protocol uses a Literal marker to identify sync clients.
42+
Used in @overload to provide correct return types for sync clients.
43+
"""
44+
45+
_is_async_client: Literal[False]
46+
47+
2248
EncodedT = Union[bytes, bytearray, memoryview]
2349
DecodedT = Union[str, int, float]
2450
EncodableT = Union[EncodedT, DecodedT]

0 commit comments

Comments
 (0)