|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
15 | | -from typing import Any, Dict, List, Optional, Tuple |
| 15 | +from collections import defaultdict |
| 16 | +from typing import Any, Dict, List, Optional, Set, Tuple |
16 | 17 |
|
17 | 18 | from twisted.internet.address import IPv4Address |
18 | 19 | from twisted.internet.protocol import Protocol |
|
32 | 33 |
|
33 | 34 | from tests import unittest |
34 | 35 | from tests.server import FakeTransport |
| 36 | +from tests.utils import USE_POSTGRES_FOR_TESTS |
35 | 37 |
|
36 | 38 | try: |
37 | 39 | import hiredis |
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer: |
475 | 477 | """A fake Redis server for pub/sub.""" |
476 | 478 |
|
477 | 479 | def __init__(self): |
478 | | - self._subscribers = set() |
| 480 | + self._subscribers_by_channel: Dict[ |
| 481 | + bytes, Set["FakeRedisPubSubProtocol"] |
| 482 | + ] = defaultdict(set) |
479 | 483 |
|
480 | | - def add_subscriber(self, conn): |
| 484 | + def add_subscriber(self, conn, channel: bytes): |
481 | 485 | """A connection has called SUBSCRIBE""" |
482 | | - self._subscribers.add(conn) |
| 486 | + self._subscribers_by_channel[channel].add(conn) |
483 | 487 |
|
484 | 488 | def remove_subscriber(self, conn): |
485 | | - """A connection has called UNSUBSCRIBE""" |
486 | | - self._subscribers.discard(conn) |
| 489 | + """A connection has lost connection""" |
| 490 | + for subscribers in self._subscribers_by_channel.values(): |
| 491 | + subscribers.discard(conn) |
487 | 492 |
|
488 | | - def publish(self, conn, channel, msg) -> int: |
| 493 | + def publish(self, conn, channel: bytes, msg) -> int: |
489 | 494 | """A connection want to publish a message to subscribers.""" |
490 | | - for sub in self._subscribers: |
| 495 | + for sub in self._subscribers_by_channel[channel]: |
491 | 496 | sub.send(["message", channel, msg]) |
492 | 497 |
|
493 | | - return len(self._subscribers) |
| 498 | + return len(self._subscribers_by_channel) |
494 | 499 |
|
495 | 500 | def buildProtocol(self, addr): |
496 | 501 | return FakeRedisPubSubProtocol(self) |
@@ -531,9 +536,10 @@ def handle_command(self, command, *args): |
531 | 536 | num_subscribers = self._server.publish(self, channel, message) |
532 | 537 | self.send(num_subscribers) |
533 | 538 | elif command == b"SUBSCRIBE": |
534 | | - (channel,) = args |
535 | | - self._server.add_subscriber(self) |
536 | | - self.send(["subscribe", channel, 1]) |
| 539 | + for idx, channel in enumerate(args): |
| 540 | + num_channels = idx + 1 |
| 541 | + self._server.add_subscriber(self, channel) |
| 542 | + self.send(["subscribe", channel, num_channels]) |
537 | 543 |
|
538 | 544 | # Since we use SET/GET to cache things we can safely no-op them. |
539 | 545 | elif command == b"SET": |
@@ -576,3 +582,27 @@ def encode(self, obj): |
576 | 582 |
|
577 | 583 | def connectionLost(self, reason): |
578 | 584 | self._server.remove_subscriber(self) |
| 585 | + |
| 586 | + |
| 587 | +class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): |
| 588 | + """ |
| 589 | + A test case that enables Redis, providing a fake Redis server. |
| 590 | + """ |
| 591 | + |
| 592 | + if not hiredis: |
| 593 | + skip = "Requires hiredis" |
| 594 | + |
| 595 | + if not USE_POSTGRES_FOR_TESTS: |
| 596 | + # Redis replication only takes place on Postgres |
| 597 | + skip = "Requires Postgres" |
| 598 | + |
| 599 | + def default_config(self) -> Dict[str, Any]: |
| 600 | + """ |
| 601 | + Overrides the default config to enable Redis. |
| 602 | + Even if the test only uses make_worker_hs, the main process needs Redis |
| 603 | + enabled otherwise it won't create a Fake Redis server to listen on the |
| 604 | + Redis port and accept fake TCP connections. |
| 605 | + """ |
| 606 | + base = super().default_config() |
| 607 | + base["redis"] = {"enabled": True} |
| 608 | + return base |
0 commit comments