providers/oauth2: improve error handling and event creation

This commit is contained in:
Jens Langhammer 2021-01-16 18:27:10 +01:00
parent 394ad6ade5
commit 2d2a404028
2 changed files with 44 additions and 13 deletions

View file

@ -1,6 +1,8 @@
"""OAuth errors""" """OAuth errors"""
from typing import Optional
from urllib.parse import quote from urllib.parse import quote
from authentik.events.models import Event, EventAction
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.providers.oauth2.models import GrantTypes from authentik.providers.oauth2.models import GrantTypes
@ -21,6 +23,13 @@ class OAuth2Error(SentryIgnoredException):
def __repr__(self) -> str: def __repr__(self) -> str:
return self.error return self.error
def to_event(self, message: Optional[str] = None) -> Event:
"""Create configuration_error Event and save it."""
return Event.new(
EventAction.CONFIGURATION_ERROR,
message=message or self.description,
)
class RedirectUriError(OAuth2Error): class RedirectUriError(OAuth2Error):
"""The request fails due to a missing, invalid, or mismatching """The request fails due to a missing, invalid, or mismatching
@ -32,6 +41,20 @@ class RedirectUriError(OAuth2Error):
"redirection URI (redirect_uri)." "redirection URI (redirect_uri)."
) )
provided_uri: str
allowed_uris: list[str]
def __init__(self, provided_uri: str, allowed_uris: list[str]) -> None:
super().__init__()
self.provided_uri = provided_uri
self.allowed_uris = allowed_uris
def to_event(self) -> Event:
return super().to_event(
f"Invalid redirect URI was used. Client used '{self.provided_uri}'. "
f"Allowed redirect URIs are {','.join(self.allowed_uris)}"
)
class ClientIdError(OAuth2Error): class ClientIdError(OAuth2Error):
"""The client identifier (client_id) is missing or invalid.""" """The client identifier (client_id) is missing or invalid."""
@ -39,6 +62,15 @@ class ClientIdError(OAuth2Error):
error = "Client ID Error" error = "Client ID Error"
description = "The client identifier (client_id) is missing or invalid." description = "The client identifier (client_id) is missing or invalid."
client_id: str
def __init__(self, client_id: str) -> None:
super().__init__()
self.client_id = client_id
def to_event(self) -> Event:
return super().to_event(f"Invalid client identifier: {self.client_id}.")
class UserAuthError(OAuth2Error): class UserAuthError(OAuth2Error):
""" """

View file

@ -145,7 +145,7 @@ class OAuthAuthorizationParams:
) )
except OAuth2Provider.DoesNotExist: except OAuth2Provider.DoesNotExist:
LOGGER.warning("Invalid client identifier", client_id=self.client_id) LOGGER.warning("Invalid client identifier", client_id=self.client_id)
raise ClientIdError() raise ClientIdError(client_id=self.client_id)
self.check_redirect_uri() self.check_redirect_uri()
self.check_scope() self.check_scope()
self.check_nonce() self.check_nonce()
@ -155,24 +155,18 @@ class OAuthAuthorizationParams:
"""Redirect URI validation.""" """Redirect URI validation."""
if not self.redirect_uri: if not self.redirect_uri:
LOGGER.warning("Missing redirect uri.") LOGGER.warning("Missing redirect uri.")
raise RedirectUriError() raise RedirectUriError("", self.provider.redirect_uris.split())
if self.redirect_uri.lower() not in [ if self.redirect_uri.lower() not in [
x.lower() for x in self.provider.redirect_uris.split() x.lower() for x in self.provider.redirect_uris.split()
]: ]:
Event.new(
EventAction.CONFIGURATION_ERROR,
provider=self.provider,
message="Invalid redirect URI was used.",
client_used=self.redirect_uri,
configured=self.provider.redirect_uris.split(),
).save()
LOGGER.warning( LOGGER.warning(
"Invalid redirect uri", "Invalid redirect uri",
redirect_uri=self.redirect_uri, redirect_uri=self.redirect_uri,
excepted=self.provider.redirect_uris.split(), excepted=self.provider.redirect_uris.split(),
) )
raise RedirectUriError() raise RedirectUriError(
self.redirect_uri, self.provider.redirect_uris.split()
)
if self.request: if self.request:
raise AuthorizeError( raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state self.redirect_uri, "request_not_supported", self.grant_type, self.state
@ -262,10 +256,12 @@ class OAuthFulfillmentStage(StageView):
).from_http(self.request) ).from_http(self.request)
return redirect(self.create_response_uri()) return redirect(self.create_response_uri())
except (ClientIdError, RedirectUriError) as error: except (ClientIdError, RedirectUriError) as error:
error.to_event().from_http(request)
self.executor.stage_invalid() self.executor.stage_invalid()
# pylint: disable=no-member # pylint: disable=no-member
return bad_request_message(request, error.description, title=error.error) return bad_request_message(request, error.description, title=error.error)
except AuthorizeError as error: except AuthorizeError as error:
error.to_event().from_http(request)
self.executor.stage_invalid() self.executor.stage_invalid()
return redirect(error.create_uri()) return redirect(error.create_uri())
@ -383,8 +379,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
try: try:
self.params = OAuthAuthorizationParams.from_request(self.request) self.params = OAuthAuthorizationParams.from_request(self.request)
except AuthorizeError as error: except AuthorizeError as error:
error.to_event().from_http(self.request)
raise RequestValidationError(redirect(error.create_uri())) raise RequestValidationError(redirect(error.create_uri()))
except OAuth2Error as error: except OAuth2Error as error:
error.to_event().from_http(self.request)
raise RequestValidationError( raise RequestValidationError(
bad_request_message(self.request, error.description, title=error.error) bad_request_message(self.request, error.description, title=error.error)
) )
@ -398,6 +396,7 @@ class AuthorizationFlowInitView(PolicyAccessView):
self.params.grant_type, self.params.grant_type,
self.params.state, self.params.state,
) )
error.to_event().from_http(self.request)
raise RequestValidationError(redirect(error.create_uri())) raise RequestValidationError(redirect(error.create_uri()))
def resolve_provider_application(self): def resolve_provider_application(self):