Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d3ed935

Browse files
authored
Create a PasswordProvider wrapper object (#8849)
The idea here is to abstract out all the conditional code which tests which methods a given password provider has, to provide a consistent interface.
1 parent edb3d3f commit d3ed935

3 files changed

Lines changed: 152 additions & 57 deletions

File tree

changelog.d/8849.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor `password_auth_provider` support code.

synapse/handlers/auth.py

Lines changed: 148 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# Copyright 2014 - 2016 OpenMarket Ltd
33
# Copyright 2017 Vector Creations Ltd
4+
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
45
#
56
# Licensed under the Apache License, Version 2.0 (the "License");
67
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@
2526
Dict,
2627
Iterable,
2728
List,
29+
Mapping,
2830
Optional,
2931
Tuple,
3032
Union,
@@ -181,17 +183,12 @@ def __init__(self, hs: "HomeServer"):
181183
# better way to break the loop
182184
account_handler = ModuleApi(hs, self)
183185

184-
self.password_providers = []
185-
for module, config in hs.config.password_providers:
186-
try:
187-
self.password_providers.append(
188-
module(config=config, account_handler=account_handler)
189-
)
190-
except Exception as e:
191-
logger.error("Error while initializing %r: %s", module, e)
192-
raise
186+
self.password_providers = [
187+
PasswordProvider.load(module, config, account_handler)
188+
for module, config in hs.config.password_providers
189+
]
193190

194-
logger.info("Extra password_providers: %r", self.password_providers)
191+
logger.info("Extra password_providers: %s", self.password_providers)
195192

196193
self.hs = hs # FIXME better possibility to access registrationHandler later?
197194
self.macaroon_gen = hs.get_macaroon_generator()
@@ -853,6 +850,8 @@ async def validate_login(
853850
LoginError if there was an authentication problem.
854851
"""
855852
login_type = login_submission.get("type")
853+
if not isinstance(login_type, str):
854+
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
856855

857856
# ideally, we wouldn't be checking the identifier unless we know we have a login
858857
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@@ -998,24 +997,12 @@ async def _validate_userid_login(
998997
qualified_user_id = UserID(username, self.hs.hostname).to_string()
999998

1000999
login_type = login_submission.get("type")
1000+
# we already checked that we have a valid login type
1001+
assert isinstance(login_type, str)
1002+
10011003
known_login_type = False
10021004

10031005
for provider in self.password_providers:
1004-
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
1005-
known_login_type = True
1006-
# we've already checked that there is a (valid) password field
1007-
is_valid = await provider.check_password(
1008-
qualified_user_id, login_submission["password"]
1009-
)
1010-
if is_valid:
1011-
return qualified_user_id, None
1012-
1013-
if not hasattr(provider, "get_supported_login_types") or not hasattr(
1014-
provider, "check_auth"
1015-
):
1016-
# this password provider doesn't understand custom login types
1017-
continue
1018-
10191006
supported_login_types = provider.get_supported_login_types()
10201007
if login_type not in supported_login_types:
10211008
# this password provider doesn't understand this login type
@@ -1040,8 +1027,6 @@ async def _validate_userid_login(
10401027

10411028
result = await provider.check_auth(username, login_type, login_dict)
10421029
if result:
1043-
if isinstance(result, str):
1044-
result = (result, None)
10451030
return result
10461031

10471032
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
@@ -1083,19 +1068,9 @@ async def check_password_provider_3pid(
10831068
unsuccessful, `user_id` and `callback` are both `None`.
10841069
"""
10851070
for provider in self.password_providers:
1086-
if hasattr(provider, "check_3pid_auth"):
1087-
# This function is able to return a deferred that either
1088-
# resolves None, meaning authentication failure, or upon
1089-
# success, to a str (which is the user_id) or a tuple of
1090-
# (user_id, callback_func), where callback_func should be run
1091-
# after we've finished everything else
1092-
result = await provider.check_3pid_auth(medium, address, password)
1093-
if result:
1094-
# Check if the return value is a str or a tuple
1095-
if isinstance(result, str):
1096-
# If it's a str, set callback function to None
1097-
result = (result, None)
1098-
return result
1071+
result = await provider.check_3pid_auth(medium, address, password)
1072+
if result:
1073+
return result
10991074

11001075
return None, None
11011076

@@ -1153,16 +1128,11 @@ async def delete_access_token(self, access_token: str):
11531128

11541129
# see if any of our auth providers want to know about this
11551130
for provider in self.password_providers:
1156-
if hasattr(provider, "on_logged_out"):
1157-
# This might return an awaitable, if it does block the log out
1158-
# until it completes.
1159-
result = provider.on_logged_out(
1160-
user_id=user_info.user_id,
1161-
device_id=user_info.device_id,
1162-
access_token=access_token,
1163-
)
1164-
if inspect.isawaitable(result):
1165-
await result
1131+
await provider.on_logged_out(
1132+
user_id=user_info.user_id,
1133+
device_id=user_info.device_id,
1134+
access_token=access_token,
1135+
)
11661136

11671137
# delete pushers associated with this access token
11681138
if user_info.token_id is not None:
@@ -1191,11 +1161,10 @@ async def delete_access_tokens_for_user(
11911161

11921162
# see if any of our auth providers want to know about this
11931163
for provider in self.password_providers:
1194-
if hasattr(provider, "on_logged_out"):
1195-
for token, token_id, device_id in tokens_and_devices:
1196-
await provider.on_logged_out(
1197-
user_id=user_id, device_id=device_id, access_token=token
1198-
)
1164+
for token, token_id, device_id in tokens_and_devices:
1165+
await provider.on_logged_out(
1166+
user_id=user_id, device_id=device_id, access_token=token
1167+
)
11991168

12001169
# delete pushers associated with the access tokens
12011170
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1519,3 +1488,127 @@ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
15191488
macaroon.add_first_party_caveat("gen = 1")
15201489
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
15211490
return macaroon
1491+
1492+
1493+
class PasswordProvider:
1494+
"""Wrapper for a password auth provider module
1495+
1496+
This class abstracts out all of the backwards-compatibility hacks for
1497+
password providers, to provide a consistent interface.
1498+
"""
1499+
1500+
@classmethod
1501+
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
1502+
try:
1503+
pp = module(config=config, account_handler=module_api)
1504+
except Exception as e:
1505+
logger.error("Error while initializing %r: %s", module, e)
1506+
raise
1507+
return cls(pp, module_api)
1508+
1509+
def __init__(self, pp, module_api: ModuleApi):
1510+
self._pp = pp
1511+
self._module_api = module_api
1512+
1513+
self._supported_login_types = {}
1514+
1515+
# grandfather in check_password support
1516+
if hasattr(self._pp, "check_password"):
1517+
self._supported_login_types[LoginType.PASSWORD] = ("password",)
1518+
1519+
g = getattr(self._pp, "get_supported_login_types", None)
1520+
if g:
1521+
self._supported_login_types.update(g())
1522+
1523+
def __str__(self):
1524+
return str(self._pp)
1525+
1526+
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
1527+
"""Get the login types supported by this password provider
1528+
1529+
Returns a map from a login type identifier (such as m.login.password) to an
1530+
iterable giving the fields which must be provided by the user in the submission
1531+
to the /login API.
1532+
1533+
This wrapper adds m.login.password to the list if the underlying password
1534+
provider supports the check_password() api.
1535+
"""
1536+
return self._supported_login_types
1537+
1538+
async def check_auth(
1539+
self, username: str, login_type: str, login_dict: JsonDict
1540+
) -> Optional[Tuple[str, Optional[Callable]]]:
1541+
"""Check if the user has presented valid login credentials
1542+
1543+
This wrapper also calls check_password() if the underlying password provider
1544+
supports the check_password() api and the login type is m.login.password.
1545+
1546+
Args:
1547+
username: user id presented by the client. Either an MXID or an unqualified
1548+
username.
1549+
1550+
login_type: the login type being attempted - one of the types returned by
1551+
get_supported_login_types()
1552+
1553+
login_dict: the dictionary of login secrets passed by the client.
1554+
1555+
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
1556+
user, and `callback` is an optional callback which will be called with the
1557+
result from the /login call (including access_token, device_id, etc.)
1558+
"""
1559+
# first grandfather in a call to check_password
1560+
if login_type == LoginType.PASSWORD:
1561+
g = getattr(self._pp, "check_password", None)
1562+
if g:
1563+
qualified_user_id = self._module_api.get_qualified_user_id(username)
1564+
is_valid = await self._pp.check_password(
1565+
qualified_user_id, login_dict["password"]
1566+
)
1567+
if is_valid:
1568+
return qualified_user_id, None
1569+
1570+
g = getattr(self._pp, "check_auth", None)
1571+
if not g:
1572+
return None
1573+
result = await g(username, login_type, login_dict)
1574+
1575+
# Check if the return value is a str or a tuple
1576+
if isinstance(result, str):
1577+
# If it's a str, set callback function to None
1578+
return result, None
1579+
1580+
return result
1581+
1582+
async def check_3pid_auth(
1583+
self, medium: str, address: str, password: str
1584+
) -> Optional[Tuple[str, Optional[Callable]]]:
1585+
g = getattr(self._pp, "check_3pid_auth", None)
1586+
if not g:
1587+
return None
1588+
1589+
# This function is able to return a deferred that either
1590+
# resolves None, meaning authentication failure, or upon
1591+
# success, to a str (which is the user_id) or a tuple of
1592+
# (user_id, callback_func), where callback_func should be run
1593+
# after we've finished everything else
1594+
result = await g(medium, address, password)
1595+
1596+
# Check if the return value is a str or a tuple
1597+
if isinstance(result, str):
1598+
# If it's a str, set callback function to None
1599+
return result, None
1600+
1601+
return result
1602+
1603+
async def on_logged_out(
1604+
self, user_id: str, device_id: Optional[str], access_token: str
1605+
) -> None:
1606+
g = getattr(self._pp, "on_logged_out", None)
1607+
if not g:
1608+
return
1609+
1610+
# This might return an awaitable, if it does block the log out
1611+
# until it completes.
1612+
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
1613+
if inspect.isawaitable(result):
1614+
await result

tests/handlers/test_password_providers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,9 @@ def test_no_local_user_fallback_ui_auth(self):
266266
# first delete should give a 401
267267
channel = self._delete_device(tok1, "dev2")
268268
self.assertEqual(channel.code, 401)
269-
# there are no valid flows here!
270-
self.assertEqual(channel.json_body["flows"], [])
269+
# m.login.password UIA is permitted because the auth provider allows it,
270+
# even though the localdb does not.
271+
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
271272
session = channel.json_body["session"]
272273
mock_password_provider.check_password.assert_not_called()
273274

0 commit comments

Comments
 (0)