Skip to content

Commit 21408ef

Browse files
authored
Merge pull request galaxyproject#22449 from AustralianBioCommons/oidc-require-refresh
Require logging in again when OIDC tokens can't be refreshed
2 parents e4d4bcb + 0dbca9b commit 21408ef

5 files changed

Lines changed: 365 additions & 18 deletions

File tree

lib/galaxy/authnz/managers.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
1+
from __future__ import annotations
2+
13
import builtins
24
import logging
3-
from typing import TYPE_CHECKING
5+
from typing import (
6+
Optional,
7+
TYPE_CHECKING,
8+
TypedDict,
9+
)
410

511
import jwt as pyjwt
612
from social_core.exceptions import (
713
AuthAlreadyAssociated,
814
AuthCanceled,
15+
AuthForbidden,
916
AuthTokenError,
1017
)
1118

1219
from galaxy import (
1320
exceptions,
1421
model,
1522
)
23+
from galaxy.model import UserAuthnzToken
1624
from galaxy.util import (
1725
asbool,
1826
etree,
@@ -32,6 +40,7 @@
3240

3341
if TYPE_CHECKING:
3442
from galaxy.managers.context import ProvidesAppContext
43+
from galaxy.webapps.base.webapp import GalaxyWebTransaction
3544

3645
OIDC_BACKEND_SCHEMA = resource_path(__name__, "xsd/oidc_backends_config.xsd")
3746

@@ -45,6 +54,11 @@
4554
}
4655

4756

57+
class RefreshResult(TypedDict):
58+
refreshed: bool
59+
reauthentication_required: bool
60+
61+
4862
class AuthnzManager:
4963
def __init__(self, app, oidc_config_file, oidc_backends_config_file):
5064
"""
@@ -163,6 +177,8 @@ def _parse_idp_config(self, config_xml):
163177
rtv["label"] = config_xml.find("label").text
164178
if config_xml.find("require_create_confirmation") is not None:
165179
rtv["require_create_confirmation"] = asbool(config_xml.find("require_create_confirmation").text)
180+
if config_xml.find("require_session_refresh") is not None:
181+
rtv["require_session_refresh"] = asbool(config_xml.find("require_session_refresh").text)
166182
if config_xml.find("prompt") is not None:
167183
rtv["prompt"] = config_xml.find("prompt").text
168184
if config_xml.find("api_url") is not None:
@@ -226,17 +242,17 @@ def get_allowed_idps(self):
226242
# None, if no allowed idp list is set, and a list of EntityIDs if configured (in oidc_backend)
227243
return self.allowed_idps
228244

229-
def _unify_provider_name(self, provider):
245+
def _unify_provider_name(self, provider: str) -> str | None:
230246
if provider.lower() in self.oidc_backends_config:
231247
return provider.lower()
232248
for k, v in BACKENDS_NAME.items():
233249
if v == provider:
234250
return k.lower()
235251
return None
236252

237-
def _get_authnz_backend(self, provider: str, idphint=None):
253+
def _get_authnz_backend(self, provider: str, idphint: str | None = None) -> tuple[bool, str, PSAAuthnz | None]:
238254
unified_provider_name = self._unify_provider_name(provider)
239-
if unified_provider_name in self.oidc_backends_config:
255+
if unified_provider_name is not None and unified_provider_name in self.oidc_backends_config:
240256
provider = unified_provider_name
241257
identity_provider_class = self._get_identity_provider_factory(self.oidc_backends_implementation[provider])
242258
try:
@@ -281,29 +297,69 @@ def can_user_assume_authn(trans, authn_id):
281297
log.warning(msg)
282298
raise exceptions.ItemAccessibilityException(msg)
283299

284-
def refresh_expiring_oidc_tokens_for_provider(self, trans, auth):
300+
def refresh_expiring_oidc_tokens_for_provider(
301+
self, trans: GalaxyWebTransaction, auth: UserAuthnzToken
302+
) -> RefreshResult:
303+
"""
304+
Refresh expiring OIDC tokens for a specific provider.
305+
306+
Returns:
307+
RefreshResult: A dictionary containing a boolean indicating success, and a boolean
308+
indicating if reauthentication is required
309+
"""
285310
try:
311+
if auth.provider is None:
312+
raise exceptions.AuthenticationFailed("Provider is not set")
286313
success, message, backend = self._get_authnz_backend(auth.provider)
314+
if backend is None:
315+
msg = f"Provider `{auth.provider}` not found"
316+
log.error(msg)
317+
return {"refreshed": False, "reauthentication_required": False}
287318
if success is False:
288319
msg = f"An error occurred when refreshing user token on `{auth.provider}` identity provider: {message}"
289320
log.error(msg)
290-
return False
321+
return {"refreshed": False, "reauthentication_required": False}
291322
refreshed = backend.refresh(trans, auth)
292323
if refreshed:
293324
log.debug(f"Refreshed user token via `{auth.provider}` identity provider")
294-
return True
295-
except Exception:
296-
log.exception("An error occurred when refreshing user token")
297-
return False
325+
return {"refreshed": refreshed, "reauthentication_required": False}
326+
except (AuthTokenError, AuthCanceled, AuthForbidden):
327+
log.warning("Authentication session has expired or is invalid, reauth required.")
328+
return {"refreshed": False, "reauthentication_required": True}
329+
except Exception as e:
330+
log.warning(f"An error occurred when refreshing user token: {e}")
331+
return {"refreshed": False, "reauthentication_required": False}
332+
333+
def refresh_expiring_oidc_tokens(
334+
self, trans: GalaxyWebTransaction, user: Optional[model.User] = None
335+
) -> str | None:
336+
"""
337+
Refresh expiring OIDC tokens for all providers associated with a user.
298338
299-
def refresh_expiring_oidc_tokens(self, trans, user=None):
339+
Returns:
340+
str | None: The provider name if refresh fails and require_session_refresh is enabled, otherwise None
341+
"""
300342
user = trans.user or user
301343
if not isinstance(user, model.User):
302-
return
344+
return None
303345
for auth in user.social_auth or []:
304-
self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
346+
result = self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
347+
if auth.provider is None:
348+
continue
349+
provider = self._unify_provider_name(auth.provider)
350+
if provider is None:
351+
continue
352+
config = self.oidc_backends_config.get(provider, None)
353+
if config is None:
354+
continue
355+
# Redirect to OIDC login if refresh fails and require_session_refresh is enabled
356+
if config.get("require_session_refresh") and result["reauthentication_required"]:
357+
return provider
358+
return None
305359

306-
def authenticate(self, provider, trans, idphint=None):
360+
def authenticate(
361+
self, provider: str, trans: GalaxyWebTransaction, idphint: str | None = None
362+
) -> tuple[bool, str, str | None]:
307363
"""
308364
:type provider: string
309365
:param provider: set the name of the identity provider to be
@@ -314,6 +370,8 @@ def authenticate(self, provider, trans, idphint=None):
314370
"""
315371
try:
316372
success, message, backend = self._get_authnz_backend(provider, idphint=idphint)
373+
if backend is None:
374+
return False, f"Provider `{provider}` not found", None
317375
if success is False:
318376
return False, message, None
319377
# Check allowed IDPs for providers that support idphint (keycloak, cilogon)
@@ -365,9 +423,11 @@ def callback(self, provider, state_token, authz_code, trans, login_redirect_url,
365423
log.exception(msg)
366424
return False, msg, (None, None)
367425

368-
def create_user(self, provider: str, token: str, trans: "ProvidesAppContext", login_redirect_url: str):
426+
def create_user(self, provider: str, token: str, trans: ProvidesAppContext, login_redirect_url: str):
369427
try:
370428
success, message, backend = self._get_authnz_backend(provider)
429+
if backend is None:
430+
raise ValueError(f"Provider `{provider}` not found")
371431
if success is False:
372432
return False, message, (None, None)
373433
return success, message, backend.create_user(token, trans, login_redirect_url)

lib/galaxy/authnz/xsd/oidc_backends_config.xsd

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@
6565
</xs:documentation>
6666
</xs:annotation>
6767
</xs:element>
68+
<xs:element name="require_session_refresh" minOccurs="0" type="xs:boolean">
69+
<xs:annotation>
70+
<xs:documentation>
71+
Require the user to refresh their session (via refresh token)
72+
when the access token expires. Users will be required to reauthenticate
73+
if refreshing fails.
74+
</xs:documentation>
75+
</xs:annotation>
76+
</xs:element>
6877
<xs:element name="ca_bundle" minOccurs="0" type="xs:string">
6978
<xs:annotation>
7079
<xs:documentation>

lib/galaxy/config/sample/oidc_backends_config.xml.sample

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ _______________
2828
2929
- require_create_confirmation: A boolean value that decides whether a NewUserConfirmation page shows up.
3030
31+
- require_session_refresh: A boolean value that decides whether failed token refresh requires the
32+
user to reauthenticate with this provider.
33+
3134
3235
IMPORTANT NOTES
3336
_______________

lib/galaxy/webapps/base/webapp.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,12 @@ def __init__(
358358
self._ensure_valid_session(session_cookie)
359359

360360
if hasattr(self.app, "authnz_manager") and self.app.authnz_manager:
361-
self.app.authnz_manager.refresh_expiring_oidc_tokens(self)
361+
# Check for expiring tokens and refresh them. If configured (at the individual provider
362+
# level), require a reauthentication on failed refresh.
363+
reauth_provider = self.app.authnz_manager.refresh_expiring_oidc_tokens(self)
364+
if reauth_provider:
365+
self.handle_user_reauthentication(reauth_provider)
366+
return
362367

363368
if self.galaxy_session:
364369
# When we've authenticated by session, we have to check the
@@ -893,6 +898,21 @@ def handle_user_logout(self, logout_all=False):
893898
elif self.webapp.name == "tool_shed":
894899
self.__update_session_cookie(name="galaxycommunitysession")
895900

901+
def handle_user_reauthentication(self, reauth_provider: str) -> None:
902+
"""
903+
Handle user being required to log in again after failed OIDC refresh
904+
"""
905+
log.info("OIDC refresh failed terminally for provider `%s`, forcing re-login", reauth_provider)
906+
if self.galaxy_session:
907+
self.handle_user_logout()
908+
if self.environ.get("is_api_request", False):
909+
self.response.status = 401
910+
self.error_message = "Authentication session expired. Please log in again."
911+
self.user = None
912+
self.galaxy_session = None
913+
else:
914+
self.response.send_redirect(url_for(f"/authnz/{reauth_provider}/login", redirect="true", next="/"))
915+
896916
def get_galaxy_session(self):
897917
"""
898918
Return the current galaxy session

0 commit comments

Comments
 (0)