Skip to content

Commit 53e6d8a

Browse files
authored
Fix PubSub timeout propagation to prevent indefinite hangs on socket read operations (#3982)
* Adding timeout optional arg to parse_response methods that will be propagated down to the socket read. This improves pubsub behaviour - now the timeout is applied when waiting for messages. * Fix linters * Fixing tests mocks * Fix linters after web conflict resolution
1 parent e045654 commit 53e6d8a

13 files changed

Lines changed: 897 additions & 38 deletions

File tree

redis/_parsers/hiredis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,12 @@ def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
132132
if custom_timeout:
133133
sock.settimeout(self._socket_timeout)
134134

135-
def read_response(self, disable_decoding=False, push_request=False):
135+
def read_response(
136+
self,
137+
disable_decoding=False,
138+
push_request=False,
139+
timeout: Union[float, object] = SENTINEL,
140+
):
136141
if not self._reader:
137142
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
138143

@@ -152,6 +157,7 @@ def read_response(self, disable_decoding=False, push_request=False):
152157
return self.read_response(
153158
disable_decoding=disable_decoding,
154159
push_request=push_request,
160+
timeout=timeout,
155161
)
156162
return response
157163

@@ -161,7 +167,7 @@ def read_response(self, disable_decoding=False, push_request=False):
161167
response = self._reader.gets()
162168

163169
while response is NOT_ENOUGH_DATA:
164-
self.read_from_socket()
170+
self.read_from_socket(timeout=timeout)
165171
if disable_decoding:
166172
response = self._reader.gets(False)
167173
else:

redis/_parsers/resp2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
44
from ..typing import EncodableT
55
from .base import _AsyncRESPBase, _RESPBase
6-
from .socket import SERVER_CLOSED_CONNECTION_ERROR
6+
from .socket import SENTINEL, SERVER_CLOSED_CONNECTION_ERROR
77

88

99
class _RESP2Parser(_RESPBase):
1010
"""RESP2 protocol implementation"""
1111

12-
def read_response(self, disable_decoding=False):
12+
def read_response(
13+
self, disable_decoding=False, timeout: Union[float, object] = SENTINEL
14+
):
1315
pos = self._buffer.get_pos() if self._buffer else None
1416
try:
15-
result = self._read_response(disable_decoding=disable_decoding)
17+
result = self._read_response(
18+
disable_decoding=disable_decoding, timeout=timeout
19+
)
1620
except BaseException:
1721
if self._buffer:
1822
self._buffer.rewind(pos)
@@ -21,8 +25,10 @@ def read_response(self, disable_decoding=False):
2125
self._buffer.purge()
2226
return result
2327

24-
def _read_response(self, disable_decoding=False):
25-
raw = self._buffer.readline()
28+
def _read_response(
29+
self, disable_decoding=False, timeout: Union[float, object] = SENTINEL
30+
):
31+
raw = self._buffer.readline(timeout=timeout)
2632
if not raw:
2733
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
2834

@@ -51,13 +57,13 @@ def _read_response(self, disable_decoding=False):
5157
elif byte == b"$" and response == b"-1":
5258
return None
5359
elif byte == b"$":
54-
response = self._buffer.read(int(response))
60+
response = self._buffer.read(int(response), timeout=timeout)
5561
# multi-bulk response
5662
elif byte == b"*" and response == b"-1":
5763
return None
5864
elif byte == b"*":
5965
response = [
60-
self._read_response(disable_decoding=disable_decoding)
66+
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
6167
for i in range(int(response))
6268
]
6369
else:

redis/_parsers/resp3.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_AsyncRESPBase,
1010
_RESPBase,
1111
)
12-
from .socket import SERVER_CLOSED_CONNECTION_ERROR
12+
from .socket import SENTINEL, SERVER_CLOSED_CONNECTION_ERROR
1313

1414

1515
class _RESP3Parser(_RESPBase, PushNotificationsParser):
@@ -28,11 +28,18 @@ def handle_pubsub_push_response(self, response):
2828
logger.debug("Push response: " + str(response))
2929
return response
3030

31-
def read_response(self, disable_decoding=False, push_request=False):
31+
def read_response(
32+
self,
33+
disable_decoding=False,
34+
push_request=False,
35+
timeout: Union[float, object] = SENTINEL,
36+
):
3237
pos = self._buffer.get_pos() if self._buffer is not None else None
3338
try:
3439
result = self._read_response(
35-
disable_decoding=disable_decoding, push_request=push_request
40+
disable_decoding=disable_decoding,
41+
push_request=push_request,
42+
timeout=timeout,
3643
)
3744
except BaseException:
3845
if self._buffer is not None:
@@ -48,8 +55,13 @@ def read_response(self, disable_decoding=False, push_request=False):
4855
pass
4956
return result
5057

51-
def _read_response(self, disable_decoding=False, push_request=False):
52-
raw = self._buffer.readline()
58+
def _read_response(
59+
self,
60+
disable_decoding=False,
61+
push_request=False,
62+
timeout: Union[float, object] = SENTINEL,
63+
):
64+
raw = self._buffer.readline(timeout=timeout)
5365
if not raw:
5466
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
5567

@@ -58,7 +70,7 @@ def _read_response(self, disable_decoding=False, push_request=False):
5870
# server returned an error
5971
if byte in (b"-", b"!"):
6072
if byte == b"!":
61-
response = self._buffer.read(int(response))
73+
response = self._buffer.read(int(response), timeout=timeout)
6274
response = response.decode("utf-8", errors="replace")
6375
error = self.parse_error(response)
6476
# if the error is a ConnectionError, raise immediately so the user
@@ -87,22 +99,22 @@ def _read_response(self, disable_decoding=False, push_request=False):
8799
return response == b"t"
88100
# bulk response
89101
elif byte == b"$":
90-
response = self._buffer.read(int(response))
102+
response = self._buffer.read(int(response), timeout=timeout)
91103
# verbatim string response
92104
elif byte == b"=":
93-
response = self._buffer.read(int(response))[4:]
105+
response = self._buffer.read(int(response), timeout=timeout)[4:]
94106
# array response
95107
elif byte == b"*":
96108
response = [
97-
self._read_response(disable_decoding=disable_decoding)
109+
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
98110
for _ in range(int(response))
99111
]
100112
# set response
101113
elif byte == b"~":
102114
# redis can return unhashable types (like dict) in a set,
103115
# so we return sets as list, all the time, for predictability
104116
response = [
105-
self._read_response(disable_decoding=disable_decoding)
117+
self._read_response(disable_decoding=disable_decoding, timeout=timeout)
106118
for _ in range(int(response))
107119
]
108120
# map response
@@ -112,16 +124,22 @@ def _read_response(self, disable_decoding=False, push_request=False):
112124
# became defined to be left-right in version 3.8
113125
resp_dict = {}
114126
for _ in range(int(response)):
115-
key = self._read_response(disable_decoding=disable_decoding)
127+
key = self._read_response(
128+
disable_decoding=disable_decoding, timeout=timeout
129+
)
116130
resp_dict[key] = self._read_response(
117-
disable_decoding=disable_decoding, push_request=push_request
131+
disable_decoding=disable_decoding,
132+
push_request=push_request,
133+
timeout=timeout,
118134
)
119135
response = resp_dict
120136
# push response
121137
elif byte == b">":
122138
response = [
123139
self._read_response(
124-
disable_decoding=disable_decoding, push_request=push_request
140+
disable_decoding=disable_decoding,
141+
push_request=push_request,
142+
timeout=timeout,
125143
)
126144
for _ in range(int(response))
127145
]

redis/_parsers/socket.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,23 @@ def can_read(self, timeout: float) -> bool:
9696
timeout=timeout, raise_on_timeout=False
9797
)
9898

99-
def read(self, length: int) -> bytes:
99+
def read(self, length: int, timeout: Union[float, object] = SENTINEL) -> bytes:
100100
length = length + 2 # make sure to read the \r\n terminator
101101
# BufferIO will return less than requested if buffer is short
102102
data = self._buffer.read(length)
103103
missing = length - len(data)
104104
if missing:
105105
# fill up the buffer and read the remainder
106-
self._read_from_socket(missing)
106+
self._read_from_socket(length=missing, timeout=timeout)
107107
data += self._buffer.read(missing)
108108
return data[:-2]
109109

110-
def readline(self) -> bytes:
110+
def readline(self, timeout: Union[float, object] = SENTINEL) -> bytes:
111111
buf = self._buffer
112112
data = buf.readline()
113113
while not data.endswith(SYM_CRLF):
114114
# there's more data in the socket that we need
115-
self._read_from_socket()
115+
self._read_from_socket(timeout=timeout)
116116
data += buf.readline()
117117

118118
return data[:-2]

redis/asyncio/client.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,44 @@ def failure_callback(error, failure_count):
11761176
raise
11771177

11781178
async def parse_response(self, block: bool = True, timeout: float = 0):
1179-
"""Parse the response from a publish/subscribe command"""
1179+
"""
1180+
Parse the response from a publish/subscribe command.
1181+
1182+
Args:
1183+
block: If True, block indefinitely until a message is available.
1184+
If False, return immediately if no message is available.
1185+
Default: True
1186+
timeout: The timeout in seconds for reading a response when block=False.
1187+
This parameter is ignored when block=True.
1188+
Default: 0 (return immediately if no data available)
1189+
1190+
Returns:
1191+
The parsed response from the server, or None if no message is available
1192+
within the timeout period (when block=False).
1193+
1194+
Important:
1195+
The block and timeout parameters work together:
1196+
- When block=True: timeout is IGNORED, method blocks indefinitely
1197+
- When block=False: timeout is USED, method returns after timeout expires
1198+
1199+
Typically, you should use get_message(timeout=X) instead of calling
1200+
parse_response() directly. The get_message() method automatically sets
1201+
block=False when a timeout is provided, and block=True when timeout=None.
1202+
1203+
Example:
1204+
# Block indefinitely (timeout is ignored)
1205+
response = await pubsub.parse_response(block=True, timeout=0.1)
1206+
1207+
# Non-blocking with 0.1 second timeout
1208+
response = await pubsub.parse_response(block=False, timeout=0.1)
1209+
1210+
# Non-blocking, return immediately
1211+
response = await pubsub.parse_response(block=False, timeout=0)
1212+
1213+
# Recommended: use get_message() instead
1214+
msg = await pubsub.get_message(timeout=0.1) # automatically sets block=False
1215+
msg = await pubsub.get_message(timeout=None) # automatically sets block=True
1216+
"""
11801217
conn = self.connection
11811218
if conn is None:
11821219
raise RuntimeError(

redis/client.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_RedisCallbacksRESP3,
2525
bool_ok,
2626
)
27+
from redis._parsers.socket import SENTINEL
2728
from redis.backoff import ExponentialWithJitterBackoff
2829
from redis.cache import CacheConfig, CacheInterface
2930
from redis.commands import (
@@ -1157,7 +1158,44 @@ def failure_callback(error, failure_count):
11571158
raise
11581159

11591160
def parse_response(self, block=True, timeout=0):
1160-
"""Parse the response from a publish/subscribe command"""
1161+
"""
1162+
Parse the response from a publish/subscribe command.
1163+
1164+
Args:
1165+
block: If True, block indefinitely until a message is available.
1166+
If False, return immediately if no message is available.
1167+
Default: True
1168+
timeout: The timeout in seconds for reading a response when block=False.
1169+
This parameter is ignored when block=True.
1170+
Default: 0 (return immediately if no data available)
1171+
1172+
Returns:
1173+
The parsed response from the server, or None if no message is available
1174+
within the timeout period (when block=False).
1175+
1176+
Important:
1177+
The block and timeout parameters work together:
1178+
- When block=True: timeout is IGNORED, method blocks indefinitely
1179+
- When block=False: timeout is USED, method returns after timeout expires
1180+
1181+
Typically, you should use get_message(timeout=X) instead of calling
1182+
parse_response() directly. The get_message() method automatically sets
1183+
block=False when a timeout is provided, and block=True when timeout=None.
1184+
1185+
Example:
1186+
# Block indefinitely (timeout is ignored)
1187+
response = pubsub.parse_response(block=True, timeout=0.1)
1188+
1189+
# Non-blocking with 0.1 second timeout
1190+
response = pubsub.parse_response(block=False, timeout=0.1)
1191+
1192+
# Non-blocking, return immediately
1193+
response = pubsub.parse_response(block=False, timeout=0)
1194+
1195+
# Recommended: use get_message() instead
1196+
msg = pubsub.get_message(timeout=0.1) # automatically sets block=False
1197+
msg = pubsub.get_message(timeout=None) # automatically sets block=True
1198+
"""
11611199
conn = self.connection
11621200
if conn is None:
11631201
raise RuntimeError(
@@ -1171,9 +1209,13 @@ def try_read():
11711209
if not block:
11721210
if not conn.can_read(timeout=timeout):
11731211
return None
1212+
read_timeout = timeout
11741213
else:
11751214
conn.connect()
1176-
return conn.read_response(disconnect_on_error=False, push_request=True)
1215+
read_timeout = SENTINEL # Use default socket timeout for blocking
1216+
return conn.read_response(
1217+
disconnect_on_error=False, push_request=True, timeout=read_timeout
1218+
)
11771219

11781220
response = self._execute(conn, try_read)
11791221

redis/cluster.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,10 +2669,14 @@ def _get_node_pubsub(self, node):
26692669
self.node_pubsub_mapping[node.name] = pubsub
26702670
return pubsub
26712671

2672-
def _sharded_message_generator(self):
2672+
def _sharded_message_generator(self, timeout=0.0):
26732673
for _ in range(len(self.node_pubsub_mapping)):
26742674
pubsub = next(self._pubsubs_generator)
2675-
message = pubsub.get_message()
2675+
# Don't pass ignore_subscribe_messages here - let get_sharded_message
2676+
# handle the filtering after processing subscription state changes
2677+
message = pubsub.get_message(
2678+
ignore_subscribe_messages=False, timeout=timeout
2679+
)
26762680
if message is not None:
26772681
return message
26782682
return None
@@ -2690,7 +2694,7 @@ def get_sharded_message(
26902694
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
26912695
)
26922696
else:
2693-
message = self._sharded_message_generator()
2697+
message = self._sharded_message_generator(timeout=timeout)
26942698
if message is None:
26952699
return None
26962700
elif str_if_bytes(message["type"]) == "sunsubscribe":

0 commit comments

Comments
 (0)