diff --git a/authentik/providers/oauth2/errors.py b/authentik/providers/oauth2/errors.py index cdc41217a..06b318cc3 100644 --- a/authentik/providers/oauth2/errors.py +++ b/authentik/providers/oauth2/errors.py @@ -1,7 +1,9 @@ """OAuth errors""" +from typing import Optional from urllib.parse import quote from authentik.lib.sentry import SentryIgnoredException +from authentik.providers.oauth2.models import GrantTypes class OAuth2Error(SentryIgnoredException): @@ -98,27 +100,34 @@ class AuthorizeError(OAuth2Error): "the registration parameter", } - def __init__(self, redirect_uri, error, grant_type): + def __init__( + self, + redirect_uri: str, + error: str, + grant_type: str, + state: Optional[str] = None, + ): super().__init__() self.error = error self.description = self._errors[error] self.redirect_uri = redirect_uri self.grant_type = grant_type + self.state = state - def create_uri(self, redirect_uri: str, state: str) -> str: + def create_uri(self, redirect_uri: str) -> str: """Get a redirect URI with the error message""" description = quote(str(self.description)) # See: # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError - hash_or_question = "#" if self.grant_type == "implicit" else "?" + hash_or_question = "#" if self.grant_type == GrantTypes.IMPLICIT else "?" uri = "{0}{1}error={2}&error_description={3}".format( redirect_uri, hash_or_question, self.error, description ) # Add state if present. - uri = uri + ("&state={0}".format(state) if state else "") + uri = uri + ("&state={0}".format(self.state) if self.state else "") return uri diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 8a2056c82..f59a698d5 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -91,6 +91,8 @@ class OAuthAuthorizationParams: # Because in this endpoint we handle both GET # and POST request. query_dict = request.POST if request.method == "POST" else request.GET + state = query_dict.get("state", "") + redirect_uri = query_dict.get("redirect_uri", "") response_type = query_dict.get("response_type", "") grant_type = None @@ -113,20 +115,21 @@ class OAuthAuthorizationParams: # Grant type validation. if not grant_type: LOGGER.warning("Invalid response type", type=response_type) + raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state) + + if "request" in query_dict: raise AuthorizeError( - query_dict.get("redirect_uri", ""), - "unsupported_response_type", - grant_type, + redirect_uri, "request_not_supported", grant_type, state ) max_age = query_dict.get("max_age") return OAuthAuthorizationParams( client_id=query_dict.get("client_id", ""), - redirect_uri=query_dict.get("redirect_uri", ""), + redirect_uri=redirect_uri, response_type=response_type, grant_type=grant_type, scope=query_dict.get("scope", "").split(), - state=query_dict.get("state", ""), + state=state, nonce=query_dict.get("nonce", ""), prompt=ALLOWED_PROMPT_PARAMS.intersection( set(query_dict.get("prompt", "").split()) @@ -253,7 +256,7 @@ class OAuthFulfillmentStage(StageView): return bad_request_message(request, error.description, title=error.error) except AuthorizeError as error: self.executor.stage_invalid() - uri = error.create_uri(self.params.redirect_uri, self.params.state) + uri = error.create_uri(self.params.redirect_uri) return redirect(uri) def create_response_uri(self) -> str: @@ -332,7 +335,10 @@ class OAuthFulfillmentStage(StageView): except OAuth2Error as error: LOGGER.exception("Error when trying to create response uri", error=error) raise AuthorizeError( - self.params.redirect_uri, "server_error", self.params.grant_type + self.params.redirect_uri, + "server_error", + self.params.grant_type, + self.params.state, ) uri = uri._replace( @@ -353,6 +359,8 @@ class AuthorizationFlowInitView(PolicyAccessView): see https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6""" try: self.params = OAuthAuthorizationParams.from_request(self.request) + except AuthorizeError as error: + raise RequestValidationError(redirect(error.create_uri(error.redirect_uri))) except OAuth2Error as error: raise RequestValidationError( bad_request_message(self.request, error.description, title=error.error) @@ -365,7 +373,7 @@ class AuthorizationFlowInitView(PolicyAccessView): self.params.redirect_uri, "login_required", self.params.grant_type ) raise RequestValidationError( - redirect(error.create_uri(self.params.redirect_uri, self.params.state)) + redirect(error.create_uri(self.params.redirect_uri)) ) def resolve_provider_application(self):