From 2b48ba41037d097d538fab5020543725b51e9de1 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 16 Apr 2021 11:29:23 +0200 Subject: [PATCH] sources/oauth: fix resolution of sources' provider type Signed-off-by: Jens Langhammer --- authentik/sources/oauth/types/manager.py | 11 ++++------- authentik/sources/oauth/views/dispatcher.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/authentik/sources/oauth/types/manager.py b/authentik/sources/oauth/types/manager.py index d58cd21c6..8d27a1d27 100644 --- a/authentik/sources/oauth/types/manager.py +++ b/authentik/sources/oauth/types/manager.py @@ -1,6 +1,6 @@ """Source type manager""" from enum import Enum -from typing import TYPE_CHECKING, Callable, Optional +from typing import Callable, Optional from structlog.stdlib import get_logger @@ -9,9 +9,6 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect LOGGER = get_logger() -if TYPE_CHECKING: - from authentik.sources.oauth.models import OAuthSource - class RequestKind(Enum): """Enum of OAuth Request types""" @@ -69,13 +66,13 @@ class SourceTypeManager: LOGGER.warning( "no matching type found, using default", wanted=type_name, - have=[x.name for x in self.__sources], + have=[x.slug for x in self.__sources], ) return found_type - def find(self, source: "OAuthSource", kind: RequestKind) -> Callable: + def find(self, type_name: str, kind: RequestKind) -> Callable: """Find fitting Source Type""" - found_type = self.find_type(source) + found_type = self.find_type(type_name) if kind == RequestKind.CALLBACK: return found_type.callback_view if kind == RequestKind.REDIRECT: diff --git a/authentik/sources/oauth/views/dispatcher.py b/authentik/sources/oauth/views/dispatcher.py index d89dc2fc1..63a0769b5 100644 --- a/authentik/sources/oauth/views/dispatcher.py +++ b/authentik/sources/oauth/views/dispatcher.py @@ -21,6 +21,6 @@ class DispatcherView(View): if not slug: raise Http404 source = get_object_or_404(OAuthSource, slug=slug) - view = MANAGER.find(source, kind=RequestKind(self.kind)) + view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind)) LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind) return view.as_view()(*args, **kwargs)