Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 0e719f2

Browse files
authored
Thread through instance name to replication client. (#7369)
For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.
1 parent 3085cde commit 0e719f2

12 files changed

Lines changed: 101 additions & 41 deletions

File tree

changelog.d/7369.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Thread through instance name to replication client.

synapse/app/generic_worker.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,11 @@ def __init__(self, hs):
646646
else:
647647
self.send_handler = None
648648

649-
async def on_rdata(self, stream_name, token, rows):
650-
await super(GenericWorkerReplicationHandler, self).on_rdata(
651-
stream_name, token, rows
652-
)
653-
await self.process_and_notify(stream_name, token, rows)
649+
async def on_rdata(self, stream_name, instance_name, token, rows):
650+
await super().on_rdata(stream_name, instance_name, token, rows)
651+
await self._process_and_notify(stream_name, instance_name, token, rows)
654652

655-
async def process_and_notify(self, stream_name, token, rows):
653+
async def _process_and_notify(self, stream_name, instance_name, token, rows):
656654
try:
657655
if self.send_handler:
658656
await self.send_handler.process_replication_rows(

synapse/replication/http/_base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import abc
1717
import logging
1818
import re
19+
from inspect import signature
1920
from typing import Dict, List, Tuple
2021

2122
from six import raise_from
@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
6061
must call `register` to register the path with the HTTP server.
6162
6263
Requests can be sent by calling the client returned by `make_client`.
64+
Requests are sent to master process by default, but can be sent to other
65+
named processes by specifying an `instance_name` keyword argument.
6366
6467
Attributes:
6568
NAME (str): A name for the endpoint, added to the path as well as used
@@ -91,6 +94,16 @@ def __init__(self, hs):
9194
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
9295
)
9396

97+
# We reserve `instance_name` as a parameter to sending requests, so we
98+
# assert here that sub classes don't try and use the name.
99+
assert (
100+
"instance_name" not in self.PATH_ARGS
101+
), "`instance_name` is a reserved paramater name"
102+
assert (
103+
"instance_name"
104+
not in signature(self.__class__._serialize_payload).parameters
105+
), "`instance_name` is a reserved paramater name"
106+
94107
assert self.METHOD in ("PUT", "POST", "GET")
95108

96109
@abc.abstractmethod
@@ -135,7 +148,11 @@ def make_client(cls, hs):
135148

136149
@trace(opname="outgoing_replication_request")
137150
@defer.inlineCallbacks
138-
def send_request(**kwargs):
151+
def send_request(instance_name="master", **kwargs):
152+
# Currently we only support sending requests to master process.
153+
if instance_name != "master":
154+
raise Exception("Unknown instance")
155+
139156
data = yield cls._serialize_payload(**kwargs)
140157

141158
url_args = [

synapse/replication/http/streams.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
5050
def __init__(self, hs):
5151
super().__init__(hs)
5252

53+
self._instance_name = hs.get_instance_name()
54+
5355
# We pull the streams from the replication steamer (if we try and make
5456
# them ourselves we end up in an import loop).
5557
self.streams = hs.get_replication_streamer().get_streams()
@@ -67,7 +69,7 @@ async def _handle_request(self, request, stream_name):
6769
upto_token = parse_integer(request, "upto_token", required=True)
6870

6971
updates, upto_token, limited = await stream.get_updates_since(
70-
from_token, upto_token
72+
self._instance_name, from_token, upto_token
7173
)
7274

7375
return (

synapse/replication/tcp/client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,19 @@ class ReplicationDataHandler:
8686
def __init__(self, store: BaseSlavedStore):
8787
self.store = store
8888

89-
async def on_rdata(self, stream_name: str, token: int, rows: list):
89+
async def on_rdata(
90+
self, stream_name: str, instance_name: str, token: int, rows: list
91+
):
9092
"""Called to handle a batch of replication data with a given stream token.
9193
9294
By default this just pokes the slave store. Can be overridden in subclasses to
9395
handle more.
9496
9597
Args:
96-
stream_name (str): name of the replication stream for this batch of rows
97-
token (int): stream token for this batch of rows
98-
rows (list): a list of Stream.ROW_TYPE objects as returned by
99-
Stream.parse_row.
98+
stream_name: name of the replication stream for this batch of rows
99+
instance_name: the instance that wrote the rows.
100+
token: stream token for this batch of rows
101+
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
100102
"""
101103
self.store.process_replication_rows(stream_name, token, rows)
102104

synapse/replication/tcp/handler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,19 +278,24 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
278278
# Check if this is the last of a batch of updates
279279
rows = self._pending_batches.pop(stream_name, [])
280280
rows.append(row)
281-
await self.on_rdata(stream_name, cmd.token, rows)
281+
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
282282

283-
async def on_rdata(self, stream_name: str, token: int, rows: list):
283+
async def on_rdata(
284+
self, stream_name: str, instance_name: str, token: int, rows: list
285+
):
284286
"""Called to handle a batch of replication data with a given stream token.
285287
286288
Args:
287289
stream_name: name of the replication stream for this batch of rows
290+
instance_name: the instance that wrote the rows.
288291
token: stream token for this batch of rows
289292
rows: a list of Stream.ROW_TYPE objects as returned by
290293
Stream.parse_row.
291294
"""
292295
logger.debug("Received rdata %s -> %s", stream_name, token)
293-
await self._replication_data_handler.on_rdata(stream_name, token, rows)
296+
await self._replication_data_handler.on_rdata(
297+
stream_name, instance_name, token, rows
298+
)
294299

295300
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
296301
if cmd.instance_name == self._instance_name:
@@ -325,7 +330,9 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
325330
updates,
326331
current_token,
327332
missing_updates,
328-
) = await stream.get_updates_since(current_token, cmd.token)
333+
) = await stream.get_updates_since(
334+
cmd.instance_name, current_token, cmd.token
335+
)
329336

330337
# TODO: add some tests for this
331338

@@ -334,7 +341,10 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
334341

335342
for token, rows in _batch_updates(updates):
336343
await self.on_rdata(
337-
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
344+
cmd.stream_name,
345+
cmd.instance_name,
346+
token,
347+
[stream.parse_row(row) for row in rows],
338348
)
339349

340350
# We've now caught up to position sent to us, notify handler.

synapse/replication/tcp/streams/_base.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818
from collections import namedtuple
19-
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
19+
from typing import Any, Awaitable, Callable, List, Optional, Tuple
2020

2121
import attr
2222

@@ -53,6 +53,7 @@
5353
#
5454
# The arguments are:
5555
#
56+
# * instance_name: the writer of the stream
5657
# * from_token: the previous stream token: the starting point for fetching the
5758
# updates
5859
# * to_token: the new stream token: the point to get updates up to
@@ -62,7 +63,7 @@
6263
# If there are more updates available, it should set `limited` in the result, and
6364
# it will be called again to get the next batch.
6465
#
65-
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
66+
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
6667

6768

6869
class Stream(object):
@@ -93,6 +94,7 @@ def parse_row(cls, row: StreamRow):
9394

9495
def __init__(
9596
self,
97+
local_instance_name: str,
9698
current_token_function: Callable[[], Token],
9799
update_function: UpdateFunction,
98100
):
@@ -108,9 +110,11 @@ def __init__(
108110
stream tokens. See the UpdateFunction type definition for more info.
109111
110112
Args:
113+
local_instance_name: The instance name of the current process
111114
current_token_function: callback to get the current token, as above
112115
update_function: callback go get stream updates, as above
113116
"""
117+
self.local_instance_name = local_instance_name
114118
self.current_token = current_token_function
115119
self.update_function = update_function
116120

@@ -135,14 +139,14 @@ async def get_updates(self) -> StreamUpdateResult:
135139
"""
136140
current_token = self.current_token()
137141
updates, current_token, limited = await self.get_updates_since(
138-
self.last_token, current_token
142+
self.local_instance_name, self.last_token, current_token
139143
)
140144
self.last_token = current_token
141145

142146
return updates, current_token, limited
143147

144148
async def get_updates_since(
145-
self, from_token: Token, upto_token: Token
149+
self, instance_name: str, from_token: Token, upto_token: Token
146150
) -> StreamUpdateResult:
147151
"""Like get_updates except allows specifying from when we should
148152
stream updates
@@ -160,19 +164,19 @@ async def get_updates_since(
160164
return [], upto_token, False
161165

162166
updates, upto_token, limited = await self.update_function(
163-
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
167+
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
164168
)
165169
return updates, upto_token, limited
166170

167171

168172
def db_query_to_update_function(
169-
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
173+
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
170174
) -> UpdateFunction:
171175
"""Wraps a db query function which returns a list of rows to make it
172176
suitable for use as an `update_function` for the Stream class
173177
"""
174178

175-
async def update_function(from_token, upto_token, limit):
179+
async def update_function(instance_name, from_token, upto_token, limit):
176180
rows = await query_function(from_token, upto_token, limit)
177181
updates = [(row[0], row[1:]) for row in rows]
178182
limited = False
@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
193197
client = ReplicationGetStreamUpdates.make_client(hs)
194198

195199
async def update_function(
196-
from_token: int, upto_token: int, limit: int
200+
instance_name: str, from_token: int, upto_token: int, limit: int
197201
) -> StreamUpdateResult:
198202
result = await client(
199-
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
203+
instance_name=instance_name,
204+
stream_name=stream_name,
205+
from_token=from_token,
206+
upto_token=upto_token,
200207
)
201208
return result["updates"], result["upto_token"], result["limited"]
202209

@@ -226,6 +233,7 @@ class BackfillStream(Stream):
226233
def __init__(self, hs):
227234
store = hs.get_datastore()
228235
super().__init__(
236+
hs.get_instance_name(),
229237
store.get_current_backfill_token,
230238
db_query_to_update_function(store.get_all_new_backfill_event_rows),
231239
)
@@ -261,7 +269,9 @@ def __init__(self, hs):
261269
# Query master process
262270
update_function = make_http_update_function(hs, self.NAME)
263271

264-
super().__init__(store.get_current_presence_token, update_function)
272+
super().__init__(
273+
hs.get_instance_name(), store.get_current_presence_token, update_function
274+
)
265275

266276

267277
class TypingStream(Stream):
@@ -284,7 +294,9 @@ def __init__(self, hs):
284294
# Query master process
285295
update_function = make_http_update_function(hs, self.NAME)
286296

287-
super().__init__(typing_handler.get_current_token, update_function)
297+
super().__init__(
298+
hs.get_instance_name(), typing_handler.get_current_token, update_function
299+
)
288300

289301

290302
class ReceiptsStream(Stream):
@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
305317
def __init__(self, hs):
306318
store = hs.get_datastore()
307319
super().__init__(
320+
hs.get_instance_name(),
308321
store.get_max_receipt_stream_id,
309322
db_query_to_update_function(store.get_all_updated_receipts),
310323
)
@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
322335
def __init__(self, hs):
323336
self.store = hs.get_datastore()
324337
super(PushRulesStream, self).__init__(
325-
self._current_token, self._update_function
338+
hs.get_instance_name(), self._current_token, self._update_function
326339
)
327340

328341
def _current_token(self) -> int:
329342
push_rules_token, _ = self.store.get_push_rules_stream_token()
330343
return push_rules_token
331344

332-
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
345+
async def _update_function(
346+
self, instance_name: str, from_token: Token, to_token: Token, limit: int
347+
):
333348
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
334349

335350
limited = False
@@ -356,6 +371,7 @@ def __init__(self, hs):
356371
store = hs.get_datastore()
357372

358373
super().__init__(
374+
hs.get_instance_name(),
359375
store.get_pushers_stream_token,
360376
db_query_to_update_function(store.get_all_updated_pushers_rows),
361377
)
@@ -387,6 +403,7 @@ class CachesStreamRow:
387403
def __init__(self, hs):
388404
store = hs.get_datastore()
389405
super().__init__(
406+
hs.get_instance_name(),
390407
store.get_cache_stream_token,
391408
db_query_to_update_function(store.get_all_updated_caches),
392409
)
@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
412429
def __init__(self, hs):
413430
store = hs.get_datastore()
414431
super().__init__(
432+
hs.get_instance_name(),
415433
store.get_current_public_room_stream_id,
416434
db_query_to_update_function(store.get_all_new_public_rooms),
417435
)
@@ -432,6 +450,7 @@ class DeviceListsStreamRow:
432450
def __init__(self, hs):
433451
store = hs.get_datastore()
434452
super().__init__(
453+
hs.get_instance_name(),
435454
store.get_device_stream_token,
436455
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
437456
)
@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
449468
def __init__(self, hs):
450469
store = hs.get_datastore()
451470
super().__init__(
471+
hs.get_instance_name(),
452472
store.get_to_device_stream_token,
453473
db_query_to_update_function(store.get_all_new_device_messages),
454474
)
@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
468488
def __init__(self, hs):
469489
store = hs.get_datastore()
470490
super().__init__(
491+
hs.get_instance_name(),
471492
store.get_max_account_data_stream_id,
472493
db_query_to_update_function(store.get_all_updated_tags),
473494
)
@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
487508
def __init__(self, hs):
488509
self.store = hs.get_datastore()
489510
super().__init__(
511+
hs.get_instance_name(),
490512
self.store.get_max_account_data_stream_id,
491513
db_query_to_update_function(self._update_function),
492514
)
@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
517539
def __init__(self, hs):
518540
store = hs.get_datastore()
519541
super().__init__(
542+
hs.get_instance_name(),
520543
store.get_group_stream_token,
521544
db_query_to_update_function(store.get_all_groups_changes),
522545
)
@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
534557
def __init__(self, hs):
535558
store = hs.get_datastore()
536559
super().__init__(
560+
hs.get_instance_name(),
537561
store.get_device_stream_token,
538562
db_query_to_update_function(
539563
store.get_all_user_signature_changes_for_remotes

0 commit comments

Comments
 (0)