sources/oauth: fix handling of sources with spaces in their name
This commit is contained in:
parent
4d45dc31a9
commit
f58ee7fb52
|
@ -21,5 +21,6 @@ class PassbookSourceOAuthConfig(AppConfig):
|
||||||
for source_type in settings.PASSBOOK_SOURCES_OAUTH_TYPES:
|
for source_type in settings.PASSBOOK_SOURCES_OAUTH_TYPES:
|
||||||
try:
|
try:
|
||||||
import_module(source_type)
|
import_module(source_type)
|
||||||
|
LOGGER.debug("Loaded OAuth Source Type", type=source_type)
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
LOGGER.debug(exc)
|
LOGGER.debug(exc)
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
"""Source type manager"""
|
"""Source type manager"""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
|
from passbook.sources.oauth.models import OAuthSource
|
||||||
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
|
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
@ -19,18 +21,20 @@ class RequestKind(Enum):
|
||||||
class SourceTypeManager:
|
class SourceTypeManager:
|
||||||
"""Manager to hold all Source types."""
|
"""Manager to hold all Source types."""
|
||||||
|
|
||||||
__source_types = {}
|
__source_types: Dict[RequestKind, Dict[str, Callable]] = {}
|
||||||
__names = []
|
__names: List[str] = []
|
||||||
|
|
||||||
def source(self, kind, name):
|
def source(self, kind: RequestKind, name: str):
|
||||||
"""Class decorator to register classes inline."""
|
"""Class decorator to register classes inline."""
|
||||||
|
|
||||||
def inner_wrapper(cls):
|
def inner_wrapper(cls):
|
||||||
if kind not in self.__source_types:
|
if kind.value not in self.__source_types:
|
||||||
self.__source_types[kind] = {}
|
self.__source_types[kind.value] = {}
|
||||||
self.__source_types[kind][name.lower()] = cls
|
self.__source_types[kind.value][slugify(name)] = cls
|
||||||
self.__names.append(name)
|
self.__names.append(name)
|
||||||
LOGGER.debug("Registered source", source_class=cls.__name__, kind=kind)
|
LOGGER.debug(
|
||||||
|
"Registered source", source_class=cls.__name__, kind=kind.value
|
||||||
|
)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return inner_wrapper
|
return inner_wrapper
|
||||||
|
@ -39,15 +43,16 @@ class SourceTypeManager:
|
||||||
"""Get list of tuples of all registered names"""
|
"""Get list of tuples of all registered names"""
|
||||||
return [(slugify(x), x) for x in set(self.__names)]
|
return [(slugify(x), x) for x in set(self.__names)]
|
||||||
|
|
||||||
def find(self, source, kind):
|
def find(self, source: OAuthSource, kind: RequestKind) -> Callable:
|
||||||
"""Find fitting Source Type"""
|
"""Find fitting Source Type"""
|
||||||
if kind in self.__source_types:
|
if kind.value in self.__source_types:
|
||||||
if source.provider_type in self.__source_types[kind]:
|
if source.provider_type in self.__source_types[kind.value]:
|
||||||
return self.__source_types[kind][source.provider_type]
|
return self.__source_types[kind.value][source.provider_type]
|
||||||
|
LOGGER.warning("no matching type found, using default")
|
||||||
# Return defaults
|
# Return defaults
|
||||||
if kind == RequestKind.callback:
|
if kind.value == RequestKind.callback:
|
||||||
return OAuthCallback
|
return OAuthCallback
|
||||||
if kind == RequestKind.redirect:
|
if kind.value == RequestKind.redirect:
|
||||||
return OAuthRedirect
|
return OAuthRedirect
|
||||||
raise KeyError
|
raise KeyError
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
"""OAuth Client User Creation Utils"""
|
"""OAuth Client User Creation Utils"""
|
||||||
|
|
||||||
from django.db.utils import IntegrityError
|
from django.db.utils import IntegrityError
|
||||||
|
|
||||||
from passbook.core.models import User
|
from passbook.core.models import User
|
||||||
|
|
||||||
|
|
||||||
def user_get_or_create(**kwargs):
|
def user_get_or_create(**kwargs: str) -> User:
|
||||||
"""Create user or return existing user"""
|
"""Create user or return existing user"""
|
||||||
try:
|
try:
|
||||||
new_user = User.objects.create_user(**kwargs)
|
new_user = User.objects.create_user(**kwargs)
|
||||||
|
|
|
@ -2,10 +2,13 @@
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.views import View
|
from django.views import View
|
||||||
|
from structlog import get_logger
|
||||||
|
|
||||||
from passbook.sources.oauth.models import OAuthSource
|
from passbook.sources.oauth.models import OAuthSource
|
||||||
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
|
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class DispatcherView(View):
|
class DispatcherView(View):
|
||||||
"""Dispatch OAuth Redirect/Callback views to their proper class based on URL parameters"""
|
"""Dispatch OAuth Redirect/Callback views to their proper class based on URL parameters"""
|
||||||
|
@ -19,4 +22,5 @@ class DispatcherView(View):
|
||||||
raise Http404
|
raise Http404
|
||||||
source = get_object_or_404(OAuthSource, slug=slug)
|
source = get_object_or_404(OAuthSource, slug=slug)
|
||||||
view = MANAGER.find(source, kind=RequestKind(self.kind))
|
view = MANAGER.find(source, kind=RequestKind(self.kind))
|
||||||
|
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
|
||||||
return view.as_view()(*args, **kwargs)
|
return view.as_view()(*args, **kwargs)
|
||||||
|
|
Reference in a new issue