providers/oauth2: redirect back correctly with state on AuthorizationError

This commit is contained in:
Jens Langhammer 2020-12-27 15:22:53 +01:00
parent 55322995a1
commit bcd0686a33
2 changed files with 29 additions and 12 deletions

View File

@ -1,7 +1,9 @@
"""OAuth errors""" """OAuth errors"""
from typing import Optional
from urllib.parse import quote from urllib.parse import quote
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.providers.oauth2.models import GrantTypes
class OAuth2Error(SentryIgnoredException): class OAuth2Error(SentryIgnoredException):
@ -98,27 +100,34 @@ class AuthorizeError(OAuth2Error):
"the registration parameter", "the registration parameter",
} }
def __init__(self, redirect_uri, error, grant_type): def __init__(
self,
redirect_uri: str,
error: str,
grant_type: str,
state: Optional[str] = None,
):
super().__init__() super().__init__()
self.error = error self.error = error
self.description = self._errors[error] self.description = self._errors[error]
self.redirect_uri = redirect_uri self.redirect_uri = redirect_uri
self.grant_type = grant_type self.grant_type = grant_type
self.state = state
def create_uri(self, redirect_uri: str, state: str) -> str: def create_uri(self, redirect_uri: str) -> str:
"""Get a redirect URI with the error message""" """Get a redirect URI with the error message"""
description = quote(str(self.description)) description = quote(str(self.description))
# See: # See:
# http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthError
hash_or_question = "#" if self.grant_type == "implicit" else "?" hash_or_question = "#" if self.grant_type == GrantTypes.IMPLICIT else "?"
uri = "{0}{1}error={2}&error_description={3}".format( uri = "{0}{1}error={2}&error_description={3}".format(
redirect_uri, hash_or_question, self.error, description redirect_uri, hash_or_question, self.error, description
) )
# Add state if present. # Add state if present.
uri = uri + ("&state={0}".format(state) if state else "") uri = uri + ("&state={0}".format(self.state) if self.state else "")
return uri return uri

View File

@ -91,6 +91,8 @@ class OAuthAuthorizationParams:
# Because in this endpoint we handle both GET # Because in this endpoint we handle both GET
# and POST request. # and POST request.
query_dict = request.POST if request.method == "POST" else request.GET query_dict = request.POST if request.method == "POST" else request.GET
state = query_dict.get("state", "")
redirect_uri = query_dict.get("redirect_uri", "")
response_type = query_dict.get("response_type", "") response_type = query_dict.get("response_type", "")
grant_type = None grant_type = None
@ -113,20 +115,21 @@ class OAuthAuthorizationParams:
# Grant type validation. # Grant type validation.
if not grant_type: if not grant_type:
LOGGER.warning("Invalid response type", type=response_type) LOGGER.warning("Invalid response type", type=response_type)
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
if "request" in query_dict:
raise AuthorizeError( raise AuthorizeError(
query_dict.get("redirect_uri", ""), redirect_uri, "request_not_supported", grant_type, state
"unsupported_response_type",
grant_type,
) )
max_age = query_dict.get("max_age") max_age = query_dict.get("max_age")
return OAuthAuthorizationParams( return OAuthAuthorizationParams(
client_id=query_dict.get("client_id", ""), client_id=query_dict.get("client_id", ""),
redirect_uri=query_dict.get("redirect_uri", ""), redirect_uri=redirect_uri,
response_type=response_type, response_type=response_type,
grant_type=grant_type, grant_type=grant_type,
scope=query_dict.get("scope", "").split(), scope=query_dict.get("scope", "").split(),
state=query_dict.get("state", ""), state=state,
nonce=query_dict.get("nonce", ""), nonce=query_dict.get("nonce", ""),
prompt=ALLOWED_PROMPT_PARAMS.intersection( prompt=ALLOWED_PROMPT_PARAMS.intersection(
set(query_dict.get("prompt", "").split()) set(query_dict.get("prompt", "").split())
@ -253,7 +256,7 @@ class OAuthFulfillmentStage(StageView):
return bad_request_message(request, error.description, title=error.error) return bad_request_message(request, error.description, title=error.error)
except AuthorizeError as error: except AuthorizeError as error:
self.executor.stage_invalid() self.executor.stage_invalid()
uri = error.create_uri(self.params.redirect_uri, self.params.state) uri = error.create_uri(self.params.redirect_uri)
return redirect(uri) return redirect(uri)
def create_response_uri(self) -> str: def create_response_uri(self) -> str:
@ -332,7 +335,10 @@ class OAuthFulfillmentStage(StageView):
except OAuth2Error as error: except OAuth2Error as error:
LOGGER.exception("Error when trying to create response uri", error=error) LOGGER.exception("Error when trying to create response uri", error=error)
raise AuthorizeError( raise AuthorizeError(
self.params.redirect_uri, "server_error", self.params.grant_type self.params.redirect_uri,
"server_error",
self.params.grant_type,
self.params.state,
) )
uri = uri._replace( uri = uri._replace(
@ -353,6 +359,8 @@ class AuthorizationFlowInitView(PolicyAccessView):
see https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6""" see https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6"""
try: try:
self.params = OAuthAuthorizationParams.from_request(self.request) self.params = OAuthAuthorizationParams.from_request(self.request)
except AuthorizeError as error:
raise RequestValidationError(redirect(error.create_uri(error.redirect_uri)))
except OAuth2Error as error: except OAuth2Error as error:
raise RequestValidationError( raise RequestValidationError(
bad_request_message(self.request, error.description, title=error.error) bad_request_message(self.request, error.description, title=error.error)
@ -365,7 +373,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
self.params.redirect_uri, "login_required", self.params.grant_type self.params.redirect_uri, "login_required", self.params.grant_type
) )
raise RequestValidationError( raise RequestValidationError(
redirect(error.create_uri(self.params.redirect_uri, self.params.state)) redirect(error.create_uri(self.params.redirect_uri))
) )
def resolve_provider_application(self): def resolve_provider_application(self):