@@ -73,6 +73,14 @@ def login_id_thirdparty_from_phone(identifier):
7373 return {"type" : "m.id.thirdparty" , "medium" : "msisdn" , "address" : msisdn }
7474
7575
76+ def build_service_param (cas_service_url , client_redirect_url ):
77+ return "%s%s?redirectUrl=%s" % (
78+ cas_service_url ,
79+ "/_matrix/client/r0/login/cas/ticket" ,
80+ urllib .parse .quote (client_redirect_url , safe = "" ),
81+ )
82+
83+
7684class LoginRestServlet (RestServlet ):
7785 PATTERNS = client_patterns ("/login$" , v1 = True )
7886 CAS_TYPE = "m.login.cas"
@@ -428,18 +436,15 @@ def get_sso_url(self, client_redirect_url):
428436class CasRedirectServlet (BaseSSORedirectServlet ):
429437 def __init__ (self , hs ):
430438 super (CasRedirectServlet , self ).__init__ ()
431- self .cas_server_url = hs .config .cas_server_url . encode ( "ascii" )
432- self .cas_service_url = hs .config .cas_service_url . encode ( "ascii" )
439+ self .cas_server_url = hs .config .cas_server_url
440+ self .cas_service_url = hs .config .cas_service_url
433441
434442 def get_sso_url (self , client_redirect_url ):
435- client_redirect_url_param = urllib .parse .urlencode (
436- {b"redirectUrl" : client_redirect_url }
437- ).encode ("ascii" )
438- hs_redirect_url = self .cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
439- service_param = urllib .parse .urlencode (
440- {b"service" : b"%s?%s" % (hs_redirect_url , client_redirect_url_param )}
441- ).encode ("ascii" )
442- return b"%s/login?%s" % (self .cas_server_url , service_param )
443+ args = urllib .parse .urlencode (
444+ {"service" : build_service_param (self .cas_service_url , client_redirect_url )}
445+ )
446+
447+ return "%s/login?%s" % (self .cas_server_url , args )
443448
444449
445450class CasTicketServlet (RestServlet ):
@@ -448,10 +453,7 @@ class CasTicketServlet(RestServlet):
448453 def __init__ (self , hs ):
449454 super (CasTicketServlet , self ).__init__ ()
450455 self .cas_server_url = hs .config .cas_server_url
451- self .cas_service_url = (
452- hs .config .cas_service_url .encode ("ascii" )
453- + b"/_matrix/client/r0/login/cas/ticket?redirectUrl="
454- )
456+ self .cas_service_url = hs .config .cas_service_url
455457 self .cas_displayname_attribute = hs .config .cas_displayname_attribute
456458 self .cas_required_attributes = hs .config .cas_required_attributes
457459 self ._sso_auth_handler = SSOAuthHandler (hs )
@@ -460,12 +462,9 @@ def __init__(self, hs):
460462 async def on_GET (self , request ):
461463 client_redirect_url = parse_string (request , "redirectUrl" , required = True )
462464 uri = self .cas_server_url + "/proxyValidate"
463- service_url = self .cas_service_url + urllib .parse .quote (
464- client_redirect_url , safe = ""
465- ).encode ("ascii" )
466465 args = {
467466 "ticket" : parse_string (request , "ticket" , required = True ),
468- "service" : service_url ,
467+ "service" : build_service_param ( self . cas_service_url , client_redirect_url ) ,
469468 }
470469 try :
471470 body = await self ._http_client .get_raw (uri , args )
0 commit comments