Skip to content

Commit f02c66b

Browse files
praboudpetyaslavovaCopilot
committed
Fix issues with ClusterPipeline connection management (#3804)
* Fix connection leak & dirty connection reuse * Add tests for connection leak and dirty connection reuse bugs * Add comment * Update redis/cluster.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fixing tests --------- Co-authored-by: petyaslavova <petya.slavova@redis.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 1958065 commit f02c66b

2 files changed

Lines changed: 138 additions & 39 deletions

File tree

redis/cluster.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,7 +3068,9 @@ def __init__(self, args, options=None, position=None):
30683068
class NodeCommands:
30693069
""" """
30703070

3071-
def __init__(self, parse_response, connection_pool, connection):
3071+
def __init__(
3072+
self, parse_response, connection_pool: ConnectionPool, connection: Connection
3073+
):
30723074
""" """
30733075
self.parse_response = parse_response
30743076
self.connection_pool = connection_pool
@@ -3423,15 +3425,18 @@ def _send_cluster_commands(
34233425
attempt = sorted(stack, key=lambda x: x.position)
34243426
is_default_node = False
34253427
# build a list of node objects based on node names we need to
3426-
nodes = {}
3427-
3428-
# as we move through each command that still needs to be processed,
3429-
# we figure out the slot number that command maps to, then from
3430-
# the slot determine the node.
3431-
for c in attempt:
3432-
command_policies = self._pipe._policy_resolver.resolve(c.args[0].lower())
3428+
nodes: dict[str, NodeCommands] = {}
3429+
nodes_written = 0
3430+
nodes_read = 0
34333431

3434-
while True:
3432+
try:
3433+
# as we move through each command that still needs to be processed,
3434+
# we figure out the slot number that command maps to, then from
3435+
# the slot determine the node.
3436+
for c in attempt:
3437+
command_policies = self._pipe._policy_resolver.resolve(
3438+
c.args[0].lower()
3439+
)
34353440
# refer to our internal node -> slot table that
34363441
# tells us where a given command should route to.
34373442
# (it might be possible we have a cached node that no longer
@@ -3506,37 +3511,38 @@ def _send_cluster_commands(
35063511
try:
35073512
connection = get_connection(redis_node)
35083513
except (ConnectionError, TimeoutError):
3514+
# Release any connections we've already acquired before clearing nodes
35093515
for n in nodes.values():
35103516
n.connection_pool.release(n.connection)
35113517
# Connection retries are being handled in the node's
35123518
# Retry object. Reinitialize the node -> slot table.
35133519
self._nodes_manager.initialize()
35143520
if is_default_node:
35153521
self._pipe.replace_default_node()
3522+
nodes = {}
35163523
raise
35173524
nodes[node_name] = NodeCommands(
35183525
redis_node.parse_response,
35193526
redis_node.connection_pool,
35203527
connection,
35213528
)
35223529
nodes[node_name].append(c)
3523-
break
35243530

3525-
# send the commands in sequence.
3526-
# we write to all the open sockets for each node first,
3527-
# before reading anything
3528-
# this allows us to flush all the requests out across the
3529-
# network
3530-
# so that we can read them from different sockets as they come back.
3531-
# we dont' multiplex on the sockets as they come available,
3532-
# but that shouldn't make too much difference.
3531+
# send the commands in sequence.
3532+
# we write to all the open sockets for each node first,
3533+
# before reading anything
3534+
# this allows us to flush all the requests out across the
3535+
# network
3536+
# so that we can read them from different sockets as they come back.
3537+
# we dont' multiplex on the sockets as they come available,
3538+
# but that shouldn't make too much difference.
35333539

3534-
# Start timing for observability
3535-
start_time = time.monotonic()
3540+
# Start timing for observability
3541+
start_time = time.monotonic()
35363542

3537-
try:
35383543
node_commands = nodes.values()
35393544
for n in node_commands:
3545+
nodes_written += 1
35403546
n.write()
35413547

35423548
for n in node_commands:
@@ -3550,26 +3556,24 @@ def _send_cluster_commands(
35503556
db_namespace=str(n.connection.db),
35513557
batch_size=len(n.commands),
35523558
)
3559+
nodes_read += 1
35533560
finally:
3554-
# release all of the redis connections we allocated earlier
3561+
# release all the redis connections we allocated earlier
35553562
# back into the connection pool.
3556-
# we used to do this step as part of a try/finally block,
3557-
# but it is really dangerous to
3558-
# release connections back into the pool if for some
3559-
# reason the socket has data still left in it
3560-
# from a previous operation. The write and
3561-
# read operations already have try/catch around them for
3562-
# all known types of errors including connection
3563-
# and socket level errors.
3564-
# So if we hit an exception, something really bad
3565-
# happened and putting any oF
3566-
# these connections back into the pool is a very bad idea.
3567-
# the socket might have unread buffer still sitting in it,
3568-
# and then the next time we read from it we pass the
3569-
# buffered result back from a previous command and
3570-
# every single request after to that connection will always get
3571-
# a mismatched result.
3572-
for n in nodes.values():
3563+
# if the connection is dirty (that is: we've written
3564+
# commands to it, but haven't read the responses), we need
3565+
# to close the connection before returning it to the pool.
3566+
# otherwise, the next caller to use this connection will
3567+
# read the response from _this_ request, not its own request.
3568+
# disconnecting discards the dirty state & forces the next
3569+
# caller to reconnect.
3570+
# NOTE: dicts have a consistent ordering; we're iterating
3571+
# through nodes.values() in the same order as we are when
3572+
# reading / writing to the connections above, which is critical
3573+
# for how we're using the nodes_written/nodes_read offsets.
3574+
for i, n in enumerate(nodes.values()):
3575+
if i < nodes_written and i >= nodes_read:
3576+
n.connection.disconnect()
35733577
n.connection_pool.release(n.connection)
35743578

35753579
# if the response isn't an exception it is a

tests/test_cluster.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3844,6 +3844,101 @@ def test_pipeline_discard(self, r):
38443844
assert response[0]
38453845
assert r.get(f"{hashkey}:foo") == b"bar"
38463846

3847+
def test_connection_leak_on_non_timeout_error_during_connect(self, r):
3848+
"""
3849+
Test that connections are not leaked when a non-TimeoutError/ConnectionError
3850+
is raised during get_connection(). The bugfix ensures that if an error
3851+
occurs that isn't explicitly handled, we don't leak connections.
3852+
"""
3853+
# Ensure keys map to different nodes
3854+
assert r.keyslot("a") != r.keyslot("b")
3855+
3856+
orig_func = redis.cluster.get_connection
3857+
with patch("redis.cluster.get_connection") as get_connection:
3858+
3859+
def raise_custom_error(target_node, *args, **kwargs):
3860+
# Raise a RuntimeError (not ConnectionError or TimeoutError)
3861+
# on the second call (when getting second connection)
3862+
if get_connection.call_count == 2:
3863+
raise RuntimeError("Some unexpected error during connection")
3864+
else:
3865+
return orig_func(target_node, *args, **kwargs)
3866+
3867+
get_connection.side_effect = raise_custom_error
3868+
3869+
with pytest.raises(RuntimeError):
3870+
r.pipeline().get("a").get("b").execute()
3871+
3872+
# Verify that all connections were returned to the pool
3873+
# (not leaked) even though a non-standard error was raised
3874+
for cluster_node in r.nodes_manager.nodes_cache.values():
3875+
connection_pool = cluster_node.redis_connection.connection_pool
3876+
num_of_conns = len(connection_pool._available_connections)
3877+
assert num_of_conns == connection_pool._created_connections, (
3878+
f"Connection leaked: expected {connection_pool._created_connections} "
3879+
f"available, got {num_of_conns}"
3880+
)
3881+
3882+
def test_dirty_connection_not_reused(self, r):
3883+
"""
3884+
Test that dirty connections (with unread responses) are not reused.
3885+
A dirty connection is one where we've written commands but haven't
3886+
read all responses. If such a connection is returned to the pool,
3887+
the next caller will read responses from the previous request.
3888+
"""
3889+
# Ensure we're using multiple nodes to test the dirty connection scenario
3890+
assert r.keyslot("a") != r.keyslot("b")
3891+
3892+
# Mock the write method to raise an error after writing to only some nodes
3893+
orig_write = redis.cluster.NodeCommands.write
3894+
3895+
write_count = 0
3896+
3897+
def mock_write(self):
3898+
nonlocal write_count
3899+
write_count += 1
3900+
# Allow the first write to succeed
3901+
if write_count == 1:
3902+
return orig_write(self)
3903+
# Simulate a failure after the first write (leaving connection dirty)
3904+
else:
3905+
raise RuntimeError("Simulated write error")
3906+
3907+
# Patch Connection.disconnect so we can assert that at least one
3908+
# connection was disconnected when the write error occurred.
3909+
original_disconnect = Connection.disconnect
3910+
disconnect_called = []
3911+
3912+
def track_disconnect(self, *args):
3913+
disconnect_called.append(True)
3914+
return original_disconnect(self, *args)
3915+
3916+
with patch.object(Connection, "disconnect", track_disconnect):
3917+
with patch.object(redis.cluster.NodeCommands, "write", mock_write):
3918+
with pytest.raises(RuntimeError):
3919+
r.pipeline().get("a").get("b").execute()
3920+
3921+
# Ensure that at least one connection was disconnected as part of
3922+
# handling the dirty connection created by the write failure.
3923+
assert disconnect_called, (
3924+
"Expected at least one connection to be disconnected when "
3925+
"handling a dirty connection, but disconnect() was not called."
3926+
)
3927+
# After the error, verify that no connections are in the available pool
3928+
# with dirty state (unread responses). If a connection is dirty, it should
3929+
# have been disconnected before being returned to the pool.
3930+
# We verify this by checking the connections can be reused successfully.
3931+
try:
3932+
# Try to execute a command on each connection to verify
3933+
# they're clean (not holding responses from previous requests)
3934+
result = r.ping()
3935+
assert result is True
3936+
except Exception as e:
3937+
pytest.fail(
3938+
f"Connection reuse after dirty state failed: {e}. "
3939+
f"This indicates a dirty connection was returned to the pool."
3940+
)
3941+
38473942

38483943
@pytest.mark.onlycluster
38493944
class TestReadOnlyPipeline:

0 commit comments

Comments
 (0)