stages/*: update tests for new response

This commit is contained in:
Jens Langhammer 2021-02-20 19:41:32 +01:00
parent bdb86d7119
commit e0ae92ccc7
19 changed files with 101 additions and 48 deletions

View file

@ -1,6 +1,7 @@
"""Challenge helpers""" """Challenge helpers"""
from enum import Enum from enum import Enum
from django.db.models.base import Model
from django.http import JsonResponse from django.http import JsonResponse
from rest_framework.fields import ChoiceField, JSONField from rest_framework.fields import ChoiceField, JSONField
from rest_framework.serializers import CharField, Serializer from rest_framework.serializers import CharField, Serializer
@ -23,11 +24,24 @@ class Challenge(Serializer):
type = ChoiceField(choices=list(ChallengeTypes)) type = ChoiceField(choices=list(ChallengeTypes))
component = CharField(required=False) component = CharField(required=False)
args = JSONField() args = JSONField()
title = CharField(required=False)
def create(self, validated_data: dict) -> Model:
return Model()
def update(self, instance: Model, validated_data: dict) -> Model:
return Model()
class ChallengeResponse(Serializer): class ChallengeResponse(Serializer):
"""Base class for all challenge responses""" """Base class for all challenge responses"""
def create(self, validated_data: dict) -> Model:
return Model()
def update(self, instance: Model, validated_data: dict) -> Model:
return Model()
class HttpChallengeResponse(JsonResponse): class HttpChallengeResponse(JsonResponse):
"""Subclass of JsonResponse that uses the `DataclassEncoder`""" """Subclass of JsonResponse that uses the `DataclassEncoder`"""

View file

@ -1,6 +1,6 @@
"""authentik stage Base view""" """authentik stage Base view"""
from collections import namedtuple from collections import namedtuple
from typing import Any from typing import Any, Type
from django.http import HttpRequest from django.http import HttpRequest
from django.http.response import HttpResponse, JsonResponse from django.http.response import HttpResponse, JsonResponse
@ -52,25 +52,36 @@ class StageView(TemplateView):
class ChallengeStageView(StageView): class ChallengeStageView(StageView):
"""Stage view which response with a challenge"""
response_class = ChallengeResponse response_class = ChallengeResponse
def get_response_class(self) -> Type[ChallengeResponse]:
"""Return the response class type"""
return self.response_class
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
challenge = self.get_challenge() challenge = self.get_challenge()
challenge.title = self.executor.flow.title
challenge.is_valid() challenge.is_valid()
return HttpChallengeResponse(challenge) return HttpChallengeResponse(challenge)
# pylint: disable=unused-argument
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
challenge: ChallengeResponse = self.response_class(data=request.POST) """Handle challenge response"""
challenge: ChallengeResponse = self.get_response_class()(data=request.POST)
if not challenge.is_valid(): if not challenge.is_valid():
return self.challenge_invalid(challenge) return self.challenge_invalid(challenge)
return self.challenge_valid(challenge) return self.challenge_valid(challenge)
def get_challenge(self) -> Challenge: def get_challenge(self) -> Challenge:
"""Return the challenge that the client should solve"""
raise NotImplementedError raise NotImplementedError
def challenge_valid(self, challenge: ChallengeResponse) -> HttpResponse: def challenge_valid(self, challenge: ChallengeResponse) -> HttpResponse:
"""Callback when the challenge has the correct format"""
raise NotImplementedError raise NotImplementedError
def challenge_invalid(self, challenge: ChallengeResponse) -> HttpResponse: def challenge_invalid(self, challenge: ChallengeResponse) -> HttpResponse:
"""Callback when the challenge has the incorrect format"""
return JsonResponse(challenge.errors) return JsonResponse(challenge.errors)

View file

@ -282,7 +282,7 @@ class TestFlowExecutor(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_reevaluate_keep(self): def test_reevaluate_keep(self):
@ -435,7 +435,7 @@ class TestFlowExecutor(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_stageview_user_identifier(self): def test_stageview_user_identifier(self):

View file

@ -3,13 +3,7 @@ from traceback import format_tb
from typing import Any, Optional from typing import Any, Optional
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import ( from django.http import Http404, HttpRequest, HttpResponse, HttpResponseRedirect
Http404,
HttpRequest,
HttpResponse,
HttpResponseRedirect,
JsonResponse,
)
from django.shortcuts import get_object_or_404, redirect, reverse from django.shortcuts import get_object_or_404, redirect, reverse
from django.template.response import TemplateResponse from django.template.response import TemplateResponse
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator

View file

@ -1,8 +1,5 @@
"""OTP Validation""" """OTP Validation"""
from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.views.generic import FormView
from django_otp import user_has_device from django_otp import user_has_device
from rest_framework.fields import IntegerField from rest_framework.fields import IntegerField
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -10,7 +7,7 @@ from structlog.stdlib import get_logger
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
from authentik.flows.models import NotConfiguredAction from authentik.flows.models import NotConfiguredAction
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ChallengeStageView, StageView from authentik.flows.stage import ChallengeStageView
from authentik.stages.authenticator_validate.forms import ValidationForm from authentik.stages.authenticator_validate.forms import ValidationForm
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage
@ -18,13 +15,13 @@ LOGGER = get_logger()
class CodeChallengeResponse(ChallengeResponse): class CodeChallengeResponse(ChallengeResponse):
"""Challenge used for Code-based authenticators"""
code = IntegerField(min_value=0) code = IntegerField(min_value=0)
class WebAuthnChallengeResponse(ChallengeResponse): class WebAuthnChallengeResponse(ChallengeResponse):
"""Challenge used for WebAuthn authenticators"""
pass
class AuthenticatorValidateStageView(ChallengeStageView): class AuthenticatorValidateStageView(ChallengeStageView):
@ -32,10 +29,10 @@ class AuthenticatorValidateStageView(ChallengeStageView):
form_class = ValidationForm form_class = ValidationForm
def get_form_kwargs(self, **kwargs) -> dict[str, Any]: # def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs) # kwargs = super().get_form_kwargs(**kwargs)
kwargs["user"] = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER) # kwargs["user"] = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
return kwargs # return kwargs
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Check if a user is set, and check if the user has any devices """Check if a user is set, and check if the user has any devices
@ -68,9 +65,9 @@ class AuthenticatorValidateStageView(ChallengeStageView):
} }
) )
def post_challenge(self, challenge: Challenge) -> HttpResponse: def challenge_valid(self, challenge: ChallengeResponse) -> HttpResponse:
print(challenge) print(challenge)
return super().post_challenge(challenge) return HttpResponse()
# def form_valid(self, form: ValidationForm) -> HttpResponse: # def form_valid(self, form: ValidationForm) -> HttpResponse:
# """Verify OTP Token""" # """Verify OTP Token"""

View file

@ -51,5 +51,5 @@ class TestCaptchaStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )

View file

@ -51,7 +51,7 @@ class TestConsentStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
self.assertFalse(UserConsent.objects.filter(user=self.user).exists()) self.assertFalse(UserConsent.objects.filter(user=self.user).exists())
@ -82,7 +82,7 @@ class TestConsentStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
self.assertTrue( self.assertTrue(
UserConsent.objects.filter( UserConsent.objects.filter(
@ -119,7 +119,7 @@ class TestConsentStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
self.assertTrue( self.assertTrue(
UserConsent.objects.filter( UserConsent.objects.filter(

View file

@ -47,7 +47,7 @@ class TestDummyStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_form(self): def test_form(self):

View file

@ -126,7 +126,7 @@ class TestEmailStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
session = self.client.session session = self.client.session

View file

@ -6,7 +6,6 @@ from django.db.models import Q
from django.http import HttpResponse from django.http import HttpResponse
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.generic import FormView
from rest_framework.fields import CharField from rest_framework.fields import CharField
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -17,19 +16,20 @@ from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ( from authentik.flows.stage import (
PLAN_CONTEXT_PENDING_USER_IDENTIFIER, PLAN_CONTEXT_PENDING_USER_IDENTIFIER,
ChallengeStageView, ChallengeStageView,
StageView,
) )
from authentik.flows.views import SESSION_KEY_APPLICATION_PRE from authentik.flows.views import SESSION_KEY_APPLICATION_PRE
from authentik.stages.identification.forms import IdentificationForm
from authentik.stages.identification.models import IdentificationStage, UserFields from authentik.stages.identification.models import IdentificationStage, UserFields
LOGGER = get_logger() LOGGER = get_logger()
class IdentificationChallengeResponse(ChallengeResponse): class IdentificationChallengeResponse(ChallengeResponse):
"""Identification challenge"""
uid_field = CharField() uid_field = CharField()
# TODO: Validate here instead of challenge_valid()
class IdentificationStageView(ChallengeStageView): class IdentificationStageView(ChallengeStageView):
"""Form to identify the user""" """Form to identify the user"""
@ -66,12 +66,12 @@ class IdentificationStageView(ChallengeStageView):
if current_stage.enrollment_flow: if current_stage.enrollment_flow:
args["enroll_url"] = reverse( args["enroll_url"] = reverse(
"authentik_flows:flow-executor-shell", "authentik_flows:flow-executor-shell",
args={"flow_slug": current_stage.enrollment_flow.slug}, kwargs={"flow_slug": current_stage.enrollment_flow.slug},
) )
if current_stage.recovery_flow: if current_stage.recovery_flow:
args["recovery_url"] = reverse( args["recovery_url"] = reverse(
"authentik_flows:flow-executor-shell", "authentik_flows:flow-executor-shell",
args={"flow_slug": current_stage.recovery_flow.slug}, kwargs={"flow_slug": current_stage.recovery_flow.slug},
) )
args["primary_action"] = _("Log in") args["primary_action"] = _("Log in")

View file

@ -57,7 +57,7 @@ class TestIdentificationStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_invalid_with_username(self): def test_invalid_with_username(self):
@ -87,6 +87,7 @@ class TestIdentificationStage(TestCase):
flow = Flow.objects.create( flow = Flow.objects.create(
name="enroll-test", name="enroll-test",
slug="unique-enrollment-string", slug="unique-enrollment-string",
title="unique-enrollment-string",
designation=FlowDesignation.ENROLLMENT, designation=FlowDesignation.ENROLLMENT,
) )
self.stage.enrollment_flow = flow self.stage.enrollment_flow = flow
@ -103,7 +104,25 @@ class TestIdentificationStage(TestCase):
), ),
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertIn(flow.slug, force_str(response.content)) self.assertJSONEqual(
force_str(response.content),
{
"type": "native",
"component": "ak-stage-identification",
"args": {
"input_type": "email",
"enroll_url": "/flows/unique-enrollment-string/",
"primary_action": "Log in",
"sources": [
{
"icon_url": "/static/authentik/sources/.svg",
"name": "test",
"url": "/source/oauth/login/test/",
}
],
},
},
)
def test_recovery_flow(self): def test_recovery_flow(self):
"""Test that recovery flow is linked correctly""" """Test that recovery flow is linked correctly"""
@ -119,11 +138,28 @@ class TestIdentificationStage(TestCase):
stage=self.stage, stage=self.stage,
order=0, order=0,
) )
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}
), ),
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertIn(flow.slug, force_str(response.content)) self.assertJSONEqual(
force_str(response.content),
{
"type": "native",
"component": "ak-stage-identification",
"args": {
"input_type": "email",
"recovery_url": "/flows/unique-recovery-string/",
"primary_action": "Log in",
"sources": [
{
"icon_url": "/static/authentik/sources/.svg",
"name": "test",
"url": "/source/oauth/login/test/",
}
],
},
},
)

View file

@ -85,7 +85,7 @@ class TestUserLoginStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
self.stage.continue_flow_without_invitation = False self.stage.continue_flow_without_invitation = False
@ -124,5 +124,5 @@ class TestUserLoginStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )

View file

@ -110,7 +110,7 @@ class TestPasswordStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_invalid_password(self): def test_invalid_password(self):

View file

@ -164,7 +164,7 @@ class TestPromptStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
# Check that valid data has been saved # Check that valid data has been saved

View file

@ -85,7 +85,7 @@ class TestUserDeleteStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
self.assertFalse(User.objects.filter(username=self.username).exists()) self.assertFalse(User.objects.filter(username=self.username).exists())

View file

@ -53,7 +53,7 @@ class TestUserLoginStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
@patch( @patch(

View file

@ -49,7 +49,7 @@ class TestUserLogoutStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
def test_form(self): def test_form(self):

View file

@ -61,7 +61,7 @@ class TestUserWriteStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
user_qs = User.objects.filter( user_qs = User.objects.filter(
username=plan.context[PLAN_CONTEXT_PROMPT]["username"] username=plan.context[PLAN_CONTEXT_PROMPT]["username"]
@ -98,7 +98,7 @@ class TestUserWriteStage(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{"type": "redirect", "to": reverse("authentik_core:shell")}, {"args": {"to": reverse("authentik_core:shell")}, "type": "redirect"},
) )
user_qs = User.objects.filter( user_qs = User.objects.filter(
username=plan.context[PLAN_CONTEXT_PROMPT]["username"] username=plan.context[PLAN_CONTEXT_PROMPT]["username"]

View file

@ -15,7 +15,8 @@ enum ChallengeTypes {
interface Challenge { interface Challenge {
type: ChallengeTypes; type: ChallengeTypes;
args: { [key: string]: string }; args: { [key: string]: string };
component: string; component?: string;
title?: string;
} }
@customElement("ak-flow-executor") @customElement("ak-flow-executor")