sources/oauth: revamp types system, move default URLs to type
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
83fc22005c
commit
1daba5db87
|
@ -2,10 +2,11 @@
|
|||
from django.urls.base import reverse_lazy
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField, SerializerMethodField
|
||||
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from drf_yasg.utils import swagger_serializer_method
|
||||
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
|
@ -13,6 +14,18 @@ from authentik.sources.oauth.models import OAuthSource
|
|||
from authentik.sources.oauth.types.manager import MANAGER
|
||||
|
||||
|
||||
class SourceTypeSerializer(PassiveSerializer):
|
||||
"""Serializer for SourceType"""
|
||||
|
||||
name = CharField(required=True)
|
||||
slug = CharField(required=True)
|
||||
urls_customizable = BooleanField()
|
||||
request_token_url = CharField(read_only=True, allow_null=True)
|
||||
authorization_url = CharField(read_only=True, allow_null=True)
|
||||
access_token_url = CharField(read_only=True, allow_null=True)
|
||||
profile_url = CharField(read_only=True, allow_null=True)
|
||||
|
||||
|
||||
class OAuthSourceSerializer(SourceSerializer):
|
||||
"""OAuth Source Serializer"""
|
||||
|
||||
|
@ -28,6 +41,13 @@ class OAuthSourceSerializer(SourceSerializer):
|
|||
return relative_url
|
||||
return self.context["request"].build_absolute_uri(relative_url)
|
||||
|
||||
type = SerializerMethodField()
|
||||
|
||||
@swagger_serializer_method(serializer_or_field=SourceTypeSerializer)
|
||||
def get_type(self, instace: OAuthSource) -> SourceTypeSerializer:
|
||||
"""Get source's type configuration"""
|
||||
return SourceTypeSerializer(instace.type).data
|
||||
|
||||
class Meta:
|
||||
model = OAuthSource
|
||||
fields = SourceSerializer.Meta.fields + [
|
||||
|
@ -39,17 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
|
|||
"consumer_key",
|
||||
"consumer_secret",
|
||||
"callback_url",
|
||||
"type",
|
||||
]
|
||||
extra_kwargs = {"consumer_secret": {"write_only": True}}
|
||||
|
||||
|
||||
class OAuthSourceProviderType(PassiveSerializer):
|
||||
"""OAuth Provider"""
|
||||
|
||||
name = CharField(required=True)
|
||||
value = CharField(required=True)
|
||||
|
||||
|
||||
class OAuthSourceViewSet(ModelViewSet):
|
||||
"""Source Viewset"""
|
||||
|
||||
|
@ -57,16 +71,11 @@ class OAuthSourceViewSet(ModelViewSet):
|
|||
serializer_class = OAuthSourceSerializer
|
||||
lookup_field = "slug"
|
||||
|
||||
@swagger_auto_schema(responses={200: OAuthSourceProviderType(many=True)})
|
||||
@swagger_auto_schema(responses={200: SourceTypeSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def provider_types(self, request: Request) -> Response:
|
||||
def source_types(self, request: Request) -> Response:
|
||||
"""Get all creatable source types"""
|
||||
data = []
|
||||
for key, value in MANAGER.get_name_tuple():
|
||||
data.append(
|
||||
{
|
||||
"name": value,
|
||||
"value": key,
|
||||
}
|
||||
)
|
||||
return Response(OAuthSourceProviderType(data, many=True).data)
|
||||
for source_type in MANAGER.get():
|
||||
data.append(SourceTypeSerializer(source_type).data)
|
||||
return Response(data)
|
||||
|
|
|
@ -1,138 +0,0 @@
|
|||
"""authentik oauth_client forms"""
|
||||
|
||||
from django import forms
|
||||
|
||||
from authentik.flows.models import Flow, FlowDesignation
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.manager import MANAGER
|
||||
|
||||
|
||||
class OAuthSourceForm(forms.ModelForm):
|
||||
"""OAuthSource Form"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["authentication_flow"].queryset = Flow.objects.filter(
|
||||
designation=FlowDesignation.AUTHENTICATION
|
||||
)
|
||||
self.fields["authentication_flow"].required = True
|
||||
self.fields["enrollment_flow"].queryset = Flow.objects.filter(
|
||||
designation=FlowDesignation.ENROLLMENT
|
||||
)
|
||||
self.fields["enrollment_flow"].required = True
|
||||
if hasattr(self.Meta, "overrides"):
|
||||
for overide_field, overide_value in getattr(self.Meta, "overrides").items():
|
||||
self.fields[overide_field].initial = overide_value
|
||||
self.fields[overide_field].widget.attrs["readonly"] = "readonly"
|
||||
|
||||
class Meta:
|
||||
|
||||
model = OAuthSource
|
||||
fields = [
|
||||
"name",
|
||||
"slug",
|
||||
"enabled",
|
||||
"policy_engine_mode",
|
||||
"authentication_flow",
|
||||
"enrollment_flow",
|
||||
"provider_type",
|
||||
"request_token_url",
|
||||
"authorization_url",
|
||||
"access_token_url",
|
||||
"profile_url",
|
||||
"consumer_key",
|
||||
"consumer_secret",
|
||||
]
|
||||
widgets = {
|
||||
"name": forms.TextInput(),
|
||||
"consumer_key": forms.TextInput(),
|
||||
"consumer_secret": forms.TextInput(),
|
||||
"provider_type": forms.Select(choices=MANAGER.get_name_tuple()),
|
||||
}
|
||||
|
||||
|
||||
class GitHubOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for GitHub"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "github",
|
||||
"request_token_url": "",
|
||||
"authorization_url": "https://github.com/login/oauth/authorize",
|
||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||
"profile_url": "https://api.github.com/user",
|
||||
}
|
||||
|
||||
|
||||
class TwitterOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for Twitter"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "twitter",
|
||||
"request_token_url": "https://api.twitter.com/oauth/request_token",
|
||||
"authorization_url": "https://api.twitter.com/oauth/authenticate",
|
||||
"access_token_url": "https://api.twitter.com/oauth/access_token",
|
||||
"profile_url": (
|
||||
"https://api.twitter.com/1.1/account/"
|
||||
"verify_credentials.json?include_email=true"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class FacebookOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for Facebook"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "facebook",
|
||||
"request_token_url": "",
|
||||
"authorization_url": "https://www.facebook.com/v7.0/dialog/oauth",
|
||||
"access_token_url": "https://graph.facebook.com/v7.0/oauth/access_token",
|
||||
"profile_url": "https://graph.facebook.com/v7.0/me?fields=id,name,email",
|
||||
}
|
||||
|
||||
|
||||
class DiscordOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for Discord"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "discord",
|
||||
"request_token_url": "",
|
||||
"authorization_url": "https://discord.com/api/oauth2/authorize",
|
||||
"access_token_url": "https://discord.com/api/oauth2/token",
|
||||
"profile_url": "https://discord.com/api/users/@me",
|
||||
}
|
||||
|
||||
|
||||
class GoogleOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for Google"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "google",
|
||||
"request_token_url": "",
|
||||
"authorization_url": "https://accounts.google.com/o/oauth2/auth",
|
||||
"access_token_url": "https://accounts.google.com/o/oauth2/token",
|
||||
"profile_url": "https://www.googleapis.com/oauth2/v1/userinfo",
|
||||
}
|
||||
|
||||
|
||||
class AzureADOAuthSourceForm(OAuthSourceForm):
|
||||
"""OAuth Source form with pre-determined URL for AzureAD"""
|
||||
|
||||
class Meta(OAuthSourceForm.Meta):
|
||||
|
||||
overrides = {
|
||||
"provider_type": "azure-ad",
|
||||
"request_token_url": "",
|
||||
"authorization_url": "https://login.microsoftonline.com/common/oauth2/authorize",
|
||||
"access_token_url": "https://login.microsoftonline.com/common/oauth2/token",
|
||||
"profile_url": "https://graph.windows.net/myorganization/me?api-version=1.6",
|
||||
}
|
|
@ -1,8 +1,7 @@
|
|||
"""OAuth Client models"""
|
||||
from typing import Optional, Type
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from django.db import models
|
||||
from django.forms import ModelForm
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
@ -11,6 +10,9 @@ from rest_framework.serializers import Serializer
|
|||
from authentik.core.models import Source, UserSourceConnection
|
||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.sources.oauth.types.manager import SourceType
|
||||
|
||||
|
||||
class OAuthSource(Source):
|
||||
"""Login using a Generic OAuth provider."""
|
||||
|
@ -43,10 +45,15 @@ class OAuthSource(Source):
|
|||
consumer_secret = models.TextField()
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import OAuthSourceForm
|
||||
def type(self) -> "SourceType":
|
||||
"""Return the provider instance for this source"""
|
||||
from authentik.sources.oauth.types.manager import MANAGER
|
||||
|
||||
return OAuthSourceForm
|
||||
return MANAGER.find_type(self)
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-source-oauth-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> Type[Serializer]:
|
||||
|
@ -86,12 +93,6 @@ class OAuthSource(Source):
|
|||
class GitHubOAuthSource(OAuthSource):
|
||||
"""Social Login using GitHub.com or a GitHub-Enterprise Instance."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import GitHubOAuthSourceForm
|
||||
|
||||
return GitHubOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -102,12 +103,6 @@ class GitHubOAuthSource(OAuthSource):
|
|||
class TwitterOAuthSource(OAuthSource):
|
||||
"""Social Login using Twitter.com"""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import TwitterOAuthSourceForm
|
||||
|
||||
return TwitterOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -118,12 +113,6 @@ class TwitterOAuthSource(OAuthSource):
|
|||
class FacebookOAuthSource(OAuthSource):
|
||||
"""Social Login using Facebook.com."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import FacebookOAuthSourceForm
|
||||
|
||||
return FacebookOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -134,12 +123,6 @@ class FacebookOAuthSource(OAuthSource):
|
|||
class DiscordOAuthSource(OAuthSource):
|
||||
"""Social Login using Discord."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import DiscordOAuthSourceForm
|
||||
|
||||
return DiscordOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -150,12 +133,6 @@ class DiscordOAuthSource(OAuthSource):
|
|||
class GoogleOAuthSource(OAuthSource):
|
||||
"""Social Login using Google or Gsuite."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import GoogleOAuthSourceForm
|
||||
|
||||
return GoogleOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -166,12 +143,6 @@ class GoogleOAuthSource(OAuthSource):
|
|||
class AzureADOAuthSource(OAuthSource):
|
||||
"""Social Login using Azure AD."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import AzureADOAuthSourceForm
|
||||
|
||||
return AzureADOAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
@ -182,12 +153,6 @@ class AzureADOAuthSource(OAuthSource):
|
|||
class OpenIDOAuthSource(OAuthSource):
|
||||
"""Login using a Generic OpenID-Connect compliant provider."""
|
||||
|
||||
@property
|
||||
def form(self) -> Type[ModelForm]:
|
||||
from authentik.sources.oauth.forms import OAuthSourceForm
|
||||
|
||||
return OAuthSourceForm
|
||||
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
|
|
@ -3,11 +3,10 @@ from typing import Any
|
|||
from uuid import UUID
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Azure AD")
|
||||
class AzureADOAuthCallback(OAuthCallback):
|
||||
"""AzureAD OAuth2 Callback"""
|
||||
|
||||
|
@ -26,3 +25,18 @@ class AzureADOAuthCallback(OAuthCallback):
|
|||
"email": mail,
|
||||
"name": info.get("displayName"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class AzureADType(SourceType):
|
||||
"""Azure AD Type definition"""
|
||||
|
||||
callback_view = AzureADOAuthCallback
|
||||
name = "Azure AD"
|
||||
slug = "azure-ad"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/authorize"
|
||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/token" # nosec
|
||||
profile_url = "https://graph.windows.net/myorganization/me?api-version=1.6"
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Discord")
|
||||
class DiscordOAuthRedirect(OAuthRedirect):
|
||||
"""Discord OAuth2 Redirect"""
|
||||
|
||||
|
@ -17,7 +16,6 @@ class DiscordOAuthRedirect(OAuthRedirect):
|
|||
}
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Discord")
|
||||
class DiscordOAuth2Callback(OAuthCallback):
|
||||
"""Discord OAuth2 Callback"""
|
||||
|
||||
|
@ -32,3 +30,17 @@ class DiscordOAuth2Callback(OAuthCallback):
|
|||
"email": info.get("email", None),
|
||||
"name": info.get("username"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class DiscordType(SourceType):
|
||||
"""Discord Type definition"""
|
||||
|
||||
callback_view = DiscordOAuth2Callback
|
||||
redirect_view = DiscordOAuthRedirect
|
||||
name = "Discord"
|
||||
slug = "discord"
|
||||
|
||||
authorization_url = "https://discord.com/api/oauth2/authorize"
|
||||
access_token_url = "https://discord.com/api/oauth2/token" # nosec
|
||||
profile_url = "https://discord.com/api/users/@me"
|
||||
|
|
|
@ -5,12 +5,11 @@ from facebook import GraphAPI
|
|||
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Facebook")
|
||||
class FacebookOAuthRedirect(OAuthRedirect):
|
||||
"""Facebook OAuth2 Redirect"""
|
||||
|
||||
|
@ -28,7 +27,6 @@ class FacebookOAuth2Client(OAuth2Client):
|
|||
return api.get_object("me", fields="id,name,email")
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Facebook")
|
||||
class FacebookOAuth2Callback(OAuthCallback):
|
||||
"""Facebook OAuth2 Callback"""
|
||||
|
||||
|
@ -45,3 +43,17 @@ class FacebookOAuth2Callback(OAuthCallback):
|
|||
"email": info.get("email"),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class FacebookType(SourceType):
|
||||
"""Facebook Type definition"""
|
||||
|
||||
callback_view = FacebookOAuth2Callback
|
||||
redirect_view = FacebookOAuthRedirect
|
||||
name = "Facebook"
|
||||
slug = "facebook"
|
||||
|
||||
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
|
||||
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
|
||||
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="GitHub")
|
||||
class GitHubOAuth2Callback(OAuthCallback):
|
||||
"""GitHub OAuth2 Callback"""
|
||||
|
||||
|
@ -21,3 +20,18 @@ class GitHubOAuth2Callback(OAuthCallback):
|
|||
"email": info.get("email"),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class GitHubType(SourceType):
|
||||
"""GitHub Type definition"""
|
||||
|
||||
callback_view = GitHubOAuth2Callback
|
||||
name = "GitHub"
|
||||
slug = "github"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://github.com/login/oauth/authorize"
|
||||
access_token_url = "https://github.com/login/oauth/access_token" # nosec
|
||||
profile_url = "https://api.github.com/user"
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Google")
|
||||
class GoogleOAuthRedirect(OAuthRedirect):
|
||||
"""Google OAuth2 Redirect"""
|
||||
|
||||
|
@ -17,7 +16,6 @@ class GoogleOAuthRedirect(OAuthRedirect):
|
|||
}
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Google")
|
||||
class GoogleOAuth2Callback(OAuthCallback):
|
||||
"""Google OAuth2 Callback"""
|
||||
|
||||
|
@ -32,3 +30,17 @@ class GoogleOAuth2Callback(OAuthCallback):
|
|||
"email": info.get("email"),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class GoogleType(SourceType):
|
||||
"""Google Type definition"""
|
||||
|
||||
callback_view = GoogleOAuth2Callback
|
||||
redirect_view = GoogleOAuthRedirect
|
||||
name = "Google"
|
||||
slug = "google"
|
||||
|
||||
authorization_url = "https://accounts.google.com/o/oauth2/auth"
|
||||
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
|
||||
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
"""Source type manager"""
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from django.utils.text import slugify
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
|
||||
class RequestKind(Enum):
|
||||
"""Enum of OAuth Request types"""
|
||||
|
@ -19,46 +20,67 @@ class RequestKind(Enum):
|
|||
REDIRECT = "redirect"
|
||||
|
||||
|
||||
class SourceType:
|
||||
"""Source type, allows overriding of urls and views per type"""
|
||||
|
||||
callback_view = OAuthCallback
|
||||
redirect_view = OAuthRedirect
|
||||
name: str
|
||||
slug: str
|
||||
|
||||
urls_customizable = False
|
||||
|
||||
request_token_url: Optional[str] = None
|
||||
authorization_url: Optional[str] = None
|
||||
access_token_url: Optional[str] = None
|
||||
profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class SourceTypeManager:
|
||||
"""Manager to hold all Source types."""
|
||||
|
||||
__source_types: dict[RequestKind, dict[str, Callable]] = {}
|
||||
__names: list[str] = []
|
||||
__sources: list[SourceType] = []
|
||||
|
||||
def source(self, kind: RequestKind, name: str):
|
||||
def type(self):
|
||||
"""Class decorator to register classes inline."""
|
||||
|
||||
def inner_wrapper(cls):
|
||||
if kind.value not in self.__source_types:
|
||||
self.__source_types[kind.value] = {}
|
||||
self.__source_types[kind.value][slugify(name)] = cls
|
||||
self.__names.append(name)
|
||||
self.__sources.append(cls)
|
||||
return cls
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
def get(self):
|
||||
"""Get a list of all source types"""
|
||||
return self.__sources
|
||||
|
||||
def get_name_tuple(self):
|
||||
"""Get list of tuples of all registered names"""
|
||||
return [(slugify(x), x) for x in set(self.__names)]
|
||||
return [(x.slug, x.name) for x in self.__sources]
|
||||
|
||||
def find(self, source: OAuthSource, kind: RequestKind) -> Callable:
|
||||
"""Find fitting Source Type"""
|
||||
if kind.value in self.__source_types:
|
||||
if source.provider_type in self.__source_types[kind.value]:
|
||||
return self.__source_types[kind.value][source.provider_type]
|
||||
def find_type(self, source: "OAuthSource") -> SourceType:
|
||||
"""Find type based on source"""
|
||||
found_type = None
|
||||
for src_type in self.__sources:
|
||||
if src_type.slug == source.provider_type:
|
||||
return src_type
|
||||
if not found_type:
|
||||
found_type = SourceType()
|
||||
LOGGER.warning(
|
||||
"no matching type found, using default",
|
||||
wanted=source.provider_type,
|
||||
have=self.__source_types[kind.value].keys(),
|
||||
have=[x.name for x in self.__sources],
|
||||
)
|
||||
# Return defaults
|
||||
if kind == RequestKind.CALLBACK:
|
||||
return OAuthCallback
|
||||
if kind == RequestKind.REDIRECT:
|
||||
return OAuthRedirect
|
||||
raise KeyError(
|
||||
f"Provider Type {source.provider_type} (type {kind.value}) not found."
|
||||
)
|
||||
return found_type
|
||||
|
||||
def find(self, source: "OAuthSource", kind: RequestKind) -> Callable:
|
||||
"""Find fitting Source Type"""
|
||||
found_type = self.find_type(source)
|
||||
if kind == RequestKind.CALLBACK:
|
||||
return found_type.callback_view
|
||||
if kind == RequestKind.REDIRECT:
|
||||
return found_type.redirect_view
|
||||
raise ValueError
|
||||
|
||||
|
||||
MANAGER = SourceTypeManager()
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="OpenID Connect")
|
||||
class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
||||
"""OpenIDConnect OAuth2 Redirect"""
|
||||
|
||||
|
@ -17,7 +16,6 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
|||
}
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="OpenID Connect")
|
||||
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||
"""OpenIDConnect OAuth2 Callback"""
|
||||
|
||||
|
@ -35,3 +33,15 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
|
|||
"email": info.get("email"),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class OpenIDConnectType(SourceType):
|
||||
"""OpenIDConnect Type definition"""
|
||||
|
||||
callback_view = OpenIDConnectOAuth2Callback
|
||||
redirect_view = OpenIDConnectOAuthRedirect
|
||||
name = "OpenID Connect"
|
||||
slug = "openid-connect"
|
||||
|
||||
urls_customizable = True
|
||||
|
|
|
@ -5,12 +5,11 @@ from requests.auth import HTTPBasicAuth
|
|||
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="reddit")
|
||||
class RedditOAuthRedirect(OAuthRedirect):
|
||||
"""Reddit OAuth2 Redirect"""
|
||||
|
||||
|
@ -30,7 +29,6 @@ class RedditOAuth2Client(OAuth2Client):
|
|||
return super().get_access_token(auth=auth)
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="reddit")
|
||||
class RedditOAuth2Callback(OAuthCallback):
|
||||
"""Reddit OAuth2 Callback"""
|
||||
|
||||
|
@ -48,3 +46,17 @@ class RedditOAuth2Callback(OAuthCallback):
|
|||
"name": info.get("name"),
|
||||
"password": None,
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class RedditType(SourceType):
|
||||
"""Reddit Type definition"""
|
||||
|
||||
callback_view = RedditOAuth2Callback
|
||||
redirect_view = RedditOAuthRedirect
|
||||
name = "reddit"
|
||||
slug = "reddit"
|
||||
|
||||
authorization_url = "https://accounts.google.com/o/oauth2/auth"
|
||||
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
|
||||
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Twitter")
|
||||
class TwitterOAuthCallback(OAuthCallback):
|
||||
"""Twitter OAuth2 Callback"""
|
||||
|
||||
|
@ -21,3 +20,20 @@ class TwitterOAuthCallback(OAuthCallback):
|
|||
"email": info.get("email", None),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class TwitterType(SourceType):
|
||||
"""Twitter Type definition"""
|
||||
|
||||
callback_view = TwitterOAuthCallback
|
||||
name = "Twitter"
|
||||
slug = "twitter"
|
||||
|
||||
request_token_url = "https://api.twitter.com/oauth/request_token" # nosec
|
||||
authorization_url = "https://api.twitter.com/oauth/authenticate"
|
||||
access_token_url = "https://api.twitter.com/oauth/access_token" # nosec
|
||||
profile_url = (
|
||||
"https://api.twitter.com/1.1/account/"
|
||||
"verify_credentials.json?include_email=true"
|
||||
)
|
||||
|
|
65
swagger.yaml
65
swagger.yaml
|
@ -9642,9 +9642,9 @@ paths:
|
|||
tags:
|
||||
- sources
|
||||
parameters: []
|
||||
/sources/oauth/provider_types/:
|
||||
/sources/oauth/source_types/:
|
||||
get:
|
||||
operationId: sources_oauth_provider_types
|
||||
operationId: sources_oauth_source_types
|
||||
description: Get all creatable source types
|
||||
parameters: []
|
||||
responses:
|
||||
|
@ -9653,7 +9653,7 @@ paths:
|
|||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/definitions/OAuthSourceProviderType'
|
||||
$ref: '#/definitions/SourceType'
|
||||
'403':
|
||||
description: Authentication credentials were invalid, absent or insufficient.
|
||||
schema:
|
||||
|
@ -16907,6 +16907,49 @@ definitions:
|
|||
type: string
|
||||
format: date-time
|
||||
readOnly: true
|
||||
SourceType:
|
||||
description: Get source's type configuration
|
||||
required:
|
||||
- name
|
||||
- slug
|
||||
- urls_customizable
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
minLength: 1
|
||||
slug:
|
||||
title: Slug
|
||||
type: string
|
||||
minLength: 1
|
||||
urls_customizable:
|
||||
title: Urls customizable
|
||||
type: boolean
|
||||
request_token_url:
|
||||
title: Request token url
|
||||
type: string
|
||||
readOnly: true
|
||||
minLength: 1
|
||||
x-nullable: true
|
||||
authorization_url:
|
||||
title: Authorization url
|
||||
type: string
|
||||
readOnly: true
|
||||
minLength: 1
|
||||
x-nullable: true
|
||||
access_token_url:
|
||||
title: Access token url
|
||||
type: string
|
||||
readOnly: true
|
||||
minLength: 1
|
||||
x-nullable: true
|
||||
profile_url:
|
||||
title: Profile url
|
||||
type: string
|
||||
readOnly: true
|
||||
minLength: 1
|
||||
x-nullable: true
|
||||
OAuthSource:
|
||||
required:
|
||||
- name
|
||||
|
@ -17011,20 +17054,8 @@ definitions:
|
|||
title: Callback url
|
||||
type: string
|
||||
readOnly: true
|
||||
OAuthSourceProviderType:
|
||||
required:
|
||||
- name
|
||||
- value
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
minLength: 1
|
||||
value:
|
||||
title: Value
|
||||
type: string
|
||||
minLength: 1
|
||||
type:
|
||||
$ref: '#/definitions/SourceType'
|
||||
UserOAuthSourceConnection:
|
||||
required:
|
||||
- user
|
||||
|
|
Reference in a new issue