Skip to content

Commit 2c14943

Browse files
committed
Add tests for the new MAS auth delegation
1 parent 6c55c57 commit 2c14943

2 files changed

Lines changed: 344 additions & 4 deletions

File tree

synapse/synapse_rust/http_client.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@
1010
# See the GNU Affero General Public License for more details:
1111
# <https://www.gnu.org/licenses/agpl-3.0.html>.
1212

13-
from typing import Awaitable, Mapping
13+
from typing import Mapping
14+
15+
from twisted.internet.defer import Deferred
1416

1517
from synapse.types import ISynapseReactor
1618

1719
class HttpClient:
1820
def __init__(self, reactor: ISynapseReactor, user_agent: str) -> None: ...
19-
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
21+
def get(self, url: str, response_limit: int) -> Deferred[bytes]: ...
2022
def post(
2123
self,
2224
url: str,
2325
response_limit: int,
2426
headers: Mapping[str, str],
2527
request_body: str,
26-
) -> Awaitable[bytes]: ...
28+
) -> Deferred[bytes]: ...

tests/handlers/test_oauth_delegation.py

Lines changed: 339 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
#
2121

2222
import json
23+
import threading
24+
import time
2325
from http import HTTPStatus
26+
from http.server import BaseHTTPRequestHandler, HTTPServer
2427
from io import BytesIO
25-
from typing import Any, Dict, Union
28+
from typing import Any, Coroutine, Dict, Generator, Optional, TypeVar, Union
2629
from unittest.mock import ANY, AsyncMock, Mock
2730
from urllib.parse import parse_qs
2831

@@ -34,8 +37,10 @@
3437
)
3538
from signedjson.sign import sign_json
3639

40+
from twisted.internet.defer import Deferred, ensureDeferred
3741
from twisted.internet.testing import MemoryReactor
3842

43+
from synapse.api.auth.mas import MasDelegatedAuth
3944
from synapse.api.errors import (
4045
AuthError,
4146
Codes,
@@ -747,6 +752,339 @@ async def mock_http_client_request(
747752
self.assertEqual(conn_infos, [])
748753

749754

755+
class FakeMasHandler(BaseHTTPRequestHandler):
756+
server: "FakeMasServer"
757+
758+
def do_POST(self) -> None:
759+
self.server.calls += 1
760+
761+
if self.path != "/oauth2/introspect":
762+
self.send_response(404)
763+
self.end_headers()
764+
self.wfile.close()
765+
return
766+
767+
auth = self.headers.get("Authorization")
768+
if auth is None or auth != f"Bearer {self.server.secret}":
769+
self.send_response(401)
770+
self.end_headers()
771+
self.wfile.close()
772+
return
773+
774+
content_length = self.headers.get("Content-Length")
775+
if content_length is None:
776+
self.send_response(400)
777+
self.end_headers()
778+
self.wfile.close()
779+
return
780+
781+
raw_body = self.rfile.read(int(content_length))
782+
body = parse_qs(raw_body)
783+
param = body.get(b"token")
784+
if param is None:
785+
self.send_response(400)
786+
self.end_headers()
787+
self.wfile.close()
788+
return
789+
790+
self.server.last_token_seen = param[0].decode("utf-8")
791+
792+
self.send_response(200)
793+
self.send_header("Content-Type", "application/json")
794+
self.end_headers()
795+
self.wfile.write(json.dumps(self.server.introspection_response).encode("utf-8"))
796+
797+
def log_message(self, format: str, *args: Any) -> None:
798+
# Don't log anything; by default, the server logs to stderr
799+
pass
800+
801+
802+
class FakeMasServer(HTTPServer):
803+
"""A fake MAS server for testing.
804+
805+
This opens a real HTTP server on a random port, on a separate thread.
806+
"""
807+
808+
introspection_response: JsonDict = {}
809+
"""Determines what the response to the introspection endpoint will be."""
810+
811+
secret: str = "verysecret"
812+
"""The shared secret used to authenticate the introspection endpoint."""
813+
814+
last_token_seen: Optional[str] = None
815+
"""What is the last access token seen by the introspection endpoint."""
816+
817+
calls: int = 0
818+
"""How many times has the introspection endpoint been called."""
819+
820+
_thread: threading.Thread
821+
822+
def __init__(self) -> None:
823+
super().__init__(("127.0.0.1", 0), FakeMasHandler)
824+
825+
self._thread = threading.Thread(
826+
target=self.serve_forever,
827+
name="FakeMasServer",
828+
kwargs={"poll_interval": 0.01},
829+
daemon=True,
830+
)
831+
self._thread.start()
832+
833+
def shutdown(self) -> None:
834+
super().shutdown()
835+
self._thread.join()
836+
837+
@property
838+
def endpoint(self) -> str:
839+
return f"http://127.0.0.1:{self.server_port}/"
840+
841+
842+
T = TypeVar("T")
843+
844+
845+
class MasAuthDelegation(HomeserverTestCase):
846+
server: FakeMasServer
847+
848+
def till_deferred_has_result(
849+
self,
850+
awaitable: Union[
851+
Coroutine[Deferred[Any], Any, T],
852+
Generator[Deferred[Any], Any, T],
853+
Deferred[T],
854+
],
855+
) -> Deferred[T]:
856+
"""Wait until a deferred has a result.
857+
858+
This is useful because the Rust HTTP client will resolve the deferred
859+
using reactor.callFromThread, which are only run when we call
860+
reactor.advance.
861+
"""
862+
deferred = ensureDeferred(awaitable)
863+
tries = 0
864+
while not deferred.called:
865+
time.sleep(0.1)
866+
self.reactor.advance(0)
867+
tries += 1
868+
if tries > 100:
869+
raise Exception("Timed out waiting for deferred to resolve")
870+
871+
return deferred
872+
873+
def default_config(self) -> Dict[str, Any]:
874+
config = super().default_config()
875+
config["public_baseurl"] = BASE_URL
876+
config["disable_registration"] = True
877+
config["matrix_authentication_service"] = {
878+
"enabled": True,
879+
"endpoint": self.server.endpoint,
880+
"secret": self.server.secret,
881+
}
882+
return config
883+
884+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
885+
self.server = FakeMasServer()
886+
hs = self.setup_test_homeserver()
887+
# This triggers the server startup hooks, which starts the Tokio thread pool
888+
reactor.run()
889+
self._auth = checked_cast(MasDelegatedAuth, hs.get_auth())
890+
return hs
891+
892+
def prepare(
893+
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
894+
) -> None:
895+
# Provision the user and the device we use in the tests.
896+
store = homeserver.get_datastores().main
897+
self.get_success(store.register_user(USER_ID))
898+
self.get_success(
899+
store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
900+
)
901+
902+
def tearDown(self) -> None:
903+
self.server.shutdown()
904+
# MemoryReactor doesn't trigger the shutdown phases, and we want the
905+
# Tokio thread pool to be stopped
906+
# XXX: This logic should probably get moved somewhere else
907+
shutdown_triggers = self.reactor.triggers.get("shutdown", {})
908+
for phase in ["before", "during", "after"]:
909+
triggers = shutdown_triggers.get(phase, [])
910+
for callbable, args, kwargs in triggers:
911+
callbable(*args, **kwargs)
912+
913+
def test_simple_introspection(self) -> None:
914+
self.server.introspection_response = {
915+
"active": True,
916+
"sub": SUBJECT,
917+
"scope": " ".join(
918+
[
919+
MATRIX_USER_SCOPE,
920+
f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}",
921+
]
922+
),
923+
"username": USERNAME,
924+
"expires_in": 60,
925+
}
926+
927+
requester = self.get_success(
928+
self.till_deferred_has_result(
929+
self._auth.get_user_by_access_token("some_token")
930+
)
931+
)
932+
933+
self.assertEquals(requester.user.to_string(), USER_ID)
934+
self.assertEquals(requester.device_id, DEVICE)
935+
self.assertFalse(self.get_success(self._auth.is_server_admin(requester)))
936+
937+
self.assertEquals(
938+
self.server.last_token_seen,
939+
"some_token",
940+
)
941+
942+
def test_inexistent_device(self) -> None:
943+
self.server.introspection_response = {
944+
"active": True,
945+
"sub": SUBJECT,
946+
"scope": " ".join(
947+
[
948+
MATRIX_USER_SCOPE,
949+
f"{MATRIX_DEVICE_SCOPE_PREFIX}ABCDEF",
950+
]
951+
),
952+
"username": USERNAME,
953+
"expires_in": 60,
954+
}
955+
956+
failure = self.get_failure(
957+
self.till_deferred_has_result(
958+
self._auth.get_user_by_access_token("some_token")
959+
),
960+
InvalidClientTokenError,
961+
)
962+
self.assertEqual(failure.value.code, 401)
963+
964+
def test_inexistent_user(self) -> None:
965+
self.server.introspection_response = {
966+
"active": True,
967+
"sub": SUBJECT,
968+
"scope": " ".join([MATRIX_USER_SCOPE]),
969+
"username": "inexistent_user",
970+
"expires_in": 60,
971+
}
972+
973+
failure = self.get_failure(
974+
self.till_deferred_has_result(
975+
self._auth.get_user_by_access_token("some_token")
976+
),
977+
AuthError,
978+
)
979+
# This is a 500, it should never happen really
980+
self.assertEqual(failure.value.code, 500)
981+
982+
def test_missing_scope(self) -> None:
983+
self.server.introspection_response = {
984+
"active": True,
985+
"sub": SUBJECT,
986+
"scope": "openid",
987+
"username": USERNAME,
988+
"expires_in": 60,
989+
}
990+
991+
failure = self.get_failure(
992+
self.till_deferred_has_result(
993+
self._auth.get_user_by_access_token("some_token")
994+
),
995+
InvalidClientTokenError,
996+
)
997+
self.assertEqual(failure.value.code, 401)
998+
999+
def test_invalid_response(self) -> None:
1000+
self.server.introspection_response = {}
1001+
1002+
failure = self.get_failure(
1003+
self.till_deferred_has_result(
1004+
self._auth.get_user_by_access_token("some_token")
1005+
),
1006+
SynapseError,
1007+
)
1008+
self.assertEqual(failure.value.code, 503)
1009+
1010+
def test_device_id_in_body(self) -> None:
1011+
self.server.introspection_response = {
1012+
"active": True,
1013+
"sub": SUBJECT,
1014+
"scope": MATRIX_USER_SCOPE,
1015+
"username": USERNAME,
1016+
"expires_in": 60,
1017+
"device_id": DEVICE,
1018+
}
1019+
1020+
requester = self.get_success(
1021+
self.till_deferred_has_result(
1022+
self._auth.get_user_by_access_token("some_token")
1023+
)
1024+
)
1025+
1026+
self.assertEqual(requester.device_id, DEVICE)
1027+
1028+
def test_admin_scope(self) -> None:
1029+
self.server.introspection_response = {
1030+
"active": True,
1031+
"sub": SUBJECT,
1032+
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
1033+
"username": USERNAME,
1034+
"expires_in": 60,
1035+
}
1036+
1037+
requester = self.get_success(
1038+
self.till_deferred_has_result(
1039+
self._auth.get_user_by_access_token("some_token")
1040+
)
1041+
)
1042+
1043+
self.assertEqual(requester.user.to_string(), USER_ID)
1044+
self.assertTrue(self.get_success(self._auth.is_server_admin(requester)))
1045+
1046+
def test_cached_expired_introspection(self) -> None:
1047+
"""The handler should raise an error if the introspection response gives
1048+
an expiry time, the introspection response is cached and then the entry is
1049+
re-requested after it has expired."""
1050+
1051+
self.server.introspection_response = {
1052+
"active": True,
1053+
"sub": SUBJECT,
1054+
"scope": " ".join(
1055+
[
1056+
MATRIX_USER_SCOPE,
1057+
f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}",
1058+
]
1059+
),
1060+
"username": USERNAME,
1061+
"expires_in": 60,
1062+
}
1063+
1064+
self.assertEqual(self.server.calls, 0)
1065+
1066+
request = Mock(args={})
1067+
request.args[b"access_token"] = [b"some_token"]
1068+
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
1069+
1070+
# The first CS-API request causes a successful introspection
1071+
self.get_success(
1072+
self.till_deferred_has_result(self._auth.get_user_by_req(request))
1073+
)
1074+
self.assertEqual(self.server.calls, 1)
1075+
1076+
# Sleep for 60 seconds so the token expires.
1077+
self.reactor.advance(60.0)
1078+
1079+
# Now the CS-API request fails because the token expired
1080+
self.assertFailure(
1081+
self.till_deferred_has_result(self._auth.get_user_by_req(request)),
1082+
InvalidClientTokenError,
1083+
)
1084+
# Ensure another introspection request was not sent
1085+
self.assertEqual(self.server.calls, 1)
1086+
1087+
7501088
@parameterized_class(
7511089
("config",),
7521090
[

0 commit comments

Comments
 (0)