|
15 | 15 |
|
16 | 16 | from mock import Mock |
17 | 17 |
|
| 18 | +from synapse.app.generic_worker import GenericWorkerServer |
| 19 | +from synapse.replication.tcp.client import ReplicationDataHandler |
18 | 20 | from synapse.replication.tcp.handler import ReplicationCommandHandler |
19 | 21 | from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol |
20 | 22 | from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory |
|
26 | 28 | class BaseStreamTestCase(unittest.HomeserverTestCase): |
27 | 29 | """Base class for tests of the replication streams""" |
28 | 30 |
|
29 | | - def make_homeserver(self, reactor, clock): |
30 | | - self.test_handler = Mock(wraps=TestReplicationDataHandler()) |
31 | | - return self.setup_test_homeserver(replication_data_handler=self.test_handler) |
32 | | - |
33 | 31 | def prepare(self, reactor, clock, hs): |
34 | 32 | # build a replication server |
35 | 33 | server_factory = ReplicationStreamProtocolFactory(hs) |
36 | 34 | self.streamer = hs.get_replication_streamer() |
37 | 35 | self.server = server_factory.buildProtocol(None) |
38 | 36 |
|
39 | | - repl_handler = ReplicationCommandHandler(hs) |
40 | | - repl_handler.handler = self.test_handler |
| 37 | + # Make a new HomeServer object for the worker |
| 38 | + config = self.default_config() |
| 39 | + config["worker_app"] = "synapse.app.generic_worker" |
| 40 | + |
| 41 | + self.worker_hs = self.setup_test_homeserver( |
| 42 | + http_client=None, |
| 43 | + homeserverToUse=GenericWorkerServer, |
| 44 | + config=config, |
| 45 | + reactor=self.reactor, |
| 46 | + ) |
| 47 | + |
| 48 | + self.test_handler = Mock( |
| 49 | + wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) |
| 50 | + ) |
| 51 | + self.worker_hs.replication_data_handler = self.test_handler |
| 52 | + |
| 53 | + # Since we use sqlite in memory databases we need to make sure the |
| 54 | + # databases objects are the same. |
| 55 | + self.worker_hs.get_datastore().db = hs.get_datastore().db |
| 56 | + |
| 57 | + repl_handler = ReplicationCommandHandler(self.worker_hs) |
| 58 | + |
41 | 59 | self.client = ClientReplicationStreamProtocol( |
42 | 60 | hs, "client", "test", clock, repl_handler, |
43 | 61 | ) |
@@ -75,16 +93,15 @@ def replicate(self): |
75 | 93 | self.pump(0.1) |
76 | 94 |
|
77 | 95 |
|
78 | | -class TestReplicationDataHandler: |
| 96 | +class TestReplicationDataHandler(ReplicationDataHandler): |
79 | 97 | """Drop-in for ReplicationDataHandler which just collects RDATA rows""" |
80 | 98 |
|
81 | | - def __init__(self): |
| 99 | + def __init__(self, hs): |
| 100 | + super().__init__(hs) |
82 | 101 | self.streams = set() |
83 | 102 | self._received_rdata_rows = [] |
84 | 103 |
|
85 | 104 | async def on_rdata(self, stream_name, token, rows): |
| 105 | + await super().on_rdata(stream_name, token, rows) |
86 | 106 | for r in rows: |
87 | 107 | self._received_rdata_rows.append((stream_name, token, r)) |
88 | | - |
89 | | - async def on_position(self, stream_name, token): |
90 | | - pass |
|
0 commit comments