ATH-01-014: save authenticator validation state in flow context

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

bugfixes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-06-16 15:16:27 +02:00
parent ce77d82b24
commit f15cac39c8
No known key found for this signature in database
6 changed files with 80 additions and 63 deletions

View File

@ -82,8 +82,9 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
def retrieve_file(self) -> str:
"""Get blueprint from path"""
try:
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(CONFIG.y("blueprints_dir")):
base = Path(CONFIG.y("blueprints_dir"))
full_path = base.joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(str(base.resolve())):
raise BlueprintRetrievalFailed("Invalid blueprint path")
with full_path.open("r", encoding="utf-8") as _file:
return _file.read()

View File

@ -204,12 +204,12 @@ class ChallengeStageView(StageView):
for field, errors in response.errors.items():
for error in errors:
full_errors.setdefault(field, [])
full_errors[field].append(
{
"string": str(error),
"code": error.code,
}
)
field_error = {
"string": str(error),
}
if hasattr(error, "code"):
field_error["code"] = error.code
full_errors[field].append(field_error)
challenge_response.initial_data["response_errors"] = full_errors
if not challenge_response.is_valid():
self.logger.error(

View File

@ -132,9 +132,9 @@ class TestPolicyProcess(TestCase):
)
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
http_request = self.factory.get(reverse("authentik_core:impersonate-end"))
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user
http_request.resolver_match = resolve(reverse("authentik_core:impersonate-end"))
http_request.resolver_match = resolve(reverse("authentik_api:user-impersonate-end"))
request = PolicyRequest(self.user)
request.set_http_request(http_request)

View File

@ -36,9 +36,9 @@ from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_ME
COOKIE_NAME_MFA = "authentik_mfa"
SESSION_KEY_STAGES = "authentik/stages/authenticator_validate/stages"
SESSION_KEY_SELECTED_STAGE = "authentik/stages/authenticator_validate/selected_stage"
SESSION_KEY_DEVICE_CHALLENGES = "authentik/stages/authenticator_validate/device_challenges"
PLAN_CONTEXT_STAGES = "goauthentik.io/stages/authenticator_validate/stages"
PLAN_CONTEXT_SELECTED_STAGE = "goauthentik.io/stages/authenticator_validate/selected_stage"
PLAN_CONTEXT_DEVICE_CHALLENGES = "goauthentik.io/stages/authenticator_validate/device_challenges"
class SelectableStageSerializer(PassiveSerializer):
@ -72,8 +72,8 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-validate")
def _challenge_allowed(self, classes: list):
device_challenges: list[dict] = self.stage.request.session.get(
SESSION_KEY_DEVICE_CHALLENGES, []
device_challenges: list[dict] = self.stage.executor.plan.context.get(
PLAN_CONTEXT_DEVICE_CHALLENGES, []
)
if not any(x["device_class"] in classes for x in device_challenges):
raise ValidationError("No compatible device class allowed")
@ -103,7 +103,9 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
"""Check which challenge the user has selected. Actual logic only used for SMS stage."""
# First check if the challenge is valid
allowed = False
for device_challenge in self.stage.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []):
for device_challenge in self.stage.executor.plan.context.get(
PLAN_CONTEXT_DEVICE_CHALLENGES, []
):
if device_challenge.get("device_class", "") == challenge.get(
"device_class", ""
) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""):
@ -121,11 +123,11 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
def validate_selected_stage(self, stage_pk: str) -> str:
"""Check that the selected stage is valid"""
stages = self.stage.request.session.get(SESSION_KEY_STAGES, [])
stages = self.stage.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
if not any(str(stage.pk) == stage_pk for stage in stages):
raise ValidationError("Selected stage is invalid")
self.stage.logger.debug("Setting selected stage to ", stage=stage_pk)
self.stage.request.session[SESSION_KEY_SELECTED_STAGE] = stage_pk
self.stage.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = stage_pk
return stage_pk
def validate(self, attrs: dict):
@ -230,7 +232,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
else:
self.logger.debug("No pending user, continuing")
return self.executor.stage_ok()
self.request.session[SESSION_KEY_DEVICE_CHALLENGES] = challenges
self.executor.plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = challenges
# No allowed devices
if len(challenges) < 1:
@ -263,23 +265,23 @@ class AuthenticatorValidateStageView(ChallengeStageView):
if stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
self.request.session[SESSION_KEY_SELECTED_STAGE] = next_stage
self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage
# Because that normal execution only happens on post, we directly inject it here and
# return it
self.executor.plan.insert_stage(next_stage)
return self.executor.stage_ok()
stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
self.request.session[SESSION_KEY_STAGES] = stages
self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages
return super().get(self.request, *args, **kwargs)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
res = super().post(request, *args, **kwargs)
if (
SESSION_KEY_SELECTED_STAGE in self.request.session
PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context
and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
):
self.logger.debug("Got selected stage in session, running that")
stage_pk = self.request.session.get(SESSION_KEY_SELECTED_STAGE)
self.logger.debug("Got selected stage in context, running that")
stage_pk = self.executor.plan.context(PLAN_CONTEXT_SELECTED_STAGE)
# Because the foreign key to stage.configuration_stage points to
# a base stage class, we need to do another lookup
stage = Stage.objects.get_subclass(pk=stage_pk)
@ -290,8 +292,8 @@ class AuthenticatorValidateStageView(ChallengeStageView):
return res
def get_challenge(self) -> AuthenticatorValidationChallenge:
challenges = self.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, [])
stages = self.request.session.get(SESSION_KEY_STAGES, [])
challenges = self.executor.plan.context.get(PLAN_CONTEXT_DEVICE_CHALLENGES, [])
stages = self.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
stage_challenges = []
for stage in stages:
serializer = SelectableStageSerializer(
@ -306,6 +308,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
stage_challenges.append(serializer.data)
return AuthenticatorValidationChallenge(
data={
"component": "ak-stage-authenticator-validate",
"type": ChallengeTypes.NATIVE.value,
"device_challenges": challenges,
"configuration_stages": stage_challenges,
@ -385,8 +388,3 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device": webauthn_device,
}
return self.set_valid_mfa_cookie(response.device)
def cleanup(self):
self.request.session.pop(SESSION_KEY_STAGES, None)
self.request.session.pop(SESSION_KEY_SELECTED_STAGE, None)
self.request.session.pop(SESSION_KEY_DEVICE_CHALLENGES, None)

View File

@ -1,26 +1,19 @@
"""Test validator stage"""
from unittest.mock import MagicMock, patch
from django.contrib.sessions.middleware import SessionMiddleware
from django.test.client import RequestFactory
from django.urls.base import reverse
from rest_framework.exceptions import ValidationError
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation, FlowStageBinding, NotConfiguredAction
from authentik.flows.planner import FlowPlan
from authentik.flows.stage import StageView
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id, generate_key
from authentik.lib.tests.utils import dummy_get_response
from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import (
SESSION_KEY_DEVICE_CHALLENGES,
AuthenticatorValidationChallengeResponse,
)
from authentik.stages.authenticator_validate.stage import PLAN_CONTEXT_DEVICE_CHALLENGES
from authentik.stages.identification.models import IdentificationStage, UserFields
@ -86,12 +79,17 @@ class AuthenticatorValidateStageTests(FlowTestCase):
def test_validate_selected_challenge(self):
"""Test validate_selected_challenge"""
# Prepare request with session
request = self.request_factory.get("/")
flow = create_test_flow()
stage = AuthenticatorValidateStage.objects.create(
name=generate_id(),
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.STATIC, DeviceClasses.TOTP],
)
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(request)
request.session[SESSION_KEY_DEVICE_CHALLENGES] = [
session = self.client.session
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
"device_class": "static",
"device_uid": "1",
@ -101,23 +99,43 @@ class AuthenticatorValidateStageTests(FlowTestCase):
"device_uid": "2",
},
]
request.session.save()
session[SESSION_KEY_PLAN] = plan
session.save()
res = AuthenticatorValidationChallengeResponse()
res.stage = StageView(FlowExecutorView())
res.stage.request = request
with self.assertRaises(ValidationError):
res.validate_selected_challenge(
{
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
data={
"selected_challenge": {
"device_class": "baz",
"device_uid": "quox",
"challenge": {},
}
)
res.validate_selected_challenge(
{
"device_class": "static",
"device_uid": "1",
}
},
)
self.assertStageResponse(
response,
flow,
response_errors={
"selected_challenge": [{"string": "invalid challenge selected", "code": "invalid"}]
},
component="ak-stage-authenticator-validate",
)
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
data={
"selected_challenge": {
"device_class": "static",
"device_uid": "1",
"challenge": {},
},
},
)
self.assertStageResponse(
response,
flow,
response_errors={"non_field_errors": [{"string": "Empty response", "code": "invalid"}]},
component="ak-stage-authenticator-validate",
)
@patch(

View File

@ -22,7 +22,7 @@ from authentik.stages.authenticator_validate.challenge import (
)
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import (
SESSION_KEY_DEVICE_CHALLENGES,
PLAN_CONTEXT_DEVICE_CHALLENGES,
AuthenticatorValidateStageView,
)
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
@ -211,14 +211,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_DEVICE_CHALLENGES] = [
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
"device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk,
"challenge": {},
}
]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
@ -283,14 +283,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id()))
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_DEVICE_CHALLENGES] = [
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{
"device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk,
"challenge": {},
}
]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)