Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7427.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for running replication over Redis when using workers.
51 changes: 36 additions & 15 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ def __init__(self, hs):
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()

# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]

self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
Expand All @@ -99,9 +96,13 @@ def __init__(self, hs):
# The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory]

# The currently connected connections.
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]

# For each connection, the incoming streams that are coming from that connection
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
Expand Down Expand Up @@ -257,9 +258,11 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
if stream_name not in self._streams_connected:
# If the stream isn't marked as connected then we haven't seen a
# `POSITION` command yet, and so we may have missed some rows.
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
Expand Down Expand Up @@ -302,21 +305,25 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
# Ignore POSITION that are just our own echoes
return

stream = self._streams.get(cmd.stream_name)
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())

stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
logger.error("Got POSITION for unknown stream: %s", stream_name)
return

# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
self._streams_connected.discard(cmd.stream_name)
for streams in self._streams_by_connection.values():
streams.discard(stream_name)

# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, [])
self._pending_batches.pop(stream_name, [])

# Find where we previously streamed up to.
current_token = stream.current_token()
Expand All @@ -326,6 +333,12 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(
updates,
current_token,
Expand All @@ -341,16 +354,18 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):

for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name,
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)

logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)

# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
await self._replication_data_handler.on_position(stream_name, cmd.token)

self._streams_connected.add(cmd.stream_name)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)

async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
Expand Down Expand Up @@ -408,6 +423,12 @@ def new_connection(self, connection: AbstractConnection):
def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
"""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
if streams:
logger.info(
"Lost replication connection; streams now disconnected: %s", streams
)
try:
self._connections.remove(connection)
except ValueError:
Expand Down
52 changes: 33 additions & 19 deletions synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import txredisapi

from synapse.logging.context import PreserveLoggingContext
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
Command,
Expand All @@ -41,17 +41,23 @@
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream.

Parses incoming messages from redis into replication commands, and passes
them to `ReplicationCommandHandler`
This class fulfils two functions:

(a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`

(b) it implements the AbstractConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.

Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set
immediately after initialisation.

Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to (not anything to
do with Synapse replication streams).
stream_name: The *redis* stream name to subscribe to and publish from
(not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
commands.
"""
Expand All @@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol

def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade()
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())

run_as_background_process("subscribe-replication", self._send_subscribe)
self.handler.new_connection(self)

async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")

def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
Expand Down Expand Up @@ -120,8 +136,8 @@ async def handle_command(self, cmd: Command):
logger.warning("Unhandled command: %r", cmd)

def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self)

def send_command(self, cmd: Command):
Expand All @@ -130,6 +146,10 @@ def send_command(self, cmd: Command):
Args:
cmd (Command)
"""
run_as_background_process("send-cmd", self._async_send_command, cmd)

async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
Expand All @@ -140,15 +160,9 @@ def send_command(self, cmd: Command):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()

async def _send():
with PreserveLoggingContext():
# Note that we use the other connection as we can't send
# commands using the subscription connection.
await self.outbound_redis_connection.publish(
self.stream_name, encoded_string
)

run_as_background_process("send-cmd", _send)
await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
)


class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Expand Down