|
18 | 18 | # [This file includes modifications made by New Vector Limited] |
19 | 19 | # |
20 | 20 | # |
| 21 | +from __future__ import annotations |
21 | 22 |
|
22 | 23 | import urllib.parse |
23 | | -from typing import cast |
| 24 | +from typing import Any, cast |
| 25 | +from unittest.mock import Mock |
24 | 26 |
|
25 | 27 | from parameterized import parameterized |
26 | 28 |
|
| 29 | +from twisted.internet.defer import Deferred |
27 | 30 | from twisted.internet.testing import MemoryReactor |
28 | 31 | from twisted.web.resource import Resource |
29 | 32 |
|
@@ -70,6 +73,24 @@ def create_resource_dict(self) -> dict[str, Resource]: |
70 | 73 | resources["/_matrix/media"] = self.hs.get_media_repository_resource() |
71 | 74 | return resources |
72 | 75 |
|
| 76 | + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: |
| 77 | + self.fetches: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] |
| 78 | + |
| 79 | + # A remote fetch of media that was not intentional. |
| 80 | + # Used to check that remote media fetches do NOT happen. |
| 81 | + def unexpected_remote_fetch(*args: Any, **kwargs: Any) -> Deferred[Any]: |
| 82 | + self.fetches.append((args, kwargs)) |
| 83 | + return Deferred() |
| 84 | + |
| 85 | + client = Mock() |
| 86 | + client.federation_get_file = unexpected_remote_fetch |
| 87 | + client.get_file = unexpected_remote_fetch |
| 88 | + |
| 89 | + return self.setup_test_homeserver( |
| 90 | + clock=clock, |
| 91 | + federation_http_client=client, |
| 92 | + ) |
| 93 | + |
73 | 94 | def _ensure_quarantined( |
74 | 95 | self, |
75 | 96 | user_tok: str, |
@@ -176,6 +197,28 @@ def test_admin_can_bypass_quarantine(self) -> None: |
176 | 197 | ), |
177 | 198 | ) |
178 | 199 |
|
| 200 | + def test_non_admin_bypass_does_not_fetch_remote_media(self) -> None: |
| 201 | + self.register_user("nonadmin", "pass", admin=False) |
| 202 | + non_admin_user_tok = self.login("nonadmin", "pass") |
| 203 | + |
| 204 | + channel = self.make_request( |
| 205 | + "GET", |
| 206 | + "/_matrix/client/v1/media/download/example.com/remote_media" |
| 207 | + "?admin_unsafely_bypass_quarantine=true", |
| 208 | + shorthand=False, |
| 209 | + access_token=non_admin_user_tok, |
| 210 | + await_result=False, |
| 211 | + ) |
| 212 | + self.pump() |
| 213 | + |
| 214 | + self.assertEqual(400, channel.code, msg=channel.json_body) |
| 215 | + self.assertEqual( |
| 216 | + channel.json_body["error"], |
| 217 | + "Must be a server admin to bypass quarantine", |
| 218 | + ) |
| 219 | + # Check that a remote fetch attempt did not occur. |
| 220 | + self.assertEqual(self.fetches, []) |
| 221 | + |
179 | 222 | @parameterized.expand( |
180 | 223 | [ |
181 | 224 | # Attempt quarantine media APIs as non-admin |
|
0 commit comments