Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
1 change: 1 addition & 0 deletions changelog.d/12672.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
63 changes: 54 additions & 9 deletions synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast

import attr
import txredisapi
Expand All @@ -24,6 +24,7 @@
from twisted.internet.interfaces import IAddress, IConnector
from twisted.python.failure import Failure

from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
Expand Down Expand Up @@ -85,14 +86,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):

Attributes:
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""

synapse_handler: "ReplicationCommandHandler"
synapse_stream_name: str
synapse_stream_prefix: str
synapse_subscribed_channels: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler

def __init__(self, *args: Any, **kwargs: Any):
Expand All @@ -117,8 +119,13 @@ async def _send_subscribe(self) -> None:
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
fully_qualified_stream_names = [
f"{self.synapse_stream_prefix}/{stream_suffix}"
for stream_suffix in self.synapse_subscribed_channels
] + [self.synapse_stream_prefix]
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))

logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
Expand Down Expand Up @@ -217,7 +224,7 @@ async def _async_send_command(self, cmd: Command) -> None:

await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
self.synapse_stream_prefix, encoded_string
)
)

Expand Down Expand Up @@ -300,7 +307,7 @@ def format_address(address: IAddress) -> str:

class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
subscribes to some streams.

Args:
hs
Expand All @@ -326,10 +333,47 @@ def __init__(
)

self.synapse_handler = hs.get_replication_command_handler()
self.synapse_stream_name = hs.hostname
self.synapse_stream_prefix = hs.hostname
self.synapse_subscribed_channels = (
RedisDirectTcpReplicationClientFactory.channels_to_subscribe_to_for_config(
hs.config
)
)

self.synapse_outbound_redis_connection = outbound_redis_connection

@staticmethod
def channels_to_subscribe_to_for_config(config: HomeServerConfig) -> List[str]:
subscribe_to = []

if config.worker.run_background_tasks or config.worker.worker_app is None:
# If we're the main process or the background worker, we want to process
# User IP addresses
subscribe_to.append("USER_IP")

# Subscribe to the following RDATA channels.
# We may be able to reduce this in the future.
subscribe_to += [
"RDATA/account_data",
"RDATA/backfill",
"RDATA/caches",
"RDATA/device_lists",
"RDATA/events",
"RDATA/federation",
"RDATA/groups",
"RDATA/presence",
"RDATA/presence_federation",
"RDATA/push_rules",
"RDATA/pushers",
"RDATA/receipts",
"RDATA/tag_account_data",
"RDATA/to_device",
"RDATA/typing",
"RDATA/user_signature",
]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of hardcoding the list of streams here. Can we either:

  1. Leave RDATA on the main channel for now; or
  2. Subscribe to RDATA/* and use PSUBSCRIBE

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave RDATA on the main channel for now; or

I think I'd prefer this (mind you, it's still transmitting on the main channel). How do you suggest this work without specifying a list of streams here?
I guess we could have them be subscribed on-demand somewhere else, but I was sort of keen to subscribe them all in one place to be sure we're not sending REPLICATE before we are listening on all the desired channels.

Subscribe to RDATA/* and use PSUBSCRIBE

I am somewhat reluctant to do this because it's not really sorting us out to be able to unsubscribe from select streams in the future (and not to mention that it seems like it's more work for Redis to do pattern matches on each receiver rather than having a concrete set of listeners per channel; not sure how it's implemented in practice though).

Copy link
Copy Markdown
Member

@erikjohnston erikjohnston May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sort of depends on what we want the final code to look like. One option that I've been considering is have it so that every handler that wants to listen to a stream has to call something like ReplicationSubscriber.subscribe_to_stream(Stream.NAME, func) in the handler's __init__. Then, when we come to connect to Redis the ReplicationSubscriber has the list of streams the worker is interested in.

What I really want to avoid is to have to manually list these stream names, its a recipe for it getting out of date and its very non-obvious how it all fits together.

Using PSUBSCRIBE would allow us to set up the channels and write to those channels before putting all the logic in for doing ReplicationSubscriber. OTOH, you could implement the ReplicationSubscriber logic before adding the streams

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give that a go. Otherwise it may be worth cutting out any mention of RDATA in this PR and leaving that rewiring for another PR.

its very non-obvious how it all fits together.

yeees. That can be said about much of the replication stuff, I don't want to make it much worse though, so will give this a try

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. I think its worth time boxing the RDATA stuff, as just getting the UserIP stuff split out will provide a bunch of value.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've de-scoped the RDATA stuff a bit for now (mostly because I'd appreciate a few thoughts on where exactly to hook it in; the code processing RDATA commands is a bit convoluted and I thought it was going to take a while to untangle), but I've restructured things in a way that I think makes it a bit more approachable to do cleanly.


return subscribe_to

def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)
Expand All @@ -340,7 +384,8 @@ def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
# protocol.
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name
p.synapse_stream_prefix = self.synapse_stream_prefix
p.synapse_subscribed_channels = self.synapse_subscribed_channels

return p

Expand Down
29 changes: 17 additions & 12 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple

from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
Expand Down Expand Up @@ -475,22 +476,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""

def __init__(self):
self._subscribers = set()
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)

def add_subscriber(self, conn):
def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
self._subscribers_by_channel[channel].add(conn)

def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)

def publish(self, conn, channel, msg) -> int:
def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers:
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])

return len(self._subscribers)
return len(self._subscribers_by_channel)

def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
Expand Down Expand Up @@ -531,9 +535,10 @@ def handle_command(self, command, *args):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
for idx, channel in enumerate(args):
num_channels = idx + 1
self._server.add_subscriber(self, channel)
self.send(["subscribe", channel, num_channels])

# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
Expand Down
103 changes: 103 additions & 0 deletions tests/replication/tcp/test_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
Comment thread
reivilibre marked this conversation as resolved.
Outdated
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
# We only import it to see if it's installed, so ignore the 'unused' import
import txredisapi # noqa: F401

HAVE_TXREDISAPI = True
except ImportError:
HAVE_TXREDISAPI = False

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase

ALL_RDATA_CHANNELS = [
"RDATA/account_data",
"RDATA/backfill",
"RDATA/caches",
"RDATA/device_lists",
"RDATA/events",
"RDATA/federation",
"RDATA/groups",
"RDATA/presence",
"RDATA/presence_federation",
"RDATA/push_rules",
"RDATA/pushers",
"RDATA/receipts",
"RDATA/tag_account_data",
"RDATA/to_device",
"RDATA/typing",
"RDATA/user_signature",
]


class RedisTestCase(HomeserverTestCase):
if not HAVE_TXREDISAPI:
skip = "Redis extras not installed"

def test_subscribed_to_enough_redis_channels(self) -> None:
from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory

# The default main process is subscribed to USER_IP and all RDATA channels.
self.assertCountEqual(
RedisDirectTcpReplicationClientFactory.channels_to_subscribe_to_for_config(
self.hs.config
),
[
"USER_IP",
]
+ ALL_RDATA_CHANNELS,
)


class RedisWorkerTestCase(BaseMultiWorkerStreamTestCase):
if not HAVE_TXREDISAPI:
skip = "Redis extras not installed"

def test_background_worker_subscribed_to_user_ip(self) -> None:
from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory

# The default main process is subscribed to USER_IP and all RDATA channels.
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
},
)
self.assertIn(
"USER_IP",
RedisDirectTcpReplicationClientFactory.channels_to_subscribe_to_for_config(
worker1.config
),
)

def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory

# The default main process is subscribed to USER_IP and all RDATA channels.
worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker2",
"run_background_tasks_on": "worker1",
},
)
self.assertNotIn(
"USER_IP",
RedisDirectTcpReplicationClientFactory.channels_to_subscribe_to_for_config(
worker2.config
),
)