|
20 | 20 | # |
21 | 21 |
|
22 | 22 | import json |
| 23 | +import threading |
| 24 | +import time |
23 | 25 | from http import HTTPStatus |
| 26 | +from http.server import BaseHTTPRequestHandler, HTTPServer |
24 | 27 | from io import BytesIO |
25 | | -from typing import Any, Dict, Union |
| 28 | +from typing import Any, Coroutine, Dict, Generator, Optional, TypeVar, Union |
26 | 29 | from unittest.mock import ANY, AsyncMock, Mock |
27 | 30 | from urllib.parse import parse_qs |
28 | 31 |
|
|
34 | 37 | ) |
35 | 38 | from signedjson.sign import sign_json |
36 | 39 |
|
| 40 | +from twisted.internet.defer import Deferred, ensureDeferred |
37 | 41 | from twisted.internet.testing import MemoryReactor |
38 | 42 |
|
| 43 | +from synapse.api.auth.mas import MasDelegatedAuth |
39 | 44 | from synapse.api.errors import ( |
40 | 45 | AuthError, |
41 | 46 | Codes, |
@@ -747,6 +752,339 @@ async def mock_http_client_request( |
747 | 752 | self.assertEqual(conn_infos, []) |
748 | 753 |
|
749 | 754 |
|
| 755 | +class FakeMasHandler(BaseHTTPRequestHandler): |
| 756 | + server: "FakeMasServer" |
| 757 | + |
| 758 | + def do_POST(self) -> None: |
| 759 | + self.server.calls += 1 |
| 760 | + |
| 761 | + if self.path != "/oauth2/introspect": |
| 762 | + self.send_response(404) |
| 763 | + self.end_headers() |
| 764 | + self.wfile.close() |
| 765 | + return |
| 766 | + |
| 767 | + auth = self.headers.get("Authorization") |
| 768 | + if auth is None or auth != f"Bearer {self.server.secret}": |
| 769 | + self.send_response(401) |
| 770 | + self.end_headers() |
| 771 | + self.wfile.close() |
| 772 | + return |
| 773 | + |
| 774 | + content_length = self.headers.get("Content-Length") |
| 775 | + if content_length is None: |
| 776 | + self.send_response(400) |
| 777 | + self.end_headers() |
| 778 | + self.wfile.close() |
| 779 | + return |
| 780 | + |
| 781 | + raw_body = self.rfile.read(int(content_length)) |
| 782 | + body = parse_qs(raw_body) |
| 783 | + param = body.get(b"token") |
| 784 | + if param is None: |
| 785 | + self.send_response(400) |
| 786 | + self.end_headers() |
| 787 | + self.wfile.close() |
| 788 | + return |
| 789 | + |
| 790 | + self.server.last_token_seen = param[0].decode("utf-8") |
| 791 | + |
| 792 | + self.send_response(200) |
| 793 | + self.send_header("Content-Type", "application/json") |
| 794 | + self.end_headers() |
| 795 | + self.wfile.write(json.dumps(self.server.introspection_response).encode("utf-8")) |
| 796 | + |
| 797 | + def log_message(self, format: str, *args: Any) -> None: |
| 798 | + # Don't log anything; by default, the server logs to stderr |
| 799 | + pass |
| 800 | + |
| 801 | + |
| 802 | +class FakeMasServer(HTTPServer): |
| 803 | + """A fake MAS server for testing. |
| 804 | +
|
| 805 | + This opens a real HTTP server on a random port, on a separate thread. |
| 806 | + """ |
| 807 | + |
| 808 | + introspection_response: JsonDict = {} |
| 809 | + """Determines what the response to the introspection endpoint will be.""" |
| 810 | + |
| 811 | + secret: str = "verysecret" |
| 812 | + """The shared secret used to authenticate the introspection endpoint.""" |
| 813 | + |
| 814 | + last_token_seen: Optional[str] = None |
| 815 | + """What is the last access token seen by the introspection endpoint.""" |
| 816 | + |
| 817 | + calls: int = 0 |
| 818 | + """How many times has the introspection endpoint been called.""" |
| 819 | + |
| 820 | + _thread: threading.Thread |
| 821 | + |
| 822 | + def __init__(self) -> None: |
| 823 | + super().__init__(("127.0.0.1", 0), FakeMasHandler) |
| 824 | + |
| 825 | + self._thread = threading.Thread( |
| 826 | + target=self.serve_forever, |
| 827 | + name="FakeMasServer", |
| 828 | + kwargs={"poll_interval": 0.01}, |
| 829 | + daemon=True, |
| 830 | + ) |
| 831 | + self._thread.start() |
| 832 | + |
| 833 | + def shutdown(self) -> None: |
| 834 | + super().shutdown() |
| 835 | + self._thread.join() |
| 836 | + |
| 837 | + @property |
| 838 | + def endpoint(self) -> str: |
| 839 | + return f"http://127.0.0.1:{self.server_port}/" |
| 840 | + |
| 841 | + |
| 842 | +T = TypeVar("T") |
| 843 | + |
| 844 | + |
| 845 | +class MasAuthDelegation(HomeserverTestCase): |
| 846 | + server: FakeMasServer |
| 847 | + |
| 848 | + def till_deferred_has_result( |
| 849 | + self, |
| 850 | + awaitable: Union[ |
| 851 | + Coroutine[Deferred[Any], Any, T], |
| 852 | + Generator[Deferred[Any], Any, T], |
| 853 | + Deferred[T], |
| 854 | + ], |
| 855 | + ) -> Deferred[T]: |
| 856 | + """Wait until a deferred has a result. |
| 857 | +
|
| 858 | + This is useful because the Rust HTTP client will resolve the deferred |
| 859 | + using reactor.callFromThread, which are only run when we call |
| 860 | + reactor.advance. |
| 861 | + """ |
| 862 | + deferred = ensureDeferred(awaitable) |
| 863 | + tries = 0 |
| 864 | + while not deferred.called: |
| 865 | + time.sleep(0.1) |
| 866 | + self.reactor.advance(0) |
| 867 | + tries += 1 |
| 868 | + if tries > 100: |
| 869 | + raise Exception("Timed out waiting for deferred to resolve") |
| 870 | + |
| 871 | + return deferred |
| 872 | + |
| 873 | + def default_config(self) -> Dict[str, Any]: |
| 874 | + config = super().default_config() |
| 875 | + config["public_baseurl"] = BASE_URL |
| 876 | + config["disable_registration"] = True |
| 877 | + config["matrix_authentication_service"] = { |
| 878 | + "enabled": True, |
| 879 | + "endpoint": self.server.endpoint, |
| 880 | + "secret": self.server.secret, |
| 881 | + } |
| 882 | + return config |
| 883 | + |
| 884 | + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: |
| 885 | + self.server = FakeMasServer() |
| 886 | + hs = self.setup_test_homeserver() |
| 887 | + # This triggers the server startup hooks, which starts the Tokio thread pool |
| 888 | + reactor.run() |
| 889 | + self._auth = checked_cast(MasDelegatedAuth, hs.get_auth()) |
| 890 | + return hs |
| 891 | + |
| 892 | + def prepare( |
| 893 | + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer |
| 894 | + ) -> None: |
| 895 | + # Provision the user and the device we use in the tests. |
| 896 | + store = homeserver.get_datastores().main |
| 897 | + self.get_success(store.register_user(USER_ID)) |
| 898 | + self.get_success( |
| 899 | + store.store_device(USER_ID, DEVICE, initial_device_display_name=None) |
| 900 | + ) |
| 901 | + |
| 902 | + def tearDown(self) -> None: |
| 903 | + self.server.shutdown() |
| 904 | + # MemoryReactor doesn't trigger the shutdown phases, and we want the |
| 905 | + # Tokio thread pool to be stopped |
| 906 | + # XXX: This logic should probably get moved somewhere else |
| 907 | + shutdown_triggers = self.reactor.triggers.get("shutdown", {}) |
| 908 | + for phase in ["before", "during", "after"]: |
| 909 | + triggers = shutdown_triggers.get(phase, []) |
| 910 | + for callbable, args, kwargs in triggers: |
| 911 | + callbable(*args, **kwargs) |
| 912 | + |
| 913 | + def test_simple_introspection(self) -> None: |
| 914 | + self.server.introspection_response = { |
| 915 | + "active": True, |
| 916 | + "sub": SUBJECT, |
| 917 | + "scope": " ".join( |
| 918 | + [ |
| 919 | + MATRIX_USER_SCOPE, |
| 920 | + f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}", |
| 921 | + ] |
| 922 | + ), |
| 923 | + "username": USERNAME, |
| 924 | + "expires_in": 60, |
| 925 | + } |
| 926 | + |
| 927 | + requester = self.get_success( |
| 928 | + self.till_deferred_has_result( |
| 929 | + self._auth.get_user_by_access_token("some_token") |
| 930 | + ) |
| 931 | + ) |
| 932 | + |
| 933 | + self.assertEquals(requester.user.to_string(), USER_ID) |
| 934 | + self.assertEquals(requester.device_id, DEVICE) |
| 935 | + self.assertFalse(self.get_success(self._auth.is_server_admin(requester))) |
| 936 | + |
| 937 | + self.assertEquals( |
| 938 | + self.server.last_token_seen, |
| 939 | + "some_token", |
| 940 | + ) |
| 941 | + |
| 942 | + def test_inexistent_device(self) -> None: |
| 943 | + self.server.introspection_response = { |
| 944 | + "active": True, |
| 945 | + "sub": SUBJECT, |
| 946 | + "scope": " ".join( |
| 947 | + [ |
| 948 | + MATRIX_USER_SCOPE, |
| 949 | + f"{MATRIX_DEVICE_SCOPE_PREFIX}ABCDEF", |
| 950 | + ] |
| 951 | + ), |
| 952 | + "username": USERNAME, |
| 953 | + "expires_in": 60, |
| 954 | + } |
| 955 | + |
| 956 | + failure = self.get_failure( |
| 957 | + self.till_deferred_has_result( |
| 958 | + self._auth.get_user_by_access_token("some_token") |
| 959 | + ), |
| 960 | + InvalidClientTokenError, |
| 961 | + ) |
| 962 | + self.assertEqual(failure.value.code, 401) |
| 963 | + |
| 964 | + def test_inexistent_user(self) -> None: |
| 965 | + self.server.introspection_response = { |
| 966 | + "active": True, |
| 967 | + "sub": SUBJECT, |
| 968 | + "scope": " ".join([MATRIX_USER_SCOPE]), |
| 969 | + "username": "inexistent_user", |
| 970 | + "expires_in": 60, |
| 971 | + } |
| 972 | + |
| 973 | + failure = self.get_failure( |
| 974 | + self.till_deferred_has_result( |
| 975 | + self._auth.get_user_by_access_token("some_token") |
| 976 | + ), |
| 977 | + AuthError, |
| 978 | + ) |
| 979 | + # This is a 500, it should never happen really |
| 980 | + self.assertEqual(failure.value.code, 500) |
| 981 | + |
| 982 | + def test_missing_scope(self) -> None: |
| 983 | + self.server.introspection_response = { |
| 984 | + "active": True, |
| 985 | + "sub": SUBJECT, |
| 986 | + "scope": "openid", |
| 987 | + "username": USERNAME, |
| 988 | + "expires_in": 60, |
| 989 | + } |
| 990 | + |
| 991 | + failure = self.get_failure( |
| 992 | + self.till_deferred_has_result( |
| 993 | + self._auth.get_user_by_access_token("some_token") |
| 994 | + ), |
| 995 | + InvalidClientTokenError, |
| 996 | + ) |
| 997 | + self.assertEqual(failure.value.code, 401) |
| 998 | + |
| 999 | + def test_invalid_response(self) -> None: |
| 1000 | + self.server.introspection_response = {} |
| 1001 | + |
| 1002 | + failure = self.get_failure( |
| 1003 | + self.till_deferred_has_result( |
| 1004 | + self._auth.get_user_by_access_token("some_token") |
| 1005 | + ), |
| 1006 | + SynapseError, |
| 1007 | + ) |
| 1008 | + self.assertEqual(failure.value.code, 503) |
| 1009 | + |
| 1010 | + def test_device_id_in_body(self) -> None: |
| 1011 | + self.server.introspection_response = { |
| 1012 | + "active": True, |
| 1013 | + "sub": SUBJECT, |
| 1014 | + "scope": MATRIX_USER_SCOPE, |
| 1015 | + "username": USERNAME, |
| 1016 | + "expires_in": 60, |
| 1017 | + "device_id": DEVICE, |
| 1018 | + } |
| 1019 | + |
| 1020 | + requester = self.get_success( |
| 1021 | + self.till_deferred_has_result( |
| 1022 | + self._auth.get_user_by_access_token("some_token") |
| 1023 | + ) |
| 1024 | + ) |
| 1025 | + |
| 1026 | + self.assertEqual(requester.device_id, DEVICE) |
| 1027 | + |
| 1028 | + def test_admin_scope(self) -> None: |
| 1029 | + self.server.introspection_response = { |
| 1030 | + "active": True, |
| 1031 | + "sub": SUBJECT, |
| 1032 | + "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]), |
| 1033 | + "username": USERNAME, |
| 1034 | + "expires_in": 60, |
| 1035 | + } |
| 1036 | + |
| 1037 | + requester = self.get_success( |
| 1038 | + self.till_deferred_has_result( |
| 1039 | + self._auth.get_user_by_access_token("some_token") |
| 1040 | + ) |
| 1041 | + ) |
| 1042 | + |
| 1043 | + self.assertEqual(requester.user.to_string(), USER_ID) |
| 1044 | + self.assertTrue(self.get_success(self._auth.is_server_admin(requester))) |
| 1045 | + |
| 1046 | + def test_cached_expired_introspection(self) -> None: |
| 1047 | + """The handler should raise an error if the introspection response gives |
| 1048 | + an expiry time, the introspection response is cached and then the entry is |
| 1049 | + re-requested after it has expired.""" |
| 1050 | + |
| 1051 | + self.server.introspection_response = { |
| 1052 | + "active": True, |
| 1053 | + "sub": SUBJECT, |
| 1054 | + "scope": " ".join( |
| 1055 | + [ |
| 1056 | + MATRIX_USER_SCOPE, |
| 1057 | + f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}", |
| 1058 | + ] |
| 1059 | + ), |
| 1060 | + "username": USERNAME, |
| 1061 | + "expires_in": 60, |
| 1062 | + } |
| 1063 | + |
| 1064 | + self.assertEqual(self.server.calls, 0) |
| 1065 | + |
| 1066 | + request = Mock(args={}) |
| 1067 | + request.args[b"access_token"] = [b"some_token"] |
| 1068 | + request.requestHeaders.getRawHeaders = mock_getRawHeaders() |
| 1069 | + |
| 1070 | + # The first CS-API request causes a successful introspection |
| 1071 | + self.get_success( |
| 1072 | + self.till_deferred_has_result(self._auth.get_user_by_req(request)) |
| 1073 | + ) |
| 1074 | + self.assertEqual(self.server.calls, 1) |
| 1075 | + |
| 1076 | + # Sleep for 60 seconds so the token expires. |
| 1077 | + self.reactor.advance(60.0) |
| 1078 | + |
| 1079 | + # Now the CS-API request fails because the token expired |
| 1080 | + self.assertFailure( |
| 1081 | + self.till_deferred_has_result(self._auth.get_user_by_req(request)), |
| 1082 | + InvalidClientTokenError, |
| 1083 | + ) |
| 1084 | + # Ensure another introspection request was not sent |
| 1085 | + self.assertEqual(self.server.calls, 1) |
| 1086 | + |
| 1087 | + |
750 | 1088 | @parameterized_class( |
751 | 1089 | ("config",), |
752 | 1090 | [ |
|
0 commit comments