diff --git a/authentik/stages/authenticator_static/stage.py b/authentik/stages/authenticator_static/stage.py index 13056de9e..3cb082a28 100644 --- a/authentik/stages/authenticator_static/stage.py +++ b/authentik/stages/authenticator_static/stage.py @@ -5,13 +5,10 @@ from rest_framework.fields import CharField, ListField from structlog.stdlib import get_logger from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge -from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_static.models import AuthenticatorStaticStage LOGGER = get_logger() -SESSION_STATIC_DEVICE = "static_device" -SESSION_STATIC_TOKENS = "static_device_tokens" class AuthenticatorStaticChallenge(WithUserInfoChallenge): @@ -33,7 +30,8 @@ class AuthenticatorStaticStageView(ChallengeStageView): response_class = AuthenticatorStaticChallengeResponse def get_challenge(self, *args, **kwargs) -> AuthenticatorStaticChallenge: - tokens: list[StaticToken] = self.request.session[SESSION_STATIC_TOKENS] + user = self.get_pending_user() + tokens: list[StaticToken] = StaticToken.objects.filter(device__user=user) return AuthenticatorStaticChallenge( data={ "type": ChallengeTypes.NATIVE.value, @@ -42,34 +40,32 @@ class AuthenticatorStaticStageView(ChallengeStageView): ) def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: - user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER) - if not user: + user = self.get_pending_user() + if not user.is_authenticated: LOGGER.debug("No pending user, continuing") return self.executor.stage_ok() - # Currently, this stage only supports one device per user. If the user already - # has a device, just skip to the next stage - if StaticDevice.objects.filter(user=user).exists(): - return self.executor.stage_ok() - stage: AuthenticatorStaticStage = self.executor.current_stage - if SESSION_STATIC_DEVICE not in self.request.session: - device = StaticDevice(user=user, confirmed=False, name="Static Token") - tokens = [] - for _ in range(0, stage.token_count): - tokens.append(StaticToken(device=device, token=StaticToken.random_token())) - self.request.session[SESSION_STATIC_DEVICE] = device - self.request.session[SESSION_STATIC_TOKENS] = tokens + devices = StaticDevice.objects.filter(user=user) + # Currently, this stage only supports one device per user. If the user already + # has a device, just skip to the next stage + if devices.exists(): + if not any(x.confirmed for x in devices): + return super().get(request, *args, **kwargs) + return self.executor.stage_ok() + + device = StaticDevice.objects.create(user=user, confirmed=False, name="Static Token") + for _ in range(0, stage.token_count): + StaticToken.objects.create(device=device, token=StaticToken.random_token()) return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: """Verify OTP Token""" - device: StaticDevice = self.request.session[SESSION_STATIC_DEVICE] + user = self.get_pending_user() + device: StaticDevice = StaticDevice.objects.filter(user=user).first() + if not device: + return self.executor.stage_invalid() device.confirmed = True device.save() - for token in self.request.session[SESSION_STATIC_TOKENS]: - token.save() - del self.request.session[SESSION_STATIC_DEVICE] - del self.request.session[SESSION_STATIC_TOKENS] return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_totp/stage.py b/authentik/stages/authenticator_totp/stage.py index f21bee407..454ff013d 100644 --- a/authentik/stages/authenticator_totp/stage.py +++ b/authentik/stages/authenticator_totp/stage.py @@ -14,13 +14,11 @@ from authentik.flows.challenge import ( ChallengeTypes, WithUserInfoChallenge, ) -from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage from authentik.stages.authenticator_totp.settings import OTP_TOTP_ISSUER LOGGER = get_logger() -SESSION_TOTP_DEVICE = "totp_device" class AuthenticatorTOTPChallenge(WithUserInfoChallenge): @@ -54,7 +52,8 @@ class AuthenticatorTOTPStageView(ChallengeStageView): response_class = AuthenticatorTOTPChallengeResponse def get_challenge(self, *args, **kwargs) -> Challenge: - device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE] + user = self.get_pending_user() + device: TOTPDevice = TOTPDevice.objects.filter(user=user).first() return AuthenticatorTOTPChallenge( data={ "type": ChallengeTypes.NATIVE.value, @@ -66,34 +65,37 @@ class AuthenticatorTOTPStageView(ChallengeStageView): def get_response_instance(self, data: QueryDict) -> ChallengeResponse: response = super().get_response_instance(data) - response.device = self.request.session.get(SESSION_TOTP_DEVICE) + user = self.get_pending_user() + response.device = TOTPDevice.objects.filter(user=user).first() return response def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: - user = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER) - if not user: + user = self.get_pending_user() + if not user.is_authenticated: LOGGER.debug("No pending user, continuing") return self.executor.stage_ok() - # Currently, this stage only supports one device per user. If the user already - # has a device, just skip to the next stage - if TOTPDevice.objects.filter(user=user).exists(): - return self.executor.stage_ok() - stage: AuthenticatorTOTPStage = self.executor.current_stage - if SESSION_TOTP_DEVICE not in self.request.session: - device = TOTPDevice( - user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" - ) + devices = TOTPDevice.objects.filter(user=user) + # Currently, this stage only supports one device per user. If the user already + # has a device, just skip to the next stage + if devices.exists(): + if not any(x.confirmed for x in devices): + return super().get(request, *args, **kwargs) + return self.executor.stage_ok() - self.request.session[SESSION_TOTP_DEVICE] = device + TOTPDevice.objects.create( + user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" + ) return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: """TOTP Token is validated by challenge""" - device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE] + user = self.get_pending_user() + device: TOTPDevice = TOTPDevice.objects.filter(user=user).first() + if not device: + return self.executor.stage_invalid() device.confirmed = True device.save() - del self.request.session[SESSION_TOTP_DEVICE] return self.executor.stage_ok()