*: migrate ui_* properties to functions to allow context being passed
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
4f05dcec89
commit
e4841d54a1
|
@ -104,14 +104,14 @@ class SourceViewSet(
|
||||||
)
|
)
|
||||||
matching_sources: list[UserSettingSerializer] = []
|
matching_sources: list[UserSettingSerializer] = []
|
||||||
for source in _all_sources:
|
for source in _all_sources:
|
||||||
user_settings = source.ui_user_settings
|
user_settings = source.ui_user_settings()
|
||||||
if not user_settings:
|
if not user_settings:
|
||||||
continue
|
continue
|
||||||
policy_engine = PolicyEngine(source, request.user, request)
|
policy_engine = PolicyEngine(source, request.user, request)
|
||||||
policy_engine.build()
|
policy_engine.build()
|
||||||
if not policy_engine.passing:
|
if not policy_engine.passing:
|
||||||
continue
|
continue
|
||||||
source_settings = source.ui_user_settings
|
source_settings = source.ui_user_settings()
|
||||||
source_settings.initial_data["object_uid"] = source.slug
|
source_settings.initial_data["object_uid"] = source.slug
|
||||||
if not source_settings.is_valid():
|
if not source_settings.is_valid():
|
||||||
LOGGER.warning(source_settings.errors)
|
LOGGER.warning(source_settings.errors)
|
||||||
|
|
|
@ -359,13 +359,11 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||||
"""Return component used to edit this object"""
|
"""Return component used to edit this object"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]:
|
||||||
def ui_login_button(self) -> Optional[UILoginButton]:
|
|
||||||
"""If source uses a http-based flow, return UI Information about the login
|
"""If source uses a http-based flow, return UI Information about the login
|
||||||
button. If source doesn't use http-based flow, return None."""
|
button. If source doesn't use http-based flow, return None."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
"""Entrypoint to integrate with User settings. Can either return None if no
|
"""Entrypoint to integrate with User settings. Can either return None if no
|
||||||
user settings are available, or UserSettingSerializer."""
|
user settings are available, or UserSettingSerializer."""
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Callable, Type
|
from typing import Callable, Type
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import RequestFactory, TestCase
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
|
||||||
|
@ -30,6 +30,9 @@ class TestModels(TestCase):
|
||||||
def source_tester_factory(test_model: Type[Stage]) -> Callable:
|
def source_tester_factory(test_model: Type[Stage]) -> Callable:
|
||||||
"""Test source"""
|
"""Test source"""
|
||||||
|
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get("/")
|
||||||
|
|
||||||
def tester(self: TestModels):
|
def tester(self: TestModels):
|
||||||
model_class = None
|
model_class = None
|
||||||
if test_model._meta.abstract:
|
if test_model._meta.abstract:
|
||||||
|
@ -38,8 +41,8 @@ def source_tester_factory(test_model: Type[Stage]) -> Callable:
|
||||||
model_class = test_model()
|
model_class = test_model()
|
||||||
model_class.slug = "test"
|
model_class.slug = "test"
|
||||||
self.assertIsNotNone(model_class.component)
|
self.assertIsNotNone(model_class.component)
|
||||||
_ = model_class.ui_login_button
|
_ = model_class.ui_login_button(request)
|
||||||
_ = model_class.ui_user_settings
|
_ = model_class.ui_user_settings()
|
||||||
|
|
||||||
return tester
|
return tester
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class StageViewSet(
|
||||||
stages += list(configurable_stage.objects.all().order_by("name"))
|
stages += list(configurable_stage.objects.all().order_by("name"))
|
||||||
matching_stages: list[dict] = []
|
matching_stages: list[dict] = []
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
user_settings = stage.ui_user_settings
|
user_settings = stage.ui_user_settings()
|
||||||
if not user_settings:
|
if not user_settings:
|
||||||
continue
|
continue
|
||||||
user_settings.initial_data["object_uid"] = str(stage.pk)
|
user_settings.initial_data["object_uid"] = str(stage.pk)
|
||||||
|
|
|
@ -75,7 +75,6 @@ class Stage(SerializerModel):
|
||||||
"""Return component used to edit this object"""
|
"""Return component used to edit this object"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
"""Entrypoint to integrate with User settings. Can either return None if no
|
"""Entrypoint to integrate with User settings. Can either return None if no
|
||||||
user settings are available, or a challenge."""
|
user settings are available, or a challenge."""
|
||||||
|
|
|
@ -32,7 +32,7 @@ class TestFlowsAPI(APITestCase):
|
||||||
|
|
||||||
def test_models(self):
|
def test_models(self):
|
||||||
"""Test that ui_user_settings returns none"""
|
"""Test that ui_user_settings returns none"""
|
||||||
self.assertIsNone(Stage().ui_user_settings)
|
self.assertIsNone(Stage().ui_user_settings())
|
||||||
|
|
||||||
def test_api_serializer(self):
|
def test_api_serializer(self):
|
||||||
"""Test that stage serializer returns the correct type"""
|
"""Test that stage serializer returns the correct type"""
|
||||||
|
|
|
@ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable:
|
||||||
model_class = test_model()
|
model_class = test_model()
|
||||||
self.assertTrue(issubclass(model_class.type, StageView))
|
self.assertTrue(issubclass(model_class.type, StageView))
|
||||||
self.assertIsNotNone(test_model.component)
|
self.assertIsNotNone(test_model.component)
|
||||||
_ = model_class.ui_user_settings
|
_ = model_class.ui_user_settings()
|
||||||
|
|
||||||
return tester
|
return tester
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.http.request import HttpRequest
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
|
@ -63,11 +64,15 @@ class OAuthSource(Source):
|
||||||
|
|
||||||
return OAuthSourceSerializer
|
return OAuthSourceSerializer
|
||||||
|
|
||||||
@property
|
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||||
def ui_login_button(self) -> UILoginButton:
|
provider_type = self.type
|
||||||
return self.type().ui_login_button()
|
provider = provider_type()
|
||||||
|
return UILoginButton(
|
||||||
|
name=self.name,
|
||||||
|
icon_url=provider.icon_url(),
|
||||||
|
challenge=provider.login_challenge(self, request),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -2,12 +2,13 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional, Type
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
|
from django.http.request import HttpRequest
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from django.urls.base import reverse
|
from django.urls.base import reverse
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.types import UILoginButton
|
from authentik.flows.challenge import Challenge, ChallengeTypes, RedirectChallenge
|
||||||
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
@ -40,20 +41,17 @@ class SourceType:
|
||||||
"""Get Icon URL for login"""
|
"""Get Icon URL for login"""
|
||||||
return static(f"authentik/sources/{self.slug}.svg")
|
return static(f"authentik/sources/{self.slug}.svg")
|
||||||
|
|
||||||
def ui_login_button(self) -> UILoginButton:
|
# pylint: disable=unused-argument
|
||||||
|
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
|
||||||
"""Allow types to return custom challenges"""
|
"""Allow types to return custom challenges"""
|
||||||
return UILoginButton(
|
return RedirectChallenge(
|
||||||
challenge=RedirectChallenge(
|
instance={
|
||||||
instance={
|
"type": ChallengeTypes.REDIRECT.value,
|
||||||
"type": ChallengeTypes.REDIRECT.value,
|
"to": reverse(
|
||||||
"to": reverse(
|
"authentik_sources_oauth:oauth-client-login",
|
||||||
"authentik_sources_oauth:oauth-client-login",
|
kwargs={"source_slug": self.slug},
|
||||||
kwargs={"source_slug": self.slug},
|
),
|
||||||
),
|
}
|
||||||
}
|
|
||||||
),
|
|
||||||
icon_url=self.icon_url(),
|
|
||||||
name=self.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||||
|
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.http.request import HttpRequest
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.fields import CharField
|
from rest_framework.fields import CharField
|
||||||
|
@ -62,8 +63,7 @@ class PlexSource(Source):
|
||||||
|
|
||||||
return PlexSourceSerializer
|
return PlexSourceSerializer
|
||||||
|
|
||||||
@property
|
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||||
def ui_login_button(self) -> UILoginButton:
|
|
||||||
return UILoginButton(
|
return UILoginButton(
|
||||||
challenge=PlexAuthenticationChallenge(
|
challenge=PlexAuthenticationChallenge(
|
||||||
{
|
{
|
||||||
|
@ -77,7 +77,6 @@ class PlexSource(Source):
|
||||||
name=self.name,
|
name=self.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -167,8 +167,7 @@ class SAMLSource(Source):
|
||||||
reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
|
reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||||
def ui_login_button(self) -> UILoginButton:
|
|
||||||
return UILoginButton(
|
return UILoginButton(
|
||||||
challenge=RedirectChallenge(
|
challenge=RedirectChallenge(
|
||||||
instance={
|
instance={
|
||||||
|
|
|
@ -48,7 +48,6 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-authenticator-duo-form"
|
return "ak-stage-authenticator-duo-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -141,7 +141,6 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-authenticator-sms-form"
|
return "ak-stage-authenticator-sms-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -31,7 +31,6 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-authenticator-static-form"
|
return "ak-stage-authenticator-static-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -38,7 +38,6 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-authenticator-totp-form"
|
return "ak-stage-authenticator-totp-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -34,7 +34,6 @@ class AuthenticateWebAuthnStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-authenticator-webauthn-form"
|
return "ak-stage-authenticator-webauthn-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
return UserSettingSerializer(
|
return UserSettingSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -191,7 +191,7 @@ class IdentificationStageView(ChallengeStageView):
|
||||||
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
|
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
|
||||||
)
|
)
|
||||||
for source in sources:
|
for source in sources:
|
||||||
ui_login_button = source.ui_login_button
|
ui_login_button = source.ui_login_button(self.request)
|
||||||
if ui_login_button:
|
if ui_login_button:
|
||||||
button = asdict(ui_login_button)
|
button = asdict(ui_login_button)
|
||||||
button["challenge"] = ui_login_button.challenge.data
|
button["challenge"] = ui_login_button.challenge.data
|
||||||
|
|
|
@ -63,7 +63,6 @@ class PasswordStage(ConfigurableStage, Stage):
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-stage-password-form"
|
return "ak-stage-password-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||||
if not self.configure_flow:
|
if not self.configure_flow:
|
||||||
return None
|
return None
|
||||||
|
|
Reference in New Issue