|
| 1 | +from typing import Optional |
| 2 | + |
1 | 3 | from synapse.api.ratelimiting import LimitExceededError, Ratelimiter |
2 | 4 | from synapse.appservice import ApplicationService |
3 | 5 | from synapse.config.ratelimiting import RatelimitSettings |
| 6 | +from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks |
| 7 | +from synapse.storage.databases.main.room import RatelimitOverride |
4 | 8 | from synapse.types import create_requester |
5 | 9 |
|
6 | 10 | from tests import unittest |
@@ -440,3 +444,49 @@ def test_record_action_which_overfills_bucket(self) -> None: |
440 | 444 | limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) |
441 | 445 | ) |
442 | 446 | self.assertTrue(success) |
| 447 | + |
| 448 | + def test_get_ratelimit_override_for_user_callback(self) -> None: |
| 449 | + test_user_id = "@user:test" |
| 450 | + test_limiter_name = "name" |
| 451 | + callbacks = RatelimitModuleApiCallbacks(self.hs) |
| 452 | + requester = create_requester(test_user_id) |
| 453 | + limiter = Ratelimiter( |
| 454 | + store=self.hs.get_datastores().main, |
| 455 | + clock=self.clock, |
| 456 | + cfg=RatelimitSettings( |
| 457 | + test_limiter_name, |
| 458 | + per_second=0.1, |
| 459 | + burst_count=3, |
| 460 | + ), |
| 461 | + ratelimit_callbacks=callbacks, |
| 462 | + ) |
| 463 | + |
| 464 | + # Observe four actions, exceeding the burst_count. |
| 465 | + limiter.record_action(requester=requester, n_actions=4, _time_now_s=0.0) |
| 466 | + |
| 467 | + # We should be prevented from taking a new action now. |
| 468 | + success, _ = self.get_success_or_raise( |
| 469 | + limiter.can_do_action(requester=requester, _time_now_s=0.0) |
| 470 | + ) |
| 471 | + self.assertFalse(success) |
| 472 | + |
| 473 | + # Now register a callback that overrides the ratelimit for this user |
| 474 | + # and limiter name. |
| 475 | + async def get_ratelimit_override_for_user( |
| 476 | + user_id: str, limiter_name: str |
| 477 | + ) -> Optional[RatelimitOverride]: |
| 478 | + if user_id == test_user_id: |
| 479 | + return RatelimitOverride( |
| 480 | + messages_per_second=0.1, |
| 481 | + burst_count=10, |
| 482 | + ) |
| 483 | + return None |
| 484 | + |
| 485 | + callbacks.register_callbacks( |
| 486 | + get_ratelimit_override_for_user=get_ratelimit_override_for_user |
| 487 | + ) |
| 488 | + |
| 489 | + success, _ = self.get_success_or_raise( |
| 490 | + limiter.can_do_action(requester=requester, _time_now_s=0.0) |
| 491 | + ) |
| 492 | + self.assertTrue(success) |
0 commit comments