|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Copyright 2020 The Matrix.org Foundation C.I.C. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import logging |
| 17 | + |
| 18 | +import txredisapi |
| 19 | + |
| 20 | +from synapse.logging.context import PreserveLoggingContext |
| 21 | +from synapse.metrics.background_process_metrics import run_as_background_process |
| 22 | +from synapse.replication.tcp.commands import ( |
| 23 | + COMMAND_MAP, |
| 24 | + Command, |
| 25 | + RdataCommand, |
| 26 | + ReplicateCommand, |
| 27 | +) |
| 28 | +from synapse.util.stringutils import random_string |
| 29 | + |
| 30 | +logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | + |
| 33 | +class RedisSubscriber(txredisapi.SubscriberProtocol): |
| 34 | + """Connection to redis subscribed to replication stream. |
| 35 | + """ |
| 36 | + |
| 37 | + def connectionMade(self): |
| 38 | + logger.info("Connected to redis instance") |
| 39 | + self.subscribe(self.stream_name) |
| 40 | + self.send_command(ReplicateCommand()) |
| 41 | + |
| 42 | + self.handler.new_connection(self) |
| 43 | + |
| 44 | + def messageReceived(self, pattern: str, channel: str, message: str): |
| 45 | + """Received a message from redis. |
| 46 | + """ |
| 47 | + |
| 48 | + if message.strip() == "": |
| 49 | + # Ignore blank lines |
| 50 | + return |
| 51 | + |
| 52 | + line = message |
| 53 | + cmd_name, rest_of_line = line.split(" ", 1) |
| 54 | + |
| 55 | + cmd_cls = COMMAND_MAP[cmd_name] |
| 56 | + try: |
| 57 | + cmd = cmd_cls.from_line(rest_of_line) |
| 58 | + except Exception as e: |
| 59 | + logger.exception( |
| 60 | + "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line |
| 61 | + ) |
| 62 | + self.send_error( |
| 63 | + "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) |
| 64 | + ) |
| 65 | + return |
| 66 | + |
| 67 | + # Now lets try and call on_<CMD_NAME> function |
| 68 | + run_as_background_process( |
| 69 | + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd |
| 70 | + ) |
| 71 | + |
| 72 | + async def handle_command(self, cmd: Command): |
| 73 | + """Handle a command we have received over the replication stream. |
| 74 | +
|
| 75 | + By default delegates to on_<COMMAND>, which should return an awaitable. |
| 76 | +
|
| 77 | + Args: |
| 78 | + cmd: received command |
| 79 | + """ |
| 80 | + handled = False |
| 81 | + |
| 82 | + # First call any command handlers on this instance. These are for redis |
| 83 | + # specific handling. |
| 84 | + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) |
| 85 | + if cmd_func: |
| 86 | + await cmd_func(cmd) |
| 87 | + handled = True |
| 88 | + |
| 89 | + # Then call out to the handler. |
| 90 | + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) |
| 91 | + if cmd_func: |
| 92 | + await cmd_func(cmd) |
| 93 | + handled = True |
| 94 | + |
| 95 | + if not handled: |
| 96 | + logger.warning("Unhandled command: %r", cmd) |
| 97 | + |
| 98 | + def connectionLost(self, reason): |
| 99 | + logger.info("Lost connection to redis instance") |
| 100 | + self.handler.lost_connection(self) |
| 101 | + |
| 102 | + def send_command(self, cmd): |
| 103 | + """Send a command if connection has been established. |
| 104 | +
|
| 105 | + Args: |
| 106 | + cmd (Command) |
| 107 | + """ |
| 108 | + string = "%s %s" % (cmd.NAME, cmd.to_line()) |
| 109 | + if "\n" in string: |
| 110 | + raise Exception("Unexpected newline in command: %r", string) |
| 111 | + |
| 112 | + encoded_string = string.encode("utf-8") |
| 113 | + |
| 114 | + async def _send(): |
| 115 | + with PreserveLoggingContext(): |
| 116 | + await self.redis_connection.publish(self.stream_name, encoded_string) |
| 117 | + |
| 118 | + run_as_background_process("send-cmd", _send) |
| 119 | + |
| 120 | + def stream_update(self, stream_name, token, data): |
| 121 | + """Called when a new update is available to stream to clients. |
| 122 | +
|
| 123 | + We need to check if the client is interested in the stream or not |
| 124 | + """ |
| 125 | + self.send_command(RdataCommand(stream_name, token, data)) |
| 126 | + |
| 127 | + |
| 128 | +class RedisFactory(txredisapi.SubscriberFactory): |
| 129 | + """This is a reconnecting factory that connects to redis and immediately |
| 130 | + subscribes to a stream. |
| 131 | + """ |
| 132 | + |
| 133 | + maxDelay = 5 |
| 134 | + continueTrying = True |
| 135 | + protocol = RedisSubscriber |
| 136 | + |
| 137 | + def __init__(self, hs): |
| 138 | + super(RedisFactory, self).__init__() |
| 139 | + |
| 140 | + self.password = hs.config.redis.redis_password |
| 141 | + |
| 142 | + self.handler = hs.get_tcp_replication() |
| 143 | + self.stream_name = hs.hostname |
| 144 | + |
| 145 | + self.redis_connection = txredisapi.lazyConnection( |
| 146 | + host=hs.config.redis_host, |
| 147 | + port=hs.config.redis_port, |
| 148 | + dbid=hs.config.redis_dbid, |
| 149 | + password=hs.config.redis.redis_password, |
| 150 | + reconnect=True, |
| 151 | + ) |
| 152 | + |
| 153 | + self.conn_id = random_string(5) |
| 154 | + |
| 155 | + def buildProtocol(self, addr): |
| 156 | + p = super(RedisFactory, self).buildProtocol(addr) |
| 157 | + p.handler = self.handler |
| 158 | + p.redis_connection = self.redis_connection |
| 159 | + p.conn_id = self.conn_id |
| 160 | + p.stream_name = self.stream_name |
| 161 | + return p |
0 commit comments