From 0f169f176d238ba711d48acf2d947cf3517c149c Mon Sep 17 00:00:00 2001
From: Jens Langhammer
Date: Thu, 25 Feb 2021 12:06:05 +0100
Subject: [PATCH] stages/authenticator_validate: implement validation, add
button to go back to device picker
---
authentik/root/messages/storage.py | 3 -
.../authenticator_validate/challenge.py | 29 ++-----
.../stages/authenticator_validate/stage.py | 74 +++++++++++-----
.../stages/authenticator_webauthn/stage.py | 6 +-
.../stages/authenticator_webauthn/utils.py | 3 +
.../AuthenticatorValidateStage.ts | 19 +++--
.../AuthenticatorValidateStageCode.ts | 84 +++++++++++--------
.../AuthenticatorValidateStageWebAuthn.ts | 31 +++++--
8 files changed, 146 insertions(+), 103 deletions(-)
diff --git a/authentik/root/messages/storage.py b/authentik/root/messages/storage.py
index 33710b823..0999c276c 100644
--- a/authentik/root/messages/storage.py
+++ b/authentik/root/messages/storage.py
@@ -18,8 +18,6 @@ class ChannelsStorage(FallbackStorage):
def _store(self, messages: list[Message], response, *args, **kwargs):
prefix = f"user_{self.request.session.session_key}_messages_"
keys = cache.keys(f"{prefix}*")
- if len(keys) < 1:
- return super()._store(messages, response, *args, **kwargs)
for key in keys:
uid = key.replace(prefix, "")
for message in messages:
@@ -32,4 +30,3 @@ class ChannelsStorage(FallbackStorage):
"message": message.message,
},
)
- return None
diff --git a/authentik/stages/authenticator_validate/challenge.py b/authentik/stages/authenticator_validate/challenge.py
index 244589682..96a14a06e 100644
--- a/authentik/stages/authenticator_validate/challenge.py
+++ b/authentik/stages/authenticator_validate/challenge.py
@@ -17,7 +17,6 @@ from webauthn.webauthn import (
from authentik.core.models import User
from authentik.lib.templatetags.authentik_utils import avatar
-from authentik.stages.authenticator_validate.models import DeviceClasses
from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
from authentik.stages.authenticator_webauthn.utils import generate_challenge
@@ -70,35 +69,19 @@ def get_webauthn_challenge(request: HttpRequest, device: WebAuthnDevice) -> dict
return webauthn_assertion_options.assertion_dict
-def validate_challenge(
- challenge: DeviceChallenge, request: HttpRequest, user: User
-) -> DeviceChallenge:
- """main entry point for challenge validation"""
- if challenge.validated_data["device_class"] in (
- DeviceClasses.TOTP,
- DeviceClasses.STATIC,
- ):
- return validate_challenge_code(challenge, request, user)
- return validate_challenge_webauthn(challenge, request, user)
-
-
-def validate_challenge_code(
- challenge: DeviceChallenge, request: HttpRequest, user: User
-) -> DeviceChallenge:
+def validate_challenge_code(code: str, request: HttpRequest, user: User) -> str:
"""Validate code-based challenges. We test against every device, on purpose, as
the user mustn't choose between totp and static devices."""
- device = match_token(user, challenge.validated_data["challenge"].get("code", None))
+ device = match_token(user, code)
if not device:
raise ValidationError(_("Invalid Token"))
- return challenge
+ return code
-def validate_challenge_webauthn(
- challenge: DeviceChallenge, request: HttpRequest, user: User
-) -> DeviceChallenge:
+def validate_challenge_webauthn(data: dict, request: HttpRequest, user: User) -> dict:
"""Validate WebAuthn Challenge"""
challenge = request.session.get("challenge")
- assertion_response = challenge.validated_data["challenge"]
+ assertion_response = data["challenge"]
credential_id = assertion_response.get("id")
device = WebAuthnDevice.objects.filter(credential_id=credential_id).first()
@@ -134,4 +117,4 @@ def validate_challenge_webauthn(
raise ValidationError("Assertion failed") from exc
device.set_sign_count(sign_count)
- return challenge
+ return data
diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py
index 15c8aa434..f8ed7e842 100644
--- a/authentik/stages/authenticator_validate/stage.py
+++ b/authentik/stages/authenticator_validate/stage.py
@@ -1,11 +1,10 @@
"""Authenticator Validation"""
from django.http import HttpRequest, HttpResponse
-from django.http.request import QueryDict
from django_otp import devices_for_user
-from rest_framework.fields import ListField
+from rest_framework.fields import CharField, JSONField, ListField
+from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger
-from authentik.core.models import User
from authentik.flows.challenge import (
ChallengeResponse,
ChallengeTypes,
@@ -17,15 +16,17 @@ from authentik.flows.stage import ChallengeStageView
from authentik.stages.authenticator_validate.challenge import (
DeviceChallenge,
get_challenge_for_device,
- validate_challenge,
+ validate_challenge_code,
+ validate_challenge_webauthn,
+)
+from authentik.stages.authenticator_validate.models import (
+ AuthenticatorValidateStage,
+ DeviceClasses,
)
-from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
LOGGER = get_logger()
-PER_DEVICE_CLASSES = [
- DeviceClasses.WEBAUTHN
-]
+PER_DEVICE_CLASSES = [DeviceClasses.WEBAUTHN]
class AuthenticatorChallenge(WithUserInfoChallenge):
@@ -34,15 +35,46 @@ class AuthenticatorChallenge(WithUserInfoChallenge):
device_challenges = ListField(child=DeviceChallenge())
-class AuthenticatorChallengeResponse(ChallengeResponse, DeviceChallenge):
- """Challenge used for Code-based authenticators"""
+class AuthenticatorChallengeResponse(ChallengeResponse):
+ """Challenge used for Code-based and WebAuthn authenticators"""
- request: HttpRequest
- user: User
+ code = CharField(required=False)
+ webauthn = JSONField(required=False)
- def validate_challenge(self, value: dict):
- """Validate response"""
- return validate_challenge(value, self.request, self.user)
+ def validate_code(self, code: str) -> str:
+ """Validate code-based response, raise error if code isn't allowed"""
+ device_challenges: list[dict] = self.stage.request.session.get(
+ "device_challenges"
+ )
+ if not any(
+ x["device_class"] in (DeviceClasses.TOTP, DeviceClasses.STATIC)
+ for x in device_challenges
+ ):
+ raise ValidationError("Got code but no compatible device class allowed")
+ return validate_challenge_code(
+ code, self.stage.request, self.stage.get_pending_user()
+ )
+
+ def validate_webauthn(self, webauthn: dict) -> dict:
+ """Validate webauthn response, raise error if webauthn wasn't allowed
+ or response is invalid"""
+ device_challenges: list[dict] = self.stage.request.session.get(
+ "device_challenges"
+ )
+ if not any(
+ x["device_class"] in (DeviceClasses.WEBAUTHN) for x in device_challenges
+ ):
+ raise ValidationError("Got webauthn but no compatible device class allowed")
+ return validate_challenge_webauthn(
+ webauthn, self.stage.request, self.stage.get_pending_user()
+ )
+
+ def validate(self, data: dict):
+ # Checking if the given data is from a valid device class is done above
+ # Here we only check if the any data was sent at all
+ if "code" not in data and "webauthn" not in data:
+ raise ValidationError("Empty response")
+ return data
class AuthenticatorValidateStageView(ChallengeStageView):
@@ -51,6 +83,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
response_class = AuthenticatorChallengeResponse
def get_device_challenges(self) -> list[dict]:
+ """Get a list of all device challenges applicable for the current stage"""
challenges = []
user_devices = devices_for_user(self.get_pending_user())
@@ -112,14 +145,9 @@ class AuthenticatorValidateStageView(ChallengeStageView):
}
)
- def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
- response: AuthenticatorChallengeResponse = super().get_response_instance(data)
- response.request = self.request
- response.user = self.get_pending_user()
- return response
-
+ # pylint: disable=unused-argument
def challenge_valid(
self, challenge: AuthenticatorChallengeResponse
) -> HttpResponse:
- print(challenge)
- return HttpResponse()
+ # All validation is done by the serializer
+ return self.executor.stage_ok()
diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py
index 305f8a106..3e273fa1f 100644
--- a/authentik/stages/authenticator_webauthn/stage.py
+++ b/authentik/stages/authenticator_webauthn/stage.py
@@ -18,7 +18,11 @@ from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import ChallengeStageView
from authentik.lib.templatetags.authentik_utils import avatar
from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
-from authentik.stages.authenticator_webauthn.utils import generate_challenge, get_origin, get_rp_id
+from authentik.stages.authenticator_webauthn.utils import (
+ generate_challenge,
+ get_origin,
+ get_rp_id,
+)
RP_NAME = "authentik"
diff --git a/authentik/stages/authenticator_webauthn/utils.py b/authentik/stages/authenticator_webauthn/utils.py
index 6c0509e07..881066ef8 100644
--- a/authentik/stages/authenticator_webauthn/utils.py
+++ b/authentik/stages/authenticator_webauthn/utils.py
@@ -1,6 +1,7 @@
"""webauthn utils"""
import base64
import os
+
from django.http import HttpRequest
CHALLENGE_DEFAULT_BYTE_LEN = 32
@@ -23,6 +24,7 @@ def generate_challenge(challenge_len=CHALLENGE_DEFAULT_BYTE_LEN):
challenge_base64 = challenge_base64.decode("utf-8")
return challenge_base64
+
def get_rp_id(request: HttpRequest) -> str:
"""Get hostname from http request, without port"""
host = request.get_host()
@@ -30,6 +32,7 @@ def get_rp_id(request: HttpRequest) -> str:
return host.split(":")[0]
return host
+
def get_origin(request: HttpRequest) -> str:
"""Return Origin by building an absolute URL and removing the
trailing slash"""
diff --git a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts
index 8cf305848..e9acacf76 100644
--- a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts
+++ b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts
@@ -23,7 +23,8 @@ export interface AuthenticatorValidateStageChallenge extends WithUserInfoChallen
}
export interface AuthenticatorValidateStageChallengeResponse {
- response: DeviceChallenge;
+ code: string;
+ webauthn: string;
}
@customElement("ak-stage-authenticator-validate")
@@ -145,13 +146,15 @@ export class AuthenticatorValidateStage extends BaseStage implements StageHost {
${gettext("Select an identification method.")}
`}
-
- ${this.selectedDeviceChallenge ? this.renderDeviceChallenge() : this.renderDevicePicker()}
-
- `;
+ ${this.selectedDeviceChallenge ?
+ this.renderDeviceChallenge() :
+ html`
+ ${this.renderDevicePicker()}
+
+ `}`;
}
}
diff --git a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStageCode.ts b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStageCode.ts
index 3bc5101fc..712c9677b 100644
--- a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStageCode.ts
+++ b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStageCode.ts
@@ -2,7 +2,8 @@ import { gettext } from "django";
import { CSSResult, customElement, html, property, TemplateResult } from "lit-element";
import { COMMON_STYLES } from "../../../common/styles";
import { BaseStage } from "../base";
-import { AuthenticatorValidateStageChallenge, DeviceChallenge } from "./AuthenticatorValidateStage";
+import { AuthenticatorValidateStage, AuthenticatorValidateStageChallenge, DeviceChallenge } from "./AuthenticatorValidateStage";
+import "../form";
@customElement("ak-stage-authenticator-validate-code")
export class AuthenticatorValidateStageWebCode extends BaseStage {
@@ -21,44 +22,55 @@ export class AuthenticatorValidateStageWebCode extends BaseStage {
if (!this.challenge) {
return html``;
}
- return html`