stages/prompt: migrate to SPA
This commit is contained in:
parent
d35f524865
commit
27cd10e072
|
@ -28,9 +28,3 @@ class TestOverviewViews(TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.client.get(reverse("authentik_core:shell")).status_code, 200
|
self.client.get(reverse("authentik_core:shell")).status_code, 200
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_overview(self):
|
|
||||||
"""Test overview"""
|
|
||||||
self.assertEqual(
|
|
||||||
self.client.get(reverse("authentik_core:overview")).status_code, 200
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,20 +1,7 @@
|
||||||
"""Prompt forms"""
|
"""Prompt forms"""
|
||||||
from email.policy import Policy
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Any, Callable, Iterator
|
|
||||||
|
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.db.models.query import QuerySet
|
|
||||||
from django.http import HttpRequest
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from guardian.shortcuts import get_anonymous_user
|
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.stages.prompt.models import Prompt, PromptStage
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
|
||||||
from authentik.policies.engine import PolicyEngine
|
|
||||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
|
|
||||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
|
||||||
from authentik.stages.prompt.signals import password_validate
|
|
||||||
|
|
||||||
|
|
||||||
class PromptStageForm(forms.ModelForm):
|
class PromptStageForm(forms.ModelForm):
|
||||||
|
@ -47,111 +34,3 @@ class PromptAdminForm(forms.ModelForm):
|
||||||
"label": forms.TextInput(),
|
"label": forms.TextInput(),
|
||||||
"placeholder": forms.TextInput(),
|
"placeholder": forms.TextInput(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ListPolicyEngine(PolicyEngine):
|
|
||||||
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
|
|
||||||
|
|
||||||
__list: list[Policy]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, policies: list[Policy], user: User, request: HttpRequest = None
|
|
||||||
) -> None:
|
|
||||||
super().__init__(PolicyBindingModel(), user, request)
|
|
||||||
self.__list = policies
|
|
||||||
self.use_cache = False
|
|
||||||
|
|
||||||
def _iter_bindings(self) -> Iterator[PolicyBinding]:
|
|
||||||
for policy in self.__list:
|
|
||||||
yield PolicyBinding(
|
|
||||||
policy=policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptForm(forms.Form):
|
|
||||||
"""Dynamically created form based on PromptStage"""
|
|
||||||
|
|
||||||
stage: PromptStage
|
|
||||||
plan: FlowPlan
|
|
||||||
|
|
||||||
def __init__(self, stage: PromptStage, plan: FlowPlan, *args, **kwargs):
|
|
||||||
self.stage = stage
|
|
||||||
self.plan = plan
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
# list() is called so we only load the fields once
|
|
||||||
fields = list(self.stage.fields.all())
|
|
||||||
for field in fields:
|
|
||||||
field: Prompt
|
|
||||||
self.fields[field.field_key] = field.field
|
|
||||||
# Special handling for fields with username type
|
|
||||||
# these check for existing users with the same username
|
|
||||||
if field.type == FieldTypes.USERNAME:
|
|
||||||
setattr(
|
|
||||||
self,
|
|
||||||
f"clean_{field.field_key}",
|
|
||||||
MethodType(username_field_cleaner_factory(field), self),
|
|
||||||
)
|
|
||||||
# Check if we have a password field, add a handler that sends a signal
|
|
||||||
# to validate it
|
|
||||||
if field.type == FieldTypes.PASSWORD:
|
|
||||||
setattr(
|
|
||||||
self,
|
|
||||||
f"clean_{field.field_key}",
|
|
||||||
MethodType(password_single_cleaner_factory(field), self),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.field_order = sorted(fields, key=lambda x: x.order)
|
|
||||||
|
|
||||||
def _clean_password_fields(self, *field_names):
|
|
||||||
"""Check if the value of all password fields match by merging them into a set
|
|
||||||
and checking the length"""
|
|
||||||
all_passwords = {self.cleaned_data[x] for x in field_names}
|
|
||||||
if len(all_passwords) > 1:
|
|
||||||
raise forms.ValidationError(_("Passwords don't match."))
|
|
||||||
|
|
||||||
def clean(self):
|
|
||||||
cleaned_data = super().clean()
|
|
||||||
if cleaned_data == {}:
|
|
||||||
return {}
|
|
||||||
# Check if we have two password fields, and make sure they are the same
|
|
||||||
password_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
|
||||||
type=FieldTypes.PASSWORD
|
|
||||||
)
|
|
||||||
if password_fields.exists() and password_fields.count() == 2:
|
|
||||||
self._clean_password_fields(*[field.field_key for field in password_fields])
|
|
||||||
|
|
||||||
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
|
||||||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
|
|
||||||
engine.request.context = cleaned_data
|
|
||||||
engine.build()
|
|
||||||
result = engine.result
|
|
||||||
if not result.passing:
|
|
||||||
raise forms.ValidationError(list(result.messages))
|
|
||||||
return cleaned_data
|
|
||||||
|
|
||||||
|
|
||||||
def username_field_cleaner_factory(field: Prompt) -> Callable:
|
|
||||||
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
|
||||||
|
|
||||||
def username_field_cleaner(self: PromptForm) -> Any:
|
|
||||||
"""Check for duplicate usernames"""
|
|
||||||
username = self.cleaned_data.get(field.field_key)
|
|
||||||
if User.objects.filter(username=username).exists():
|
|
||||||
raise forms.ValidationError("Username is already taken.")
|
|
||||||
return username
|
|
||||||
|
|
||||||
return username_field_cleaner
|
|
||||||
|
|
||||||
|
|
||||||
def password_single_cleaner_factory(field: Prompt) -> Callable[[PromptForm], Any]:
|
|
||||||
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
|
||||||
|
|
||||||
def password_single_clean(self: PromptForm) -> Any:
|
|
||||||
"""Send password validation signals for e.g. LDAP Source"""
|
|
||||||
password = self.cleaned_data[field.field_key]
|
|
||||||
password_validate.send(
|
|
||||||
sender=self, password=password, plan_context=self.plan.context
|
|
||||||
)
|
|
||||||
return password
|
|
||||||
|
|
||||||
return password_single_clean
|
|
||||||
|
|
|
@ -2,17 +2,23 @@
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from django import forms
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.forms import ModelForm
|
from django.forms import ModelForm
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.views import View
|
from django.views import View
|
||||||
|
from rest_framework.fields import (
|
||||||
|
BooleanField,
|
||||||
|
CharField,
|
||||||
|
DateField,
|
||||||
|
DateTimeField,
|
||||||
|
EmailField,
|
||||||
|
IntegerField,
|
||||||
|
)
|
||||||
from rest_framework.serializers import BaseSerializer
|
from rest_framework.serializers import BaseSerializer
|
||||||
|
|
||||||
from authentik.flows.models import Stage
|
from authentik.flows.models import Stage
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.policies.models import Policy
|
from authentik.policies.models import Policy
|
||||||
from authentik.stages.prompt.widgets import HorizontalRuleWidget, StaticTextWidget
|
|
||||||
|
|
||||||
|
|
||||||
class FieldTypes(models.TextChoices):
|
class FieldTypes(models.TextChoices):
|
||||||
|
@ -43,8 +49,8 @@ class FieldTypes(models.TextChoices):
|
||||||
)
|
)
|
||||||
NUMBER = "number"
|
NUMBER = "number"
|
||||||
CHECKBOX = "checkbox"
|
CHECKBOX = "checkbox"
|
||||||
DATE = "data"
|
DATE = "date"
|
||||||
DATE_TIME = "data-time"
|
DATE_TIME = "date-time"
|
||||||
|
|
||||||
SEPARATOR = "separator", _("Separator: Static Separator Line")
|
SEPARATOR = "separator", _("Separator: Static Separator Line")
|
||||||
HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.")
|
HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.")
|
||||||
|
@ -73,49 +79,34 @@ class Prompt(SerializerModel):
|
||||||
return PromptSerializer
|
return PromptSerializer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field(self):
|
def field(self) -> CharField:
|
||||||
"""Return instantiated form input field"""
|
"""Get field type for Challenge and response"""
|
||||||
attrs = {"placeholder": _(self.placeholder)}
|
field_class = CharField
|
||||||
field_class = forms.CharField
|
|
||||||
widget = forms.TextInput(attrs=attrs)
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"label": _(self.label),
|
|
||||||
"required": self.required,
|
"required": self.required,
|
||||||
}
|
}
|
||||||
if self.type == FieldTypes.EMAIL:
|
if self.type == FieldTypes.EMAIL:
|
||||||
field_class = forms.EmailField
|
field_class = EmailField
|
||||||
if self.type == FieldTypes.USERNAME:
|
|
||||||
attrs["autocomplete"] = "username"
|
|
||||||
if self.type == FieldTypes.PASSWORD:
|
|
||||||
widget = forms.PasswordInput(attrs=attrs)
|
|
||||||
attrs["autocomplete"] = "new-password"
|
|
||||||
if self.type == FieldTypes.NUMBER:
|
if self.type == FieldTypes.NUMBER:
|
||||||
field_class = forms.IntegerField
|
field_class = IntegerField
|
||||||
widget = forms.NumberInput(attrs=attrs)
|
# TODO: Hidden?
|
||||||
if self.type == FieldTypes.HIDDEN:
|
if self.type == FieldTypes.HIDDEN:
|
||||||
widget = forms.HiddenInput(attrs=attrs)
|
|
||||||
kwargs["required"] = False
|
kwargs["required"] = False
|
||||||
kwargs["initial"] = self.placeholder
|
kwargs["initial"] = self.placeholder
|
||||||
if self.type == FieldTypes.CHECKBOX:
|
if self.type == FieldTypes.CHECKBOX:
|
||||||
field_class = forms.BooleanField
|
field_class = BooleanField
|
||||||
kwargs["required"] = False
|
kwargs["required"] = False
|
||||||
if self.type == FieldTypes.DATE:
|
if self.type == FieldTypes.DATE:
|
||||||
attrs["type"] = "date"
|
field_class = DateField
|
||||||
widget = forms.DateInput(attrs=attrs)
|
|
||||||
if self.type == FieldTypes.DATE_TIME:
|
if self.type == FieldTypes.DATE_TIME:
|
||||||
attrs["type"] = "datetime-local"
|
field_class = DateTimeField
|
||||||
widget = forms.DateTimeInput(attrs=attrs)
|
|
||||||
if self.type == FieldTypes.STATIC:
|
if self.type == FieldTypes.STATIC:
|
||||||
widget = StaticTextWidget(attrs=attrs)
|
|
||||||
kwargs["initial"] = self.placeholder
|
kwargs["initial"] = self.placeholder
|
||||||
kwargs["required"] = False
|
kwargs["required"] = False
|
||||||
kwargs["label"] = ""
|
kwargs["label"] = ""
|
||||||
if self.type == FieldTypes.SEPARATOR:
|
if self.type == FieldTypes.SEPARATOR:
|
||||||
widget = HorizontalRuleWidget(attrs=attrs)
|
|
||||||
kwargs["required"] = False
|
kwargs["required"] = False
|
||||||
kwargs["label"] = ""
|
kwargs["label"] = ""
|
||||||
|
|
||||||
kwargs["widget"] = widget
|
|
||||||
return field_class(**kwargs)
|
return field_class(**kwargs)
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
|
|
|
@ -1,36 +1,189 @@
|
||||||
"""Prompt Stage Logic"""
|
"""Prompt Stage Logic"""
|
||||||
from django.http import HttpResponse
|
from email.policy import Policy
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Any, Callable, Iterator
|
||||||
|
|
||||||
|
from django.db.models.base import Model
|
||||||
|
from django.db.models.query import QuerySet
|
||||||
|
from django.http import HttpRequest, HttpResponse
|
||||||
|
from django.http.request import QueryDict
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.views.generic import FormView
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
from rest_framework.fields import BooleanField, CharField, IntegerField
|
||||||
|
from rest_framework.serializers import Serializer, ValidationError
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.flows.stage import StageView
|
from authentik.core.models import User
|
||||||
from authentik.stages.prompt.forms import PromptForm
|
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||||
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||||
|
from authentik.flows.stage import ChallengeStageView
|
||||||
|
from authentik.policies.engine import PolicyEngine
|
||||||
|
from authentik.policies.models import PolicyBinding, PolicyBindingModel
|
||||||
|
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
||||||
|
from authentik.stages.prompt.signals import password_validate
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
PLAN_CONTEXT_PROMPT = "prompt_data"
|
PLAN_CONTEXT_PROMPT = "prompt_data"
|
||||||
|
|
||||||
|
|
||||||
class PromptStageView(FormView, StageView):
|
class PromptSerializer(Serializer):
|
||||||
|
"""Serializer for a single Prompt field"""
|
||||||
|
|
||||||
|
field_key = CharField()
|
||||||
|
label = CharField()
|
||||||
|
type = CharField()
|
||||||
|
required = BooleanField()
|
||||||
|
placeholder = CharField()
|
||||||
|
order = IntegerField()
|
||||||
|
|
||||||
|
def create(self, validated_data: dict) -> Model:
|
||||||
|
return Model()
|
||||||
|
|
||||||
|
def update(self, instance: Model, validated_data: dict) -> Model:
|
||||||
|
return Model()
|
||||||
|
|
||||||
|
|
||||||
|
class PromptChallenge(Challenge):
|
||||||
|
"""Initial challenge being sent, define fields"""
|
||||||
|
|
||||||
|
fields = PromptSerializer(many=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptResponseChallenge(ChallengeResponse):
|
||||||
|
"""Validate response, fields are dynamically created based
|
||||||
|
on the stage"""
|
||||||
|
|
||||||
|
def __init__(self, *args, stage: PromptStage, plan: FlowPlan, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.stage = stage
|
||||||
|
self.plan = plan
|
||||||
|
# list() is called so we only load the fields once
|
||||||
|
fields = list(self.stage.fields.all())
|
||||||
|
for field in fields:
|
||||||
|
field: Prompt
|
||||||
|
self.fields[field.field_key] = field.field
|
||||||
|
# Special handling for fields with username type
|
||||||
|
# these check for existing users with the same username
|
||||||
|
if field.type == FieldTypes.USERNAME:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
f"validate_{field.field_key}",
|
||||||
|
MethodType(username_field_validator_factory(), self),
|
||||||
|
)
|
||||||
|
# Check if we have a password field, add a handler that sends a signal
|
||||||
|
# to validate it
|
||||||
|
if field.type == FieldTypes.PASSWORD:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
f"validate_{field.field_key}",
|
||||||
|
MethodType(password_single_validator_factory(), self),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.field_order = sorted(fields, key=lambda x: x.order)
|
||||||
|
|
||||||
|
def _validate_password_fields(self, *field_names):
|
||||||
|
"""Check if the value of all password fields match by merging them into a set
|
||||||
|
and checking the length"""
|
||||||
|
all_passwords = {self.initial_data[x] for x in field_names}
|
||||||
|
if len(all_passwords) > 1:
|
||||||
|
raise ValidationError(_("Passwords don't match."))
|
||||||
|
|
||||||
|
def validate(self, attrs):
|
||||||
|
if attrs == {}:
|
||||||
|
return {}
|
||||||
|
# Check if we have two password fields, and make sure they are the same
|
||||||
|
password_fields: QuerySet[Prompt] = self.stage.fields.filter(
|
||||||
|
type=FieldTypes.PASSWORD
|
||||||
|
)
|
||||||
|
if password_fields.exists() and password_fields.count() == 2:
|
||||||
|
self._validate_password_fields(
|
||||||
|
*[field.field_key for field in password_fields]
|
||||||
|
)
|
||||||
|
|
||||||
|
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
||||||
|
engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
|
||||||
|
engine.request.context = attrs
|
||||||
|
engine.build()
|
||||||
|
result = engine.result
|
||||||
|
if not result.passing:
|
||||||
|
raise ValidationError(list(result.messages))
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]:
|
||||||
|
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def username_field_validator(self: PromptChallenge, value: str) -> Any:
|
||||||
|
"""Check for duplicate usernames"""
|
||||||
|
if User.objects.filter(username=value).exists():
|
||||||
|
raise ValidationError("Username is already taken.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
return username_field_validator
|
||||||
|
|
||||||
|
|
||||||
|
def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any]:
|
||||||
|
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
|
||||||
|
|
||||||
|
def password_single_clean(self: PromptChallenge, value: str) -> Any:
|
||||||
|
"""Send password validation signals for e.g. LDAP Source"""
|
||||||
|
password_validate.send(
|
||||||
|
sender=self, password=value, plan_context=self.plan.context
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return password_single_clean
|
||||||
|
|
||||||
|
|
||||||
|
class ListPolicyEngine(PolicyEngine):
|
||||||
|
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
|
||||||
|
|
||||||
|
__list: list[Policy]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, policies: list[Policy], user: User, request: HttpRequest = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(PolicyBindingModel(), user, request)
|
||||||
|
self.__list = policies
|
||||||
|
self.use_cache = False
|
||||||
|
|
||||||
|
def _iter_bindings(self) -> Iterator[PolicyBinding]:
|
||||||
|
for policy in self.__list:
|
||||||
|
yield PolicyBinding(
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptStageView(ChallengeStageView):
|
||||||
"""Prompt Stage, save form data in plan context."""
|
"""Prompt Stage, save form data in plan context."""
|
||||||
|
|
||||||
template_name = "login/form.html"
|
response_class = PromptResponseChallenge
|
||||||
form_class = PromptForm
|
|
||||||
|
|
||||||
def get_context_data(self, **kwargs):
|
def get_challenge(self, *args, **kwargs) -> Challenge:
|
||||||
ctx = super().get_context_data(**kwargs)
|
fields = list(self.executor.current_stage.fields.all())
|
||||||
ctx["title"] = _(self.executor.current_stage.name)
|
challenge = PromptChallenge(
|
||||||
return ctx
|
data={
|
||||||
|
"type": ChallengeTypes.native,
|
||||||
|
"component": "ak-stage-prompt",
|
||||||
|
"fields": [PromptSerializer(field).data for field in fields],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return challenge
|
||||||
|
|
||||||
def get_form_kwargs(self):
|
def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
|
||||||
kwargs = super().get_form_kwargs()
|
if not self.executor.plan:
|
||||||
kwargs["stage"] = self.executor.current_stage
|
raise ValueError
|
||||||
kwargs["plan"] = self.executor.plan
|
return PromptResponseChallenge(
|
||||||
return kwargs
|
instance=None,
|
||||||
|
data=data,
|
||||||
|
stage=self.executor.current_stage,
|
||||||
|
plan=self.executor.plan,
|
||||||
|
)
|
||||||
|
|
||||||
def form_valid(self, form: PromptForm) -> HttpResponse:
|
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
|
||||||
"""Form data is valid"""
|
|
||||||
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
|
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
|
||||||
self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {}
|
self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {}
|
||||||
self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(form.cleaned_data)
|
self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data)
|
||||||
|
print(self.executor.plan.context[PLAN_CONTEXT_PROMPT])
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
|
@ -11,9 +11,8 @@ from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
|
||||||
from authentik.flows.planner import FlowPlan
|
from authentik.flows.planner import FlowPlan
|
||||||
from authentik.flows.views import SESSION_KEY_PLAN
|
from authentik.flows.views import SESSION_KEY_PLAN
|
||||||
from authentik.policies.expression.models import ExpressionPolicy
|
from authentik.policies.expression.models import ExpressionPolicy
|
||||||
from authentik.stages.prompt.forms import PromptForm
|
|
||||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
||||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptResponseChallenge
|
||||||
|
|
||||||
|
|
||||||
class TestPromptStage(TestCase):
|
class TestPromptStage(TestCase):
|
||||||
|
@ -112,8 +111,8 @@ class TestPromptStage(TestCase):
|
||||||
self.assertIn(prompt.label, force_str(response.content))
|
self.assertIn(prompt.label, force_str(response.content))
|
||||||
self.assertIn(prompt.placeholder, force_str(response.content))
|
self.assertIn(prompt.placeholder, force_str(response.content))
|
||||||
|
|
||||||
def test_valid_form_with_policy(self) -> PromptForm:
|
def test_valid_challenge_with_policy(self) -> PromptResponseChallenge:
|
||||||
"""Test form validation"""
|
"""Test challenge_response validation"""
|
||||||
plan = FlowPlan(
|
plan = FlowPlan(
|
||||||
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
||||||
)
|
)
|
||||||
|
@ -123,12 +122,14 @@ class TestPromptStage(TestCase):
|
||||||
)
|
)
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
|
challenge_response = PromptResponseChallenge(
|
||||||
self.assertEqual(form.is_valid(), True)
|
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||||
return form
|
)
|
||||||
|
self.assertEqual(challenge_response.is_valid(), True)
|
||||||
|
return challenge_response
|
||||||
|
|
||||||
def test_invalid_form(self) -> PromptForm:
|
def test_invalid_challenge(self) -> PromptResponseChallenge:
|
||||||
"""Test form validation"""
|
"""Test challenge_response validation"""
|
||||||
plan = FlowPlan(
|
plan = FlowPlan(
|
||||||
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
||||||
)
|
)
|
||||||
|
@ -138,12 +139,14 @@ class TestPromptStage(TestCase):
|
||||||
)
|
)
|
||||||
self.stage.validation_policies.set([expr_policy])
|
self.stage.validation_policies.set([expr_policy])
|
||||||
self.stage.save()
|
self.stage.save()
|
||||||
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
|
challenge_response = PromptResponseChallenge(
|
||||||
self.assertEqual(form.is_valid(), False)
|
None, stage=self.stage, plan=plan, data=self.prompt_data
|
||||||
return form
|
)
|
||||||
|
self.assertEqual(challenge_response.is_valid(), False)
|
||||||
|
return challenge_response
|
||||||
|
|
||||||
def test_valid_form_request(self):
|
def test_valid_challenge_request(self):
|
||||||
"""Test a request with valid form data"""
|
"""Test a request with valid challenge_response data"""
|
||||||
plan = FlowPlan(
|
plan = FlowPlan(
|
||||||
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
||||||
)
|
)
|
||||||
|
@ -151,7 +154,7 @@ class TestPromptStage(TestCase):
|
||||||
session[SESSION_KEY_PLAN] = plan
|
session[SESSION_KEY_PLAN] = plan
|
||||||
session.save()
|
session.save()
|
||||||
|
|
||||||
form = self.test_valid_form_with_policy()
|
challenge_response = self.test_valid_challenge_with_policy()
|
||||||
|
|
||||||
with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()):
|
with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()):
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
|
@ -159,7 +162,7 @@ class TestPromptStage(TestCase):
|
||||||
"authentik_api:flow-executor",
|
"authentik_api:flow-executor",
|
||||||
kwargs={"flow_slug": self.flow.slug},
|
kwargs={"flow_slug": self.flow.slug},
|
||||||
),
|
),
|
||||||
form.cleaned_data,
|
challenge_response.validated_data,
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
"""Prompt Widgets"""
|
|
||||||
from django import forms
|
|
||||||
from django.utils.safestring import mark_safe
|
|
||||||
|
|
||||||
|
|
||||||
class StaticTextWidget(forms.widgets.Widget):
|
|
||||||
"""Widget to render static text"""
|
|
||||||
|
|
||||||
def render(self, name, value, attrs=None, renderer=None):
|
|
||||||
return mark_safe(f"<p>{value}</p>") # nosec
|
|
||||||
|
|
||||||
|
|
||||||
class HorizontalRuleWidget(forms.widgets.Widget):
|
|
||||||
"""Widget, which renders an <hr> element"""
|
|
||||||
|
|
||||||
def render(self, name, value, attrs=None, renderer=None):
|
|
||||||
return mark_safe("<hr>") # nosec
|
|
|
@ -153,7 +153,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
|
||||||
ObjectManager().run()
|
ObjectManager().run()
|
||||||
|
|
||||||
|
|
||||||
def retry(max_retires=3, exceptions=None):
|
def retry(max_retires=1, exceptions=None):
|
||||||
"""Retry test multiple times. Default to catching Selenium Timeout Exception"""
|
"""Retry test multiple times. Default to catching Selenium Timeout Exception"""
|
||||||
|
|
||||||
if not exceptions:
|
if not exceptions:
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
import { gettext } from "django";
|
||||||
|
import { CSSResult, customElement, html, property, TemplateResult } from "lit-element";
|
||||||
|
import { Challenge } from "../../../api/Flows";
|
||||||
|
import { COMMON_STYLES } from "../../../common/styles";
|
||||||
|
import { BaseStage } from "../base";
|
||||||
|
|
||||||
|
export interface Prompt {
|
||||||
|
field_key: string;
|
||||||
|
label: string;
|
||||||
|
type: string;
|
||||||
|
required: boolean;
|
||||||
|
placeholder: string;
|
||||||
|
order: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptChallenge extends Challenge {
|
||||||
|
fields: Prompt[];
|
||||||
|
}
|
||||||
|
|
||||||
|
@customElement("ak-stage-prompt")
|
||||||
|
export class PromptStage extends BaseStage {
|
||||||
|
|
||||||
|
@property({attribute: false})
|
||||||
|
challenge?: PromptChallenge;
|
||||||
|
|
||||||
|
static get styles(): CSSResult[] {
|
||||||
|
return COMMON_STYLES;
|
||||||
|
}
|
||||||
|
|
||||||
|
renderPromptInner(prompt: Prompt): TemplateResult {
|
||||||
|
switch (prompt.type) {
|
||||||
|
case "text":
|
||||||
|
return html`<input
|
||||||
|
type="text"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
autocomplete="off"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}
|
||||||
|
value="">`;
|
||||||
|
case "username":
|
||||||
|
return html`<input
|
||||||
|
type="text"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
autocomplete="username"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}
|
||||||
|
value="">`;
|
||||||
|
case "email":
|
||||||
|
return html`<input
|
||||||
|
type="email"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}
|
||||||
|
value="">`;
|
||||||
|
case "password":
|
||||||
|
return html`<input
|
||||||
|
type="password"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
autocomplete="new-password"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "number":
|
||||||
|
return html`<input
|
||||||
|
type="number"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "checkbox":
|
||||||
|
return html`<input
|
||||||
|
type="checkbox"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "date":
|
||||||
|
return html`<input
|
||||||
|
type="date"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "date-time":
|
||||||
|
return html`<input
|
||||||
|
type="datetime"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
placeholder="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "separator":
|
||||||
|
return html`<hr>`;
|
||||||
|
case "hidden":
|
||||||
|
return html`<input
|
||||||
|
type="hidden"
|
||||||
|
name="${prompt.field_key}"
|
||||||
|
value="${prompt.placeholder}"
|
||||||
|
class="pf-c-form-control"
|
||||||
|
?required=${prompt.required}>`;
|
||||||
|
case "static":
|
||||||
|
return html`<p
|
||||||
|
class="pf-c-form-control">${prompt.placeholder}
|
||||||
|
</p>`;
|
||||||
|
}
|
||||||
|
return html``;
|
||||||
|
}
|
||||||
|
|
||||||
|
render(): TemplateResult {
|
||||||
|
if (!this.challenge) {
|
||||||
|
return html`<ak-loading-state></ak-loading-state>`;
|
||||||
|
}
|
||||||
|
return html`<header class="pf-c-login__main-header">
|
||||||
|
<h1 class="pf-c-title pf-m-3xl">
|
||||||
|
${this.challenge.title}
|
||||||
|
</h1>
|
||||||
|
</header>
|
||||||
|
<div class="pf-c-login__main-body">
|
||||||
|
<form class="pf-c-form" @submit=${(e: Event) => {this.submit(e);}}>
|
||||||
|
${this.challenge.fields.map((prompt) => {
|
||||||
|
return html`<ak-form-element
|
||||||
|
label="${prompt.label}"
|
||||||
|
?required="${prompt.required}"
|
||||||
|
class="pf-c-form__group"
|
||||||
|
.errors=${(this.challenge?.response_errors || {})[prompt.field_key]}>
|
||||||
|
${this.renderPromptInner(prompt)}
|
||||||
|
</ak-form-element>`;
|
||||||
|
})}
|
||||||
|
<div class="pf-c-form__group pf-m-action">
|
||||||
|
<button type="submit" class="pf-c-button pf-m-primary pf-m-block">
|
||||||
|
${gettext("Continue")}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<footer class="pf-c-login__main-footer">
|
||||||
|
<ul class="pf-c-login__main-footer-links">
|
||||||
|
</ul>
|
||||||
|
</footer>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import "../../elements/stages/password/PasswordStage";
|
||||||
import "../../elements/stages/consent/ConsentStage";
|
import "../../elements/stages/consent/ConsentStage";
|
||||||
import "../../elements/stages/email/EmailStage";
|
import "../../elements/stages/email/EmailStage";
|
||||||
import "../../elements/stages/autosubmit/AutosubmitStage";
|
import "../../elements/stages/autosubmit/AutosubmitStage";
|
||||||
|
import "../../elements/stages/prompt/PromptStage";
|
||||||
import { ShellChallenge, Challenge, ChallengeTypes, Flow, RedirectChallenge } from "../../api/Flows";
|
import { ShellChallenge, Challenge, ChallengeTypes, Flow, RedirectChallenge } from "../../api/Flows";
|
||||||
import { DefaultClient } from "../../api/Client";
|
import { DefaultClient } from "../../api/Client";
|
||||||
import { IdentificationChallenge } from "../../elements/stages/identification/IdentificationStage";
|
import { IdentificationChallenge } from "../../elements/stages/identification/IdentificationStage";
|
||||||
|
@ -14,6 +15,7 @@ import { PasswordChallenge } from "../../elements/stages/password/PasswordStage"
|
||||||
import { ConsentChallenge } from "../../elements/stages/consent/ConsentStage";
|
import { ConsentChallenge } from "../../elements/stages/consent/ConsentStage";
|
||||||
import { EmailChallenge } from "../../elements/stages/email/EmailStage";
|
import { EmailChallenge } from "../../elements/stages/email/EmailStage";
|
||||||
import { AutosubmitChallenge } from "../../elements/stages/autosubmit/AutosubmitStage";
|
import { AutosubmitChallenge } from "../../elements/stages/autosubmit/AutosubmitStage";
|
||||||
|
import { PromptChallenge } from "../../elements/stages/prompt/PromptStage";
|
||||||
|
|
||||||
@customElement("ak-flow-executor")
|
@customElement("ak-flow-executor")
|
||||||
export class FlowExecutor extends LitElement {
|
export class FlowExecutor extends LitElement {
|
||||||
|
@ -120,6 +122,8 @@ export class FlowExecutor extends LitElement {
|
||||||
return html`<ak-stage-email .host=${this} .challenge=${this.challenge as EmailChallenge}></ak-stage-email>`;
|
return html`<ak-stage-email .host=${this} .challenge=${this.challenge as EmailChallenge}></ak-stage-email>`;
|
||||||
case "ak-stage-autosubmit":
|
case "ak-stage-autosubmit":
|
||||||
return html`<ak-stage-autosubmit .host=${this} .challenge=${this.challenge as AutosubmitChallenge}></ak-stage-autosubmit>`;
|
return html`<ak-stage-autosubmit .host=${this} .challenge=${this.challenge as AutosubmitChallenge}></ak-stage-autosubmit>`;
|
||||||
|
case "ak-stage-prompt":
|
||||||
|
return html`<ak-stage-prompt .host=${this} .challenge=${this.challenge as PromptChallenge}></ak-stage-prompt>`;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Reference in New Issue