*: migrate ui_* properties to functions to allow context being passed

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-12-13 23:56:01 +01:00
parent 4f05dcec89
commit e4841d54a1
18 changed files with 38 additions and 43 deletions

View File

@ -104,14 +104,14 @@ class SourceViewSet(
) )
matching_sources: list[UserSettingSerializer] = [] matching_sources: list[UserSettingSerializer] = []
for source in _all_sources: for source in _all_sources:
user_settings = source.ui_user_settings user_settings = source.ui_user_settings()
if not user_settings: if not user_settings:
continue continue
policy_engine = PolicyEngine(source, request.user, request) policy_engine = PolicyEngine(source, request.user, request)
policy_engine.build() policy_engine.build()
if not policy_engine.passing: if not policy_engine.passing:
continue continue
source_settings = source.ui_user_settings source_settings = source.ui_user_settings()
source_settings.initial_data["object_uid"] = source.slug source_settings.initial_data["object_uid"] = source.slug
if not source_settings.is_valid(): if not source_settings.is_valid():
LOGGER.warning(source_settings.errors) LOGGER.warning(source_settings.errors)

View File

@ -359,13 +359,11 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Return component used to edit this object""" """Return component used to edit this object"""
raise NotImplementedError raise NotImplementedError
@property def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]:
def ui_login_button(self) -> Optional[UILoginButton]:
"""If source uses a http-based flow, return UI Information about the login """If source uses a http-based flow, return UI Information about the login
button. If source doesn't use http-based flow, return None.""" button. If source doesn't use http-based flow, return None."""
return None return None
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
"""Entrypoint to integrate with User settings. Can either return None if no """Entrypoint to integrate with User settings. Can either return None if no
user settings are available, or UserSettingSerializer.""" user settings are available, or UserSettingSerializer."""

View File

@ -2,7 +2,7 @@
from time import sleep from time import sleep
from typing import Callable, Type from typing import Callable, Type
from django.test import TestCase from django.test import RequestFactory, TestCase
from django.utils.timezone import now from django.utils.timezone import now
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
@ -30,6 +30,9 @@ class TestModels(TestCase):
def source_tester_factory(test_model: Type[Stage]) -> Callable: def source_tester_factory(test_model: Type[Stage]) -> Callable:
"""Test source""" """Test source"""
factory = RequestFactory()
request = factory.get("/")
def tester(self: TestModels): def tester(self: TestModels):
model_class = None model_class = None
if test_model._meta.abstract: if test_model._meta.abstract:
@ -38,8 +41,8 @@ def source_tester_factory(test_model: Type[Stage]) -> Callable:
model_class = test_model() model_class = test_model()
model_class.slug = "test" model_class.slug = "test"
self.assertIsNotNone(model_class.component) self.assertIsNotNone(model_class.component)
_ = model_class.ui_login_button _ = model_class.ui_login_button(request)
_ = model_class.ui_user_settings _ = model_class.ui_user_settings()
return tester return tester

View File

@ -90,7 +90,7 @@ class StageViewSet(
stages += list(configurable_stage.objects.all().order_by("name")) stages += list(configurable_stage.objects.all().order_by("name"))
matching_stages: list[dict] = [] matching_stages: list[dict] = []
for stage in stages: for stage in stages:
user_settings = stage.ui_user_settings user_settings = stage.ui_user_settings()
if not user_settings: if not user_settings:
continue continue
user_settings.initial_data["object_uid"] = str(stage.pk) user_settings.initial_data["object_uid"] = str(stage.pk)

View File

@ -75,7 +75,6 @@ class Stage(SerializerModel):
"""Return component used to edit this object""" """Return component used to edit this object"""
raise NotImplementedError raise NotImplementedError
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
"""Entrypoint to integrate with User settings. Can either return None if no """Entrypoint to integrate with User settings. Can either return None if no
user settings are available, or a challenge.""" user settings are available, or a challenge."""

View File

@ -32,7 +32,7 @@ class TestFlowsAPI(APITestCase):
def test_models(self): def test_models(self):
"""Test that ui_user_settings returns none""" """Test that ui_user_settings returns none"""
self.assertIsNone(Stage().ui_user_settings) self.assertIsNone(Stage().ui_user_settings())
def test_api_serializer(self): def test_api_serializer(self):
"""Test that stage serializer returns the correct type""" """Test that stage serializer returns the correct type"""

View File

@ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable:
model_class = test_model() model_class = test_model()
self.assertTrue(issubclass(model_class.type, StageView)) self.assertTrue(issubclass(model_class.type, StageView))
self.assertIsNotNone(test_model.component) self.assertIsNotNone(test_model.component)
_ = model_class.ui_user_settings _ = model_class.ui_user_settings()
return tester return tester

View File

@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
from django.db import models from django.db import models
from django.http.request import HttpRequest
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -63,11 +64,15 @@ class OAuthSource(Source):
return OAuthSourceSerializer return OAuthSourceSerializer
@property def ui_login_button(self, request: HttpRequest) -> UILoginButton:
def ui_login_button(self) -> UILoginButton: provider_type = self.type
return self.type().ui_login_button() provider = provider_type()
return UILoginButton(
name=self.name,
icon_url=provider.icon_url(),
challenge=provider.login_challenge(self, request),
)
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -2,12 +2,13 @@
from enum import Enum from enum import Enum
from typing import Callable, Optional, Type from typing import Callable, Optional, Type
from django.http.request import HttpRequest
from django.templatetags.static import static from django.templatetags.static import static
from django.urls.base import reverse from django.urls.base import reverse
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.types import UILoginButton from authentik.flows.challenge import Challenge, ChallengeTypes, RedirectChallenge
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -40,10 +41,10 @@ class SourceType:
"""Get Icon URL for login""" """Get Icon URL for login"""
return static(f"authentik/sources/{self.slug}.svg") return static(f"authentik/sources/{self.slug}.svg")
def ui_login_button(self) -> UILoginButton: # pylint: disable=unused-argument
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
"""Allow types to return custom challenges""" """Allow types to return custom challenges"""
return UILoginButton( return RedirectChallenge(
challenge=RedirectChallenge(
instance={ instance={
"type": ChallengeTypes.REDIRECT.value, "type": ChallengeTypes.REDIRECT.value,
"to": reverse( "to": reverse(
@ -51,9 +52,6 @@ class SourceType:
kwargs={"source_slug": self.slug}, kwargs={"source_slug": self.slug},
), ),
} }
),
icon_url=self.icon_url(),
name=self.name,
) )

View File

@ -3,6 +3,7 @@ from typing import Optional
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db import models from django.db import models
from django.http.request import HttpRequest
from django.templatetags.static import static from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField from rest_framework.fields import CharField
@ -62,8 +63,7 @@ class PlexSource(Source):
return PlexSourceSerializer return PlexSourceSerializer
@property def ui_login_button(self, request: HttpRequest) -> UILoginButton:
def ui_login_button(self) -> UILoginButton:
return UILoginButton( return UILoginButton(
challenge=PlexAuthenticationChallenge( challenge=PlexAuthenticationChallenge(
{ {
@ -77,7 +77,6 @@ class PlexSource(Source):
name=self.name, name=self.name,
) )
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -167,8 +167,7 @@ class SAMLSource(Source):
reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug}) reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
) )
@property def ui_login_button(self, request: HttpRequest) -> UILoginButton:
def ui_login_button(self) -> UILoginButton:
return UILoginButton( return UILoginButton(
challenge=RedirectChallenge( challenge=RedirectChallenge(
instance={ instance={

View File

@ -48,7 +48,6 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-authenticator-duo-form" return "ak-stage-authenticator-duo-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -141,7 +141,6 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-authenticator-sms-form" return "ak-stage-authenticator-sms-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -31,7 +31,6 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-authenticator-static-form" return "ak-stage-authenticator-static-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -38,7 +38,6 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-authenticator-totp-form" return "ak-stage-authenticator-totp-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -34,7 +34,6 @@ class AuthenticateWebAuthnStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-authenticator-webauthn-form" return "ak-stage-authenticator-webauthn-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
return UserSettingSerializer( return UserSettingSerializer(
data={ data={

View File

@ -191,7 +191,7 @@ class IdentificationStageView(ChallengeStageView):
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses() current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
) )
for source in sources: for source in sources:
ui_login_button = source.ui_login_button ui_login_button = source.ui_login_button(self.request)
if ui_login_button: if ui_login_button:
button = asdict(ui_login_button) button = asdict(ui_login_button)
button["challenge"] = ui_login_button.challenge.data button["challenge"] = ui_login_button.challenge.data

View File

@ -63,7 +63,6 @@ class PasswordStage(ConfigurableStage, Stage):
def component(self) -> str: def component(self) -> str:
return "ak-stage-password-form" return "ak-stage-password-form"
@property
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
if not self.configure_flow: if not self.configure_flow:
return None return None