sources/oauth: fix resolution of sources' provider type

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-16 11:29:23 +02:00
parent 5e67f68f2b
commit 2b48ba4103
2 changed files with 5 additions and 8 deletions

View File

@ -1,6 +1,6 @@
"""Source type manager""" """Source type manager"""
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Callable, Optional from typing import Callable, Optional
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -9,9 +9,6 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.sources.oauth.models import OAuthSource
class RequestKind(Enum): class RequestKind(Enum):
"""Enum of OAuth Request types""" """Enum of OAuth Request types"""
@ -69,13 +66,13 @@ class SourceTypeManager:
LOGGER.warning( LOGGER.warning(
"no matching type found, using default", "no matching type found, using default",
wanted=type_name, wanted=type_name,
have=[x.name for x in self.__sources], have=[x.slug for x in self.__sources],
) )
return found_type return found_type
def find(self, source: "OAuthSource", kind: RequestKind) -> Callable: def find(self, type_name: str, kind: RequestKind) -> Callable:
"""Find fitting Source Type""" """Find fitting Source Type"""
found_type = self.find_type(source) found_type = self.find_type(type_name)
if kind == RequestKind.CALLBACK: if kind == RequestKind.CALLBACK:
return found_type.callback_view return found_type.callback_view
if kind == RequestKind.REDIRECT: if kind == RequestKind.REDIRECT:

View File

@ -21,6 +21,6 @@ class DispatcherView(View):
if not slug: if not slug:
raise Http404 raise Http404
source = get_object_or_404(OAuthSource, slug=slug) source = get_object_or_404(OAuthSource, slug=slug)
view = MANAGER.find(source, kind=RequestKind(self.kind)) view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind))
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind) LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
return view.as_view()(*args, **kwargs) return view.as_view()(*args, **kwargs)