diff --git a/authentik/root/messages/storage.py b/authentik/root/messages/storage.py index 33710b823..0999c276c 100644 --- a/authentik/root/messages/storage.py +++ b/authentik/root/messages/storage.py @@ -18,8 +18,6 @@ class ChannelsStorage(FallbackStorage): def _store(self, messages: list[Message], response, *args, **kwargs): prefix = f"user_{self.request.session.session_key}_messages_" keys = cache.keys(f"{prefix}*") - if len(keys) < 1: - return super()._store(messages, response, *args, **kwargs) for key in keys: uid = key.replace(prefix, "") for message in messages: @@ -32,4 +30,3 @@ class ChannelsStorage(FallbackStorage): "message": message.message, }, ) - return None diff --git a/authentik/stages/authenticator_validate/challenge.py b/authentik/stages/authenticator_validate/challenge.py index 244589682..96a14a06e 100644 --- a/authentik/stages/authenticator_validate/challenge.py +++ b/authentik/stages/authenticator_validate/challenge.py @@ -17,7 +17,6 @@ from webauthn.webauthn import ( from authentik.core.models import User from authentik.lib.templatetags.authentik_utils import avatar -from authentik.stages.authenticator_validate.models import DeviceClasses from authentik.stages.authenticator_webauthn.models import WebAuthnDevice from authentik.stages.authenticator_webauthn.utils import generate_challenge @@ -70,35 +69,19 @@ def get_webauthn_challenge(request: HttpRequest, device: WebAuthnDevice) -> dict return webauthn_assertion_options.assertion_dict -def validate_challenge( - challenge: DeviceChallenge, request: HttpRequest, user: User -) -> DeviceChallenge: - """main entry point for challenge validation""" - if challenge.validated_data["device_class"] in ( - DeviceClasses.TOTP, - DeviceClasses.STATIC, - ): - return validate_challenge_code(challenge, request, user) - return validate_challenge_webauthn(challenge, request, user) - - -def validate_challenge_code( - challenge: DeviceChallenge, request: HttpRequest, user: User -) -> DeviceChallenge: +def validate_challenge_code(code: str, request: HttpRequest, user: User) -> str: """Validate code-based challenges. We test against every device, on purpose, as the user mustn't choose between totp and static devices.""" - device = match_token(user, challenge.validated_data["challenge"].get("code", None)) + device = match_token(user, code) if not device: raise ValidationError(_("Invalid Token")) - return challenge + return code -def validate_challenge_webauthn( - challenge: DeviceChallenge, request: HttpRequest, user: User -) -> DeviceChallenge: +def validate_challenge_webauthn(data: dict, request: HttpRequest, user: User) -> dict: """Validate WebAuthn Challenge""" challenge = request.session.get("challenge") - assertion_response = challenge.validated_data["challenge"] + assertion_response = data["challenge"] credential_id = assertion_response.get("id") device = WebAuthnDevice.objects.filter(credential_id=credential_id).first() @@ -134,4 +117,4 @@ def validate_challenge_webauthn( raise ValidationError("Assertion failed") from exc device.set_sign_count(sign_count) - return challenge + return data diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index 15c8aa434..f8ed7e842 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -1,11 +1,10 @@ """Authenticator Validation""" from django.http import HttpRequest, HttpResponse -from django.http.request import QueryDict from django_otp import devices_for_user -from rest_framework.fields import ListField +from rest_framework.fields import CharField, JSONField, ListField +from rest_framework.serializers import ValidationError from structlog.stdlib import get_logger -from authentik.core.models import User from authentik.flows.challenge import ( ChallengeResponse, ChallengeTypes, @@ -17,15 +16,17 @@ from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_validate.challenge import ( DeviceChallenge, get_challenge_for_device, - validate_challenge, + validate_challenge_code, + validate_challenge_webauthn, +) +from authentik.stages.authenticator_validate.models import ( + AuthenticatorValidateStage, + DeviceClasses, ) -from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses LOGGER = get_logger() -PER_DEVICE_CLASSES = [ - DeviceClasses.WEBAUTHN -] +PER_DEVICE_CLASSES = [DeviceClasses.WEBAUTHN] class AuthenticatorChallenge(WithUserInfoChallenge): @@ -34,15 +35,46 @@ class AuthenticatorChallenge(WithUserInfoChallenge): device_challenges = ListField(child=DeviceChallenge()) -class AuthenticatorChallengeResponse(ChallengeResponse, DeviceChallenge): - """Challenge used for Code-based authenticators""" +class AuthenticatorChallengeResponse(ChallengeResponse): + """Challenge used for Code-based and WebAuthn authenticators""" - request: HttpRequest - user: User + code = CharField(required=False) + webauthn = JSONField(required=False) - def validate_challenge(self, value: dict): - """Validate response""" - return validate_challenge(value, self.request, self.user) + def validate_code(self, code: str) -> str: + """Validate code-based response, raise error if code isn't allowed""" + device_challenges: list[dict] = self.stage.request.session.get( + "device_challenges" + ) + if not any( + x["device_class"] in (DeviceClasses.TOTP, DeviceClasses.STATIC) + for x in device_challenges + ): + raise ValidationError("Got code but no compatible device class allowed") + return validate_challenge_code( + code, self.stage.request, self.stage.get_pending_user() + ) + + def validate_webauthn(self, webauthn: dict) -> dict: + """Validate webauthn response, raise error if webauthn wasn't allowed + or response is invalid""" + device_challenges: list[dict] = self.stage.request.session.get( + "device_challenges" + ) + if not any( + x["device_class"] in (DeviceClasses.WEBAUTHN) for x in device_challenges + ): + raise ValidationError("Got webauthn but no compatible device class allowed") + return validate_challenge_webauthn( + webauthn, self.stage.request, self.stage.get_pending_user() + ) + + def validate(self, data: dict): + # Checking if the given data is from a valid device class is done above + # Here we only check if the any data was sent at all + if "code" not in data and "webauthn" not in data: + raise ValidationError("Empty response") + return data class AuthenticatorValidateStageView(ChallengeStageView): @@ -51,6 +83,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): response_class = AuthenticatorChallengeResponse def get_device_challenges(self) -> list[dict]: + """Get a list of all device challenges applicable for the current stage""" challenges = [] user_devices = devices_for_user(self.get_pending_user()) @@ -112,14 +145,9 @@ class AuthenticatorValidateStageView(ChallengeStageView): } ) - def get_response_instance(self, data: QueryDict) -> ChallengeResponse: - response: AuthenticatorChallengeResponse = super().get_response_instance(data) - response.request = self.request - response.user = self.get_pending_user() - return response - + # pylint: disable=unused-argument def challenge_valid( self, challenge: AuthenticatorChallengeResponse ) -> HttpResponse: - print(challenge) - return HttpResponse() + # All validation is done by the serializer + return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py index 305f8a106..3e273fa1f 100644 --- a/authentik/stages/authenticator_webauthn/stage.py +++ b/authentik/stages/authenticator_webauthn/stage.py @@ -18,7 +18,11 @@ from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.lib.templatetags.authentik_utils import avatar from authentik.stages.authenticator_webauthn.models import WebAuthnDevice -from authentik.stages.authenticator_webauthn.utils import generate_challenge, get_origin, get_rp_id +from authentik.stages.authenticator_webauthn.utils import ( + generate_challenge, + get_origin, + get_rp_id, +) RP_NAME = "authentik" diff --git a/authentik/stages/authenticator_webauthn/utils.py b/authentik/stages/authenticator_webauthn/utils.py index 6c0509e07..881066ef8 100644 --- a/authentik/stages/authenticator_webauthn/utils.py +++ b/authentik/stages/authenticator_webauthn/utils.py @@ -1,6 +1,7 @@ """webauthn utils""" import base64 import os + from django.http import HttpRequest CHALLENGE_DEFAULT_BYTE_LEN = 32 @@ -23,6 +24,7 @@ def generate_challenge(challenge_len=CHALLENGE_DEFAULT_BYTE_LEN): challenge_base64 = challenge_base64.decode("utf-8") return challenge_base64 + def get_rp_id(request: HttpRequest) -> str: """Get hostname from http request, without port""" host = request.get_host() @@ -30,6 +32,7 @@ def get_rp_id(request: HttpRequest) -> str: return host.split(":")[0] return host + def get_origin(request: HttpRequest) -> str: """Return Origin by building an absolute URL and removing the trailing slash""" diff --git a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts index 8cf305848..e9acacf76 100644 --- a/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts +++ b/web/src/elements/stages/authenticator_validate/AuthenticatorValidateStage.ts @@ -23,7 +23,8 @@ export interface AuthenticatorValidateStageChallenge extends WithUserInfoChallen } export interface AuthenticatorValidateStageChallengeResponse { - response: DeviceChallenge; + code: string; + webauthn: string; } @customElement("ak-stage-authenticator-validate") @@ -145,13 +146,15 @@ export class AuthenticatorValidateStage extends BaseStage implements StageHost { ${gettext("Select an identification method.")}
`} -