diff --git a/authentik/providers/oauth2/utils.py b/authentik/providers/oauth2/utils.py index c5a000776..88c272bbf 100644 --- a/authentik/providers/oauth2/utils.py +++ b/authentik/providers/oauth2/utils.py @@ -6,7 +6,6 @@ from typing import List, Optional, Tuple from django.http import HttpRequest, HttpResponse, JsonResponse from django.utils.cache import patch_vary_headers -from jwkest.jwt import JWT from structlog import get_logger from authentik.providers.oauth2.errors import BearerTokenError @@ -140,17 +139,3 @@ def protected_resource_view(scopes: List[str]): return view_wrapper return wrapper - - -def client_id_from_id_token(id_token): - """ - Extracts the client id from a JSON Web Token (JWT). - Returns a string or None. - """ - payload = JWT().unpack(id_token).payload() - aud = payload.get("aud", None) - if aud is None: - return None - if isinstance(aud, list): - return aud[0] - return aud diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index ce84253b5..2cc05a2b7 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -62,6 +62,7 @@ ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSNET, PROMPT_LOGIN} @dataclass +# pylint: disable=too-many-instance-attributes class OAuthAuthorizationParams: """Parameteres required to authorize an OAuth Client""" @@ -76,6 +77,8 @@ class OAuthAuthorizationParams: provider: OAuth2Provider = field(default_factory=OAuth2Provider) + request: Optional[str] = None + max_age: Optional[int] = None code_challenge: Optional[str] = None @@ -118,11 +121,6 @@ class OAuthAuthorizationParams: LOGGER.warning("Invalid response type", type=response_type) raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state) - if "request" in query_dict: - raise AuthorizeError( - redirect_uri, "request_not_supported", grant_type, state - ) - max_age = query_dict.get("max_age") return OAuthAuthorizationParams( client_id=query_dict.get("client_id", ""), @@ -135,6 +133,7 @@ class OAuthAuthorizationParams: prompt=ALLOWED_PROMPT_PARAMS.intersection( set(query_dict.get("prompt", "").split()) ), + request=query_dict.get("request", None), max_age=int(max_age) if max_age else None, code_challenge=query_dict.get("code_challenge"), code_challenge_method=query_dict.get("code_challenge_method"), @@ -148,9 +147,14 @@ class OAuthAuthorizationParams: except OAuth2Provider.DoesNotExist: LOGGER.warning("Invalid client identifier", client_id=self.client_id) raise ClientIdError() - is_open_id = SCOPE_OPENID in self.scope + self.check_redirect_uri() + self.check_scope() + self.check_nonce() + self.check_response_type() + self.check_code_challenge() - # Redirect URI validation. + def check_redirect_uri(self): + """Redirect URI validation.""" if not self.redirect_uri: LOGGER.warning("Missing redirect uri.") raise RedirectUriError() @@ -171,7 +175,14 @@ class OAuthAuthorizationParams: ) raise RedirectUriError() - if not is_open_id and ( + if self.request: + raise AuthorizeError( + self.redirect_uri, "request_not_supported", self.grant_type, self.state + ) + + def check_scope(self): + """Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" + if SCOPE_OPENID not in self.scope and ( self.grant_type == GrantTypes.HYBRID or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] @@ -181,14 +192,20 @@ class OAuthAuthorizationParams: self.redirect_uri, "invalid_scope", self.grant_type, self.state ) - # Nonce parameter validation. - if is_open_id and self.grant_type == GrantTypes.IMPLICIT and not self.nonce: + def check_nonce(self): + """Nonce parameter validation.""" + if ( + SCOPE_OPENID in self.scope + and self.grant_type == GrantTypes.IMPLICIT + and not self.nonce + ): raise AuthorizeError( self.redirect_uri, "invalid_request", self.grant_type, self.state ) - # Response type parameter validation. - if is_open_id: + def check_response_type(self): + """Response type parameter validation.""" + if SCOPE_OPENID in self.scope: actual_response_type = self.provider.response_type if "#" in self.provider.response_type: hash_index = actual_response_type.index("#") @@ -198,7 +215,8 @@ class OAuthAuthorizationParams: self.redirect_uri, "invalid_request", self.grant_type, self.state ) - # PKCE validation of the transformation method. + def check_code_challenge(self): + """PKCE validation of the transformation method.""" if self.code_challenge: if not (self.code_challenge_method in ["plain", "S256"]): raise AuthorizeError(