1+ from __future__ import annotations
2+
13import builtins
24import logging
3- from typing import TYPE_CHECKING
5+ from typing import (
6+ Optional ,
7+ TYPE_CHECKING ,
8+ TypedDict ,
9+ )
410
511import jwt as pyjwt
612from social_core .exceptions import (
713 AuthAlreadyAssociated ,
814 AuthCanceled ,
15+ AuthForbidden ,
916 AuthTokenError ,
1017)
1118
1219from galaxy import (
1320 exceptions ,
1421 model ,
1522)
23+ from galaxy .model import UserAuthnzToken
1624from galaxy .util import (
1725 asbool ,
1826 etree ,
3240
3341if TYPE_CHECKING :
3442 from galaxy .managers .context import ProvidesAppContext
43+ from galaxy .webapps .base .webapp import GalaxyWebTransaction
3544
3645OIDC_BACKEND_SCHEMA = resource_path (__name__ , "xsd/oidc_backends_config.xsd" )
3746
4554}
4655
4756
57+ class RefreshResult (TypedDict ):
58+ refreshed : bool
59+ reauthentication_required : bool
60+
61+
4862class AuthnzManager :
4963 def __init__ (self , app , oidc_config_file , oidc_backends_config_file ):
5064 """
@@ -163,6 +177,8 @@ def _parse_idp_config(self, config_xml):
163177 rtv ["label" ] = config_xml .find ("label" ).text
164178 if config_xml .find ("require_create_confirmation" ) is not None :
165179 rtv ["require_create_confirmation" ] = asbool (config_xml .find ("require_create_confirmation" ).text )
180+ if config_xml .find ("require_session_refresh" ) is not None :
181+ rtv ["require_session_refresh" ] = asbool (config_xml .find ("require_session_refresh" ).text )
166182 if config_xml .find ("prompt" ) is not None :
167183 rtv ["prompt" ] = config_xml .find ("prompt" ).text
168184 if config_xml .find ("api_url" ) is not None :
@@ -226,17 +242,17 @@ def get_allowed_idps(self):
226242 # None, if no allowed idp list is set, and a list of EntityIDs if configured (in oidc_backend)
227243 return self .allowed_idps
228244
229- def _unify_provider_name (self , provider ) :
245+ def _unify_provider_name (self , provider : str ) -> str | None :
230246 if provider .lower () in self .oidc_backends_config :
231247 return provider .lower ()
232248 for k , v in BACKENDS_NAME .items ():
233249 if v == provider :
234250 return k .lower ()
235251 return None
236252
237- def _get_authnz_backend (self , provider : str , idphint = None ):
253+ def _get_authnz_backend (self , provider : str , idphint : str | None = None ) -> tuple [ bool , str , PSAAuthnz | None ] :
238254 unified_provider_name = self ._unify_provider_name (provider )
239- if unified_provider_name in self .oidc_backends_config :
255+ if unified_provider_name is not None and unified_provider_name in self .oidc_backends_config :
240256 provider = unified_provider_name
241257 identity_provider_class = self ._get_identity_provider_factory (self .oidc_backends_implementation [provider ])
242258 try :
@@ -281,29 +297,69 @@ def can_user_assume_authn(trans, authn_id):
281297 log .warning (msg )
282298 raise exceptions .ItemAccessibilityException (msg )
283299
284- def refresh_expiring_oidc_tokens_for_provider (self , trans , auth ):
300+ def refresh_expiring_oidc_tokens_for_provider (
301+ self , trans : GalaxyWebTransaction , auth : UserAuthnzToken
302+ ) -> RefreshResult :
303+ """
304+ Refresh expiring OIDC tokens for a specific provider.
305+
306+ Returns:
307+ RefreshResult: A dictionary containing a boolean indicating success, and a boolean
308+ indicating if reauthentication is required
309+ """
285310 try :
311+ if auth .provider is None :
312+ raise exceptions .AuthenticationFailed ("Provider is not set" )
286313 success , message , backend = self ._get_authnz_backend (auth .provider )
314+ if backend is None :
315+ msg = f"Provider `{ auth .provider } ` not found"
316+ log .error (msg )
317+ return {"refreshed" : False , "reauthentication_required" : False }
287318 if success is False :
288319 msg = f"An error occurred when refreshing user token on `{ auth .provider } ` identity provider: { message } "
289320 log .error (msg )
290- return False
321+ return { "refreshed" : False , "reauthentication_required" : False }
291322 refreshed = backend .refresh (trans , auth )
292323 if refreshed :
293324 log .debug (f"Refreshed user token via `{ auth .provider } ` identity provider" )
294- return True
295- except Exception :
296- log .exception ("An error occurred when refreshing user token" )
297- return False
325+ return {"refreshed" : refreshed , "reauthentication_required" : False }
326+ except (AuthTokenError , AuthCanceled , AuthForbidden ):
327+ log .warning ("Authentication session has expired or is invalid, reauth required." )
328+ return {"refreshed" : False , "reauthentication_required" : True }
329+ except Exception as e :
330+ log .warning (f"An error occurred when refreshing user token: { e } " )
331+ return {"refreshed" : False , "reauthentication_required" : False }
332+
333+ def refresh_expiring_oidc_tokens (
334+ self , trans : GalaxyWebTransaction , user : Optional [model .User ] = None
335+ ) -> str | None :
336+ """
337+ Refresh expiring OIDC tokens for all providers associated with a user.
298338
299- def refresh_expiring_oidc_tokens (self , trans , user = None ):
339+ Returns:
340+ str | None: The provider name if refresh fails and require_session_refresh is enabled, otherwise None
341+ """
300342 user = trans .user or user
301343 if not isinstance (user , model .User ):
302- return
344+ return None
303345 for auth in user .social_auth or []:
304- self .refresh_expiring_oidc_tokens_for_provider (trans , auth )
346+ result = self .refresh_expiring_oidc_tokens_for_provider (trans , auth )
347+ if auth .provider is None :
348+ continue
349+ provider = self ._unify_provider_name (auth .provider )
350+ if provider is None :
351+ continue
352+ config = self .oidc_backends_config .get (provider , None )
353+ if config is None :
354+ continue
355+ # Redirect to OIDC login if refresh fails and require_session_refresh is enabled
356+ if config .get ("require_session_refresh" ) and result ["reauthentication_required" ]:
357+ return provider
358+ return None
305359
306- def authenticate (self , provider , trans , idphint = None ):
360+ def authenticate (
361+ self , provider : str , trans : GalaxyWebTransaction , idphint : str | None = None
362+ ) -> tuple [bool , str , str | None ]:
307363 """
308364 :type provider: string
309365 :param provider: set the name of the identity provider to be
@@ -314,6 +370,8 @@ def authenticate(self, provider, trans, idphint=None):
314370 """
315371 try :
316372 success , message , backend = self ._get_authnz_backend (provider , idphint = idphint )
373+ if backend is None :
374+ return False , f"Provider `{ provider } ` not found" , None
317375 if success is False :
318376 return False , message , None
319377 # Check allowed IDPs for providers that support idphint (keycloak, cilogon)
@@ -365,9 +423,11 @@ def callback(self, provider, state_token, authz_code, trans, login_redirect_url,
365423 log .exception (msg )
366424 return False , msg , (None , None )
367425
368- def create_user (self , provider : str , token : str , trans : " ProvidesAppContext" , login_redirect_url : str ):
426+ def create_user (self , provider : str , token : str , trans : ProvidesAppContext , login_redirect_url : str ):
369427 try :
370428 success , message , backend = self ._get_authnz_backend (provider )
429+ if backend is None :
430+ raise ValueError (f"Provider `{ provider } ` not found" )
371431 if success is False :
372432 return False , message , (None , None )
373433 return success , message , backend .create_user (token , trans , login_redirect_url )
0 commit comments