stages/prompt: migrate to SPA

This commit is contained in:
Jens Langhammer 2021-02-21 18:13:47 +01:00
parent d35f524865
commit 27cd10e072
9 changed files with 360 additions and 209 deletions

View file

@ -28,9 +28,3 @@ class TestOverviewViews(TestCase):
self.assertEqual(
self.client.get(reverse("authentik_core:shell")).status_code, 200
)
def test_overview(self):
"""Test overview"""
self.assertEqual(
self.client.get(reverse("authentik_core:overview")).status_code, 200
)

View file

@ -1,20 +1,7 @@
"""Prompt forms"""
from email.policy import Policy
from types import MethodType
from typing import Any, Callable, Iterator
from django import forms
from django.db.models.query import QuerySet
from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _
from guardian.shortcuts import get_anonymous_user
from authentik.core.models import User
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyBindingModel
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
from authentik.stages.prompt.signals import password_validate
from authentik.stages.prompt.models import Prompt, PromptStage
class PromptStageForm(forms.ModelForm):
@ -47,111 +34,3 @@ class PromptAdminForm(forms.ModelForm):
"label": forms.TextInput(),
"placeholder": forms.TextInput(),
}
class ListPolicyEngine(PolicyEngine):
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
__list: list[Policy]
def __init__(
self, policies: list[Policy], user: User, request: HttpRequest = None
) -> None:
super().__init__(PolicyBindingModel(), user, request)
self.__list = policies
self.use_cache = False
def _iter_bindings(self) -> Iterator[PolicyBinding]:
for policy in self.__list:
yield PolicyBinding(
policy=policy,
)
class PromptForm(forms.Form):
"""Dynamically created form based on PromptStage"""
stage: PromptStage
plan: FlowPlan
def __init__(self, stage: PromptStage, plan: FlowPlan, *args, **kwargs):
self.stage = stage
self.plan = plan
super().__init__(*args, **kwargs)
# list() is called so we only load the fields once
fields = list(self.stage.fields.all())
for field in fields:
field: Prompt
self.fields[field.field_key] = field.field
# Special handling for fields with username type
# these check for existing users with the same username
if field.type == FieldTypes.USERNAME:
setattr(
self,
f"clean_{field.field_key}",
MethodType(username_field_cleaner_factory(field), self),
)
# Check if we have a password field, add a handler that sends a signal
# to validate it
if field.type == FieldTypes.PASSWORD:
setattr(
self,
f"clean_{field.field_key}",
MethodType(password_single_cleaner_factory(field), self),
)
self.field_order = sorted(fields, key=lambda x: x.order)
def _clean_password_fields(self, *field_names):
"""Check if the value of all password fields match by merging them into a set
and checking the length"""
all_passwords = {self.cleaned_data[x] for x in field_names}
if len(all_passwords) > 1:
raise forms.ValidationError(_("Passwords don't match."))
def clean(self):
cleaned_data = super().clean()
if cleaned_data == {}:
return {}
# Check if we have two password fields, and make sure they are the same
password_fields: QuerySet[Prompt] = self.stage.fields.filter(
type=FieldTypes.PASSWORD
)
if password_fields.exists() and password_fields.count() == 2:
self._clean_password_fields(*[field.field_key for field in password_fields])
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
engine.request.context = cleaned_data
engine.build()
result = engine.result
if not result.passing:
raise forms.ValidationError(list(result.messages))
return cleaned_data
def username_field_cleaner_factory(field: Prompt) -> Callable:
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
def username_field_cleaner(self: PromptForm) -> Any:
"""Check for duplicate usernames"""
username = self.cleaned_data.get(field.field_key)
if User.objects.filter(username=username).exists():
raise forms.ValidationError("Username is already taken.")
return username
return username_field_cleaner
def password_single_cleaner_factory(field: Prompt) -> Callable[[PromptForm], Any]:
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
def password_single_clean(self: PromptForm) -> Any:
"""Send password validation signals for e.g. LDAP Source"""
password = self.cleaned_data[field.field_key]
password_validate.send(
sender=self, password=password, plan_context=self.plan.context
)
return password
return password_single_clean

View file

@ -2,17 +2,23 @@
from typing import Type
from uuid import uuid4
from django import forms
from django.db import models
from django.forms import ModelForm
from django.utils.translation import gettext_lazy as _
from django.views import View
from rest_framework.fields import (
BooleanField,
CharField,
DateField,
DateTimeField,
EmailField,
IntegerField,
)
from rest_framework.serializers import BaseSerializer
from authentik.flows.models import Stage
from authentik.lib.models import SerializerModel
from authentik.policies.models import Policy
from authentik.stages.prompt.widgets import HorizontalRuleWidget, StaticTextWidget
class FieldTypes(models.TextChoices):
@ -43,8 +49,8 @@ class FieldTypes(models.TextChoices):
)
NUMBER = "number"
CHECKBOX = "checkbox"
DATE = "data"
DATE_TIME = "data-time"
DATE = "date"
DATE_TIME = "date-time"
SEPARATOR = "separator", _("Separator: Static Separator Line")
HIDDEN = "hidden", _("Hidden: Hidden field, can be used to insert data into form.")
@ -73,49 +79,34 @@ class Prompt(SerializerModel):
return PromptSerializer
@property
def field(self):
"""Return instantiated form input field"""
attrs = {"placeholder": _(self.placeholder)}
field_class = forms.CharField
widget = forms.TextInput(attrs=attrs)
def field(self) -> CharField:
"""Get field type for Challenge and response"""
field_class = CharField
kwargs = {
"label": _(self.label),
"required": self.required,
}
if self.type == FieldTypes.EMAIL:
field_class = forms.EmailField
if self.type == FieldTypes.USERNAME:
attrs["autocomplete"] = "username"
if self.type == FieldTypes.PASSWORD:
widget = forms.PasswordInput(attrs=attrs)
attrs["autocomplete"] = "new-password"
field_class = EmailField
if self.type == FieldTypes.NUMBER:
field_class = forms.IntegerField
widget = forms.NumberInput(attrs=attrs)
field_class = IntegerField
# TODO: Hidden?
if self.type == FieldTypes.HIDDEN:
widget = forms.HiddenInput(attrs=attrs)
kwargs["required"] = False
kwargs["initial"] = self.placeholder
if self.type == FieldTypes.CHECKBOX:
field_class = forms.BooleanField
field_class = BooleanField
kwargs["required"] = False
if self.type == FieldTypes.DATE:
attrs["type"] = "date"
widget = forms.DateInput(attrs=attrs)
field_class = DateField
if self.type == FieldTypes.DATE_TIME:
attrs["type"] = "datetime-local"
widget = forms.DateTimeInput(attrs=attrs)
field_class = DateTimeField
if self.type == FieldTypes.STATIC:
widget = StaticTextWidget(attrs=attrs)
kwargs["initial"] = self.placeholder
kwargs["required"] = False
kwargs["label"] = ""
if self.type == FieldTypes.SEPARATOR:
widget = HorizontalRuleWidget(attrs=attrs)
kwargs["required"] = False
kwargs["label"] = ""
kwargs["widget"] = widget
return field_class(**kwargs)
def save(self, *args, **kwargs):

View file

@ -1,36 +1,189 @@
"""Prompt Stage Logic"""
from django.http import HttpResponse
from email.policy import Policy
from types import MethodType
from typing import Any, Callable, Iterator
from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict
from django.utils.translation import gettext_lazy as _
from django.views.generic import FormView
from guardian.shortcuts import get_anonymous_user
from rest_framework.fields import BooleanField, CharField, IntegerField
from rest_framework.serializers import Serializer, ValidationError
from structlog.stdlib import get_logger
from authentik.flows.stage import StageView
from authentik.stages.prompt.forms import PromptForm
from authentik.core.models import User
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.stage import ChallengeStageView
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyBindingModel
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
from authentik.stages.prompt.signals import password_validate
LOGGER = get_logger()
PLAN_CONTEXT_PROMPT = "prompt_data"
class PromptStageView(FormView, StageView):
class PromptSerializer(Serializer):
"""Serializer for a single Prompt field"""
field_key = CharField()
label = CharField()
type = CharField()
required = BooleanField()
placeholder = CharField()
order = IntegerField()
def create(self, validated_data: dict) -> Model:
return Model()
def update(self, instance: Model, validated_data: dict) -> Model:
return Model()
class PromptChallenge(Challenge):
"""Initial challenge being sent, define fields"""
fields = PromptSerializer(many=True)
class PromptResponseChallenge(ChallengeResponse):
"""Validate response, fields are dynamically created based
on the stage"""
def __init__(self, *args, stage: PromptStage, plan: FlowPlan, **kwargs):
super().__init__(*args, **kwargs)
self.stage = stage
self.plan = plan
# list() is called so we only load the fields once
fields = list(self.stage.fields.all())
for field in fields:
field: Prompt
self.fields[field.field_key] = field.field
# Special handling for fields with username type
# these check for existing users with the same username
if field.type == FieldTypes.USERNAME:
setattr(
self,
f"validate_{field.field_key}",
MethodType(username_field_validator_factory(), self),
)
# Check if we have a password field, add a handler that sends a signal
# to validate it
if field.type == FieldTypes.PASSWORD:
setattr(
self,
f"validate_{field.field_key}",
MethodType(password_single_validator_factory(), self),
)
self.field_order = sorted(fields, key=lambda x: x.order)
def _validate_password_fields(self, *field_names):
"""Check if the value of all password fields match by merging them into a set
and checking the length"""
all_passwords = {self.initial_data[x] for x in field_names}
if len(all_passwords) > 1:
raise ValidationError(_("Passwords don't match."))
def validate(self, attrs):
if attrs == {}:
return {}
# Check if we have two password fields, and make sure they are the same
password_fields: QuerySet[Prompt] = self.stage.fields.filter(
type=FieldTypes.PASSWORD
)
if password_fields.exists() and password_fields.count() == 2:
self._validate_password_fields(
*[field.field_key for field in password_fields]
)
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
engine = ListPolicyEngine(self.stage.validation_policies.all(), user)
engine.request.context = attrs
engine.build()
result = engine.result
if not result.passing:
raise ValidationError(list(result.messages))
return attrs
def username_field_validator_factory() -> Callable[[PromptChallenge, str], Any]:
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
# pylint: disable=unused-argument
def username_field_validator(self: PromptChallenge, value: str) -> Any:
"""Check for duplicate usernames"""
if User.objects.filter(username=value).exists():
raise ValidationError("Username is already taken.")
return value
return username_field_validator
def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any]:
"""Return a `clean_` method for `field`. Clean method checks if username is taken already."""
def password_single_clean(self: PromptChallenge, value: str) -> Any:
"""Send password validation signals for e.g. LDAP Source"""
password_validate.send(
sender=self, password=value, plan_context=self.plan.context
)
return value
return password_single_clean
class ListPolicyEngine(PolicyEngine):
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
__list: list[Policy]
def __init__(
self, policies: list[Policy], user: User, request: HttpRequest = None
) -> None:
super().__init__(PolicyBindingModel(), user, request)
self.__list = policies
self.use_cache = False
def _iter_bindings(self) -> Iterator[PolicyBinding]:
for policy in self.__list:
yield PolicyBinding(
policy=policy,
)
class PromptStageView(ChallengeStageView):
"""Prompt Stage, save form data in plan context."""
template_name = "login/form.html"
form_class = PromptForm
response_class = PromptResponseChallenge
def get_context_data(self, **kwargs):
ctx = super().get_context_data(**kwargs)
ctx["title"] = _(self.executor.current_stage.name)
return ctx
def get_challenge(self, *args, **kwargs) -> Challenge:
fields = list(self.executor.current_stage.fields.all())
challenge = PromptChallenge(
data={
"type": ChallengeTypes.native,
"component": "ak-stage-prompt",
"fields": [PromptSerializer(field).data for field in fields],
},
)
return challenge
def get_form_kwargs(self):
kwargs = super().get_form_kwargs()
kwargs["stage"] = self.executor.current_stage
kwargs["plan"] = self.executor.plan
return kwargs
def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
if not self.executor.plan:
raise ValueError
return PromptResponseChallenge(
instance=None,
data=data,
stage=self.executor.current_stage,
plan=self.executor.plan,
)
def form_valid(self, form: PromptForm) -> HttpResponse:
"""Form data is valid"""
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {}
self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(form.cleaned_data)
self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data)
print(self.executor.plan.context[PLAN_CONTEXT_PROMPT])
return self.executor.stage_ok()

View file

@ -11,9 +11,8 @@ from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
from authentik.flows.planner import FlowPlan
from authentik.flows.views import SESSION_KEY_PLAN
from authentik.policies.expression.models import ExpressionPolicy
from authentik.stages.prompt.forms import PromptForm
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT, PromptResponseChallenge
class TestPromptStage(TestCase):
@ -112,8 +111,8 @@ class TestPromptStage(TestCase):
self.assertIn(prompt.label, force_str(response.content))
self.assertIn(prompt.placeholder, force_str(response.content))
def test_valid_form_with_policy(self) -> PromptForm:
"""Test form validation"""
def test_valid_challenge_with_policy(self) -> PromptResponseChallenge:
"""Test challenge_response validation"""
plan = FlowPlan(
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
)
@ -123,12 +122,14 @@ class TestPromptStage(TestCase):
)
self.stage.validation_policies.set([expr_policy])
self.stage.save()
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
self.assertEqual(form.is_valid(), True)
return form
challenge_response = PromptResponseChallenge(
None, stage=self.stage, plan=plan, data=self.prompt_data
)
self.assertEqual(challenge_response.is_valid(), True)
return challenge_response
def test_invalid_form(self) -> PromptForm:
"""Test form validation"""
def test_invalid_challenge(self) -> PromptResponseChallenge:
"""Test challenge_response validation"""
plan = FlowPlan(
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
)
@ -138,12 +139,14 @@ class TestPromptStage(TestCase):
)
self.stage.validation_policies.set([expr_policy])
self.stage.save()
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
self.assertEqual(form.is_valid(), False)
return form
challenge_response = PromptResponseChallenge(
None, stage=self.stage, plan=plan, data=self.prompt_data
)
self.assertEqual(challenge_response.is_valid(), False)
return challenge_response
def test_valid_form_request(self):
"""Test a request with valid form data"""
def test_valid_challenge_request(self):
"""Test a request with valid challenge_response data"""
plan = FlowPlan(
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
)
@ -151,7 +154,7 @@ class TestPromptStage(TestCase):
session[SESSION_KEY_PLAN] = plan
session.save()
form = self.test_valid_form_with_policy()
challenge_response = self.test_valid_challenge_with_policy()
with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()):
response = self.client.post(
@ -159,7 +162,7 @@ class TestPromptStage(TestCase):
"authentik_api:flow-executor",
kwargs={"flow_slug": self.flow.slug},
),
form.cleaned_data,
challenge_response.validated_data,
)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(

View file

@ -1,17 +0,0 @@
"""Prompt Widgets"""
from django import forms
from django.utils.safestring import mark_safe
class StaticTextWidget(forms.widgets.Widget):
"""Widget to render static text"""
def render(self, name, value, attrs=None, renderer=None):
return mark_safe(f"<p>{value}</p>") # nosec
class HorizontalRuleWidget(forms.widgets.Widget):
"""Widget, which renders an <hr> element"""
def render(self, name, value, attrs=None, renderer=None):
return mark_safe("<hr>") # nosec

View file

@ -153,7 +153,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
ObjectManager().run()
def retry(max_retires=3, exceptions=None):
def retry(max_retires=1, exceptions=None):
"""Retry test multiple times. Default to catching Selenium Timeout Exception"""
if not exceptions:

View file

@ -0,0 +1,144 @@
import { gettext } from "django";
import { CSSResult, customElement, html, property, TemplateResult } from "lit-element";
import { Challenge } from "../../../api/Flows";
import { COMMON_STYLES } from "../../../common/styles";
import { BaseStage } from "../base";
export interface Prompt {
field_key: string;
label: string;
type: string;
required: boolean;
placeholder: string;
order: number;
}
export interface PromptChallenge extends Challenge {
fields: Prompt[];
}
@customElement("ak-stage-prompt")
export class PromptStage extends BaseStage {
@property({attribute: false})
challenge?: PromptChallenge;
static get styles(): CSSResult[] {
return COMMON_STYLES;
}
renderPromptInner(prompt: Prompt): TemplateResult {
switch (prompt.type) {
case "text":
return html`<input
type="text"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
autocomplete="off"
class="pf-c-form-control"
?required=${prompt.required}
value="">`;
case "username":
return html`<input
type="text"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
autocomplete="username"
class="pf-c-form-control"
?required=${prompt.required}
value="">`;
case "email":
return html`<input
type="email"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}
value="">`;
case "password":
return html`<input
type="password"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
autocomplete="new-password"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "number":
return html`<input
type="number"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "checkbox":
return html`<input
type="checkbox"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "date":
return html`<input
type="date"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "date-time":
return html`<input
type="datetime"
name="${prompt.field_key}"
placeholder="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "separator":
return html`<hr>`;
case "hidden":
return html`<input
type="hidden"
name="${prompt.field_key}"
value="${prompt.placeholder}"
class="pf-c-form-control"
?required=${prompt.required}>`;
case "static":
return html`<p
class="pf-c-form-control">${prompt.placeholder}
</p>`;
}
return html``;
}
render(): TemplateResult {
if (!this.challenge) {
return html`<ak-loading-state></ak-loading-state>`;
}
return html`<header class="pf-c-login__main-header">
<h1 class="pf-c-title pf-m-3xl">
${this.challenge.title}
</h1>
</header>
<div class="pf-c-login__main-body">
<form class="pf-c-form" @submit=${(e: Event) => {this.submit(e);}}>
${this.challenge.fields.map((prompt) => {
return html`<ak-form-element
label="${prompt.label}"
?required="${prompt.required}"
class="pf-c-form__group"
.errors=${(this.challenge?.response_errors || {})[prompt.field_key]}>
${this.renderPromptInner(prompt)}
</ak-form-element>`;
})}
<div class="pf-c-form__group pf-m-action">
<button type="submit" class="pf-c-button pf-m-primary pf-m-block">
${gettext("Continue")}
</button>
</div>
</form>
</div>
<footer class="pf-c-login__main-footer">
<ul class="pf-c-login__main-footer-links">
</ul>
</footer>`;
}
}

View file

@ -7,6 +7,7 @@ import "../../elements/stages/password/PasswordStage";
import "../../elements/stages/consent/ConsentStage";
import "../../elements/stages/email/EmailStage";
import "../../elements/stages/autosubmit/AutosubmitStage";
import "../../elements/stages/prompt/PromptStage";
import { ShellChallenge, Challenge, ChallengeTypes, Flow, RedirectChallenge } from "../../api/Flows";
import { DefaultClient } from "../../api/Client";
import { IdentificationChallenge } from "../../elements/stages/identification/IdentificationStage";
@ -14,6 +15,7 @@ import { PasswordChallenge } from "../../elements/stages/password/PasswordStage"
import { ConsentChallenge } from "../../elements/stages/consent/ConsentStage";
import { EmailChallenge } from "../../elements/stages/email/EmailStage";
import { AutosubmitChallenge } from "../../elements/stages/autosubmit/AutosubmitStage";
import { PromptChallenge } from "../../elements/stages/prompt/PromptStage";
@customElement("ak-flow-executor")
export class FlowExecutor extends LitElement {
@ -120,6 +122,8 @@ export class FlowExecutor extends LitElement {
return html`<ak-stage-email .host=${this} .challenge=${this.challenge as EmailChallenge}></ak-stage-email>`;
case "ak-stage-autosubmit":
return html`<ak-stage-autosubmit .host=${this} .challenge=${this.challenge as AutosubmitChallenge}></ak-stage-autosubmit>`;
case "ak-stage-prompt":
return html`<ak-stage-prompt .host=${this} .challenge=${this.challenge as PromptChallenge}></ak-stage-prompt>`;
default:
break;
}