11# -*- coding: utf-8 -*-
22# Copyright 2014 - 2016 OpenMarket Ltd
33# Copyright 2017 Vector Creations Ltd
4+ # Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
45#
56# Licensed under the Apache License, Version 2.0 (the "License");
67# you may not use this file except in compliance with the License.
2526 Dict ,
2627 Iterable ,
2728 List ,
29+ Mapping ,
2830 Optional ,
2931 Tuple ,
3032 Union ,
@@ -181,17 +183,12 @@ def __init__(self, hs: "HomeServer"):
181183 # better way to break the loop
182184 account_handler = ModuleApi (hs , self )
183185
184- self .password_providers = []
185- for module , config in hs .config .password_providers :
186- try :
187- self .password_providers .append (
188- module (config = config , account_handler = account_handler )
189- )
190- except Exception as e :
191- logger .error ("Error while initializing %r: %s" , module , e )
192- raise
186+ self .password_providers = [
187+ PasswordProvider .load (module , config , account_handler )
188+ for module , config in hs .config .password_providers
189+ ]
193190
194- logger .info ("Extra password_providers: %r " , self .password_providers )
191+ logger .info ("Extra password_providers: %s " , self .password_providers )
195192
196193 self .hs = hs # FIXME better possibility to access registrationHandler later?
197194 self .macaroon_gen = hs .get_macaroon_generator ()
@@ -853,6 +850,8 @@ async def validate_login(
853850 LoginError if there was an authentication problem.
854851 """
855852 login_type = login_submission .get ("type" )
853+ if not isinstance (login_type , str ):
854+ raise SynapseError (400 , "Bad parameter: type" , Codes .INVALID_PARAM )
856855
857856 # ideally, we wouldn't be checking the identifier unless we know we have a login
858857 # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@@ -998,24 +997,12 @@ async def _validate_userid_login(
998997 qualified_user_id = UserID (username , self .hs .hostname ).to_string ()
999998
1000999 login_type = login_submission .get ("type" )
1000+ # we already checked that we have a valid login type
1001+ assert isinstance (login_type , str )
1002+
10011003 known_login_type = False
10021004
10031005 for provider in self .password_providers :
1004- if hasattr (provider , "check_password" ) and login_type == LoginType .PASSWORD :
1005- known_login_type = True
1006- # we've already checked that there is a (valid) password field
1007- is_valid = await provider .check_password (
1008- qualified_user_id , login_submission ["password" ]
1009- )
1010- if is_valid :
1011- return qualified_user_id , None
1012-
1013- if not hasattr (provider , "get_supported_login_types" ) or not hasattr (
1014- provider , "check_auth"
1015- ):
1016- # this password provider doesn't understand custom login types
1017- continue
1018-
10191006 supported_login_types = provider .get_supported_login_types ()
10201007 if login_type not in supported_login_types :
10211008 # this password provider doesn't understand this login type
@@ -1040,8 +1027,6 @@ async def _validate_userid_login(
10401027
10411028 result = await provider .check_auth (username , login_type , login_dict )
10421029 if result :
1043- if isinstance (result , str ):
1044- result = (result , None )
10451030 return result
10461031
10471032 if login_type == LoginType .PASSWORD and self .hs .config .password_localdb_enabled :
@@ -1083,19 +1068,9 @@ async def check_password_provider_3pid(
10831068 unsuccessful, `user_id` and `callback` are both `None`.
10841069 """
10851070 for provider in self .password_providers :
1086- if hasattr (provider , "check_3pid_auth" ):
1087- # This function is able to return a deferred that either
1088- # resolves None, meaning authentication failure, or upon
1089- # success, to a str (which is the user_id) or a tuple of
1090- # (user_id, callback_func), where callback_func should be run
1091- # after we've finished everything else
1092- result = await provider .check_3pid_auth (medium , address , password )
1093- if result :
1094- # Check if the return value is a str or a tuple
1095- if isinstance (result , str ):
1096- # If it's a str, set callback function to None
1097- result = (result , None )
1098- return result
1071+ result = await provider .check_3pid_auth (medium , address , password )
1072+ if result :
1073+ return result
10991074
11001075 return None , None
11011076
@@ -1153,16 +1128,11 @@ async def delete_access_token(self, access_token: str):
11531128
11541129 # see if any of our auth providers want to know about this
11551130 for provider in self .password_providers :
1156- if hasattr (provider , "on_logged_out" ):
1157- # This might return an awaitable, if it does block the log out
1158- # until it completes.
1159- result = provider .on_logged_out (
1160- user_id = user_info .user_id ,
1161- device_id = user_info .device_id ,
1162- access_token = access_token ,
1163- )
1164- if inspect .isawaitable (result ):
1165- await result
1131+ await provider .on_logged_out (
1132+ user_id = user_info .user_id ,
1133+ device_id = user_info .device_id ,
1134+ access_token = access_token ,
1135+ )
11661136
11671137 # delete pushers associated with this access token
11681138 if user_info .token_id is not None :
@@ -1191,11 +1161,10 @@ async def delete_access_tokens_for_user(
11911161
11921162 # see if any of our auth providers want to know about this
11931163 for provider in self .password_providers :
1194- if hasattr (provider , "on_logged_out" ):
1195- for token , token_id , device_id in tokens_and_devices :
1196- await provider .on_logged_out (
1197- user_id = user_id , device_id = device_id , access_token = token
1198- )
1164+ for token , token_id , device_id in tokens_and_devices :
1165+ await provider .on_logged_out (
1166+ user_id = user_id , device_id = device_id , access_token = token
1167+ )
11991168
12001169 # delete pushers associated with the access tokens
12011170 await self .hs .get_pusherpool ().remove_pushers_by_access_token (
@@ -1519,3 +1488,127 @@ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
15191488 macaroon .add_first_party_caveat ("gen = 1" )
15201489 macaroon .add_first_party_caveat ("user_id = %s" % (user_id ,))
15211490 return macaroon
1491+
1492+
1493+ class PasswordProvider :
1494+ """Wrapper for a password auth provider module
1495+
1496+ This class abstracts out all of the backwards-compatibility hacks for
1497+ password providers, to provide a consistent interface.
1498+ """
1499+
1500+ @classmethod
1501+ def load (cls , module , config , module_api : ModuleApi ) -> "PasswordProvider" :
1502+ try :
1503+ pp = module (config = config , account_handler = module_api )
1504+ except Exception as e :
1505+ logger .error ("Error while initializing %r: %s" , module , e )
1506+ raise
1507+ return cls (pp , module_api )
1508+
1509+ def __init__ (self , pp , module_api : ModuleApi ):
1510+ self ._pp = pp
1511+ self ._module_api = module_api
1512+
1513+ self ._supported_login_types = {}
1514+
1515+ # grandfather in check_password support
1516+ if hasattr (self ._pp , "check_password" ):
1517+ self ._supported_login_types [LoginType .PASSWORD ] = ("password" ,)
1518+
1519+ g = getattr (self ._pp , "get_supported_login_types" , None )
1520+ if g :
1521+ self ._supported_login_types .update (g ())
1522+
1523+ def __str__ (self ):
1524+ return str (self ._pp )
1525+
1526+ def get_supported_login_types (self ) -> Mapping [str , Iterable [str ]]:
1527+ """Get the login types supported by this password provider
1528+
1529+ Returns a map from a login type identifier (such as m.login.password) to an
1530+ iterable giving the fields which must be provided by the user in the submission
1531+ to the /login API.
1532+
1533+ This wrapper adds m.login.password to the list if the underlying password
1534+ provider supports the check_password() api.
1535+ """
1536+ return self ._supported_login_types
1537+
1538+ async def check_auth (
1539+ self , username : str , login_type : str , login_dict : JsonDict
1540+ ) -> Optional [Tuple [str , Optional [Callable ]]]:
1541+ """Check if the user has presented valid login credentials
1542+
1543+ This wrapper also calls check_password() if the underlying password provider
1544+ supports the check_password() api and the login type is m.login.password.
1545+
1546+ Args:
1547+ username: user id presented by the client. Either an MXID or an unqualified
1548+ username.
1549+
1550+ login_type: the login type being attempted - one of the types returned by
1551+ get_supported_login_types()
1552+
1553+ login_dict: the dictionary of login secrets passed by the client.
1554+
1555+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
1556+ user, and `callback` is an optional callback which will be called with the
1557+ result from the /login call (including access_token, device_id, etc.)
1558+ """
1559+ # first grandfather in a call to check_password
1560+ if login_type == LoginType .PASSWORD :
1561+ g = getattr (self ._pp , "check_password" , None )
1562+ if g :
1563+ qualified_user_id = self ._module_api .get_qualified_user_id (username )
1564+ is_valid = await self ._pp .check_password (
1565+ qualified_user_id , login_dict ["password" ]
1566+ )
1567+ if is_valid :
1568+ return qualified_user_id , None
1569+
1570+ g = getattr (self ._pp , "check_auth" , None )
1571+ if not g :
1572+ return None
1573+ result = await g (username , login_type , login_dict )
1574+
1575+ # Check if the return value is a str or a tuple
1576+ if isinstance (result , str ):
1577+ # If it's a str, set callback function to None
1578+ return result , None
1579+
1580+ return result
1581+
1582+ async def check_3pid_auth (
1583+ self , medium : str , address : str , password : str
1584+ ) -> Optional [Tuple [str , Optional [Callable ]]]:
1585+ g = getattr (self ._pp , "check_3pid_auth" , None )
1586+ if not g :
1587+ return None
1588+
1589+ # This function is able to return a deferred that either
1590+ # resolves None, meaning authentication failure, or upon
1591+ # success, to a str (which is the user_id) or a tuple of
1592+ # (user_id, callback_func), where callback_func should be run
1593+ # after we've finished everything else
1594+ result = await g (medium , address , password )
1595+
1596+ # Check if the return value is a str or a tuple
1597+ if isinstance (result , str ):
1598+ # If it's a str, set callback function to None
1599+ return result , None
1600+
1601+ return result
1602+
1603+ async def on_logged_out (
1604+ self , user_id : str , device_id : Optional [str ], access_token : str
1605+ ) -> None :
1606+ g = getattr (self ._pp , "on_logged_out" , None )
1607+ if not g :
1608+ return
1609+
1610+ # This might return an awaitable, if it does block the log out
1611+ # until it completes.
1612+ result = g (user_id = user_id , device_id = device_id , access_token = access_token ,)
1613+ if inspect .isawaitable (result ):
1614+ await result
0 commit comments