diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index 76a9bd0ac..7776c5c75 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -420,6 +420,8 @@ class IDToken: id_dict.pop("c_hash") if not self.amr: id_dict.pop("amr") + if not self.auth_time: + id_dict.pop("auth_time") id_dict.pop("claims") id_dict.update(self.claims) return id_dict @@ -496,29 +498,10 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): f"selected: {self.provider.sub_mode}" ) ) - amr = [] # Convert datetimes into timestamps. now = datetime.now() iat_time = int(now.timestamp()) exp_time = int(self.expires.timestamp()) - # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time - # Fallback in case we can't find any login events - auth_time = now - auth_event = get_login_event(request) - if auth_event: - auth_time = auth_event.created - # Also check which method was used for authentication - method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") - method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) - if method == "password": - amr.append(AMR_PASSWORD) - if method == "auth_webauthn_pwl": - amr.append(AMR_WEBAUTHN) - if "mfa_devices" in method_args: - if len(amr) > 0: - amr.append(AMR_MFA) - - auth_timestamp = int(auth_time.timestamp()) token = IDToken( iss=self.provider.get_issuer(request), @@ -526,10 +509,27 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): aud=self.provider.client_id, exp=exp_time, iat=iat_time, - auth_time=auth_timestamp, - amr=amr if amr else None, ) + # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time + auth_event = get_login_event(request) + if auth_event: + auth_time = auth_event.created + token.auth_time = int(auth_time.timestamp()) + # Also check which method was used for authentication + method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") + method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) + amr = [] + if method == "password": + amr.append(AMR_PASSWORD) + if method == "auth_webauthn_pwl": + amr.append(AMR_WEBAUTHN) + if "mfa_devices" in method_args: + if len(amr) > 0: + amr.append(AMR_MFA) + if amr: + token.amr = amr + # Include (or not) user standard claims in the id_token. if self.provider.include_claims_in_id_token: from authentik.providers.oauth2.views.userinfo import UserInfoView diff --git a/authentik/providers/oauth2/tests/test_introspect.py b/authentik/providers/oauth2/tests/test_introspect.py index 83be7ff69..94209e5d6 100644 --- a/authentik/providers/oauth2/tests/test_introspect.py +++ b/authentik/providers/oauth2/tests/test_introspect.py @@ -59,7 +59,6 @@ class TesOAuth2Introspection(OAuthTestCase): res.content.decode(), { "acr": ACR_AUTHENTIK_DEFAULT, - "auth_time": None, "aud": None, "sub": "bar", "exp": None, diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index 7801ddbe6..8df5c1ed4 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -16,7 +16,6 @@ class OAuthTestCase(TestCase): required_jwt_keys = [ "exp", "iat", - "auth_time", "acr", "sub", "iss", @@ -48,6 +47,7 @@ class OAuthTestCase(TestCase): self.assert_non_none_or_unset(id_token, "nonce") self.assert_non_none_or_unset(id_token, "c_hash") self.assert_non_none_or_unset(id_token, "amr") + self.assert_non_none_or_unset(id_token, "auth_time") for key in self.required_jwt_keys: self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")