sources/oauth: allow oauth types to override their login button challenge

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-12-13 23:45:11 +01:00
parent ede6bcd31e
commit 4f05dcec89
2 changed files with 20 additions and 15 deletions

View File

@ -8,7 +8,6 @@ from rest_framework.serializers import Serializer
from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
if TYPE_CHECKING:
from authentik.sources.oauth.types.manager import SourceType
@ -66,20 +65,7 @@ class OAuthSource(Source):
@property
def ui_login_button(self) -> UILoginButton:
provider_type = self.type
return UILoginButton(
challenge=RedirectChallenge(
instance={
"type": ChallengeTypes.REDIRECT.value,
"to": reverse(
"authentik_sources_oauth:oauth-client-login",
kwargs={"source_slug": self.slug},
),
}
),
icon_url=provider_type().icon_url(),
name=self.name,
)
return self.type().ui_login_button()
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]:

View File

@ -3,8 +3,11 @@ from enum import Enum
from typing import Callable, Optional, Type
from django.templatetags.static import static
from django.urls.base import reverse
from structlog.stdlib import get_logger
from authentik.core.types import UILoginButton
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -37,6 +40,22 @@ class SourceType:
"""Get Icon URL for login"""
return static(f"authentik/sources/{self.slug}.svg")
def ui_login_button(self) -> UILoginButton:
"""Allow types to return custom challenges"""
return UILoginButton(
challenge=RedirectChallenge(
instance={
"type": ChallengeTypes.REDIRECT.value,
"to": reverse(
"authentik_sources_oauth:oauth-client-login",
kwargs={"source_slug": self.slug},
),
}
),
icon_url=self.icon_url(),
name=self.name,
)
class SourceTypeManager:
"""Manager to hold all Source types."""