providers/oauth2: only set auth_time in ID token when a login event is stored in the session

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-01-29 16:00:19 +01:00
parent c2b4d14af5
commit 96eeb91493
No known key found for this signature in database
3 changed files with 22 additions and 23 deletions

View File

@ -420,6 +420,8 @@ class IDToken:
id_dict.pop("c_hash") id_dict.pop("c_hash")
if not self.amr: if not self.amr:
id_dict.pop("amr") id_dict.pop("amr")
if not self.auth_time:
id_dict.pop("auth_time")
id_dict.pop("claims") id_dict.pop("claims")
id_dict.update(self.claims) id_dict.update(self.claims)
return id_dict return id_dict
@ -496,29 +498,10 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
f"selected: {self.provider.sub_mode}" f"selected: {self.provider.sub_mode}"
) )
) )
amr = []
# Convert datetimes into timestamps. # Convert datetimes into timestamps.
now = datetime.now() now = datetime.now()
iat_time = int(now.timestamp()) iat_time = int(now.timestamp())
exp_time = int(self.expires.timestamp()) exp_time = int(self.expires.timestamp())
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
# Fallback in case we can't find any login events
auth_time = now
auth_event = get_login_event(request)
if auth_event:
auth_time = auth_event.created
# Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
if method == "password":
amr.append(AMR_PASSWORD)
if method == "auth_webauthn_pwl":
amr.append(AMR_WEBAUTHN)
if "mfa_devices" in method_args:
if len(amr) > 0:
amr.append(AMR_MFA)
auth_timestamp = int(auth_time.timestamp())
token = IDToken( token = IDToken(
iss=self.provider.get_issuer(request), iss=self.provider.get_issuer(request),
@ -526,10 +509,27 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
aud=self.provider.client_id, aud=self.provider.client_id,
exp=exp_time, exp=exp_time,
iat=iat_time, iat=iat_time,
auth_time=auth_timestamp,
amr=amr if amr else None,
) )
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
auth_event = get_login_event(request)
if auth_event:
auth_time = auth_event.created
token.auth_time = int(auth_time.timestamp())
# Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
method_args = auth_event.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
amr = []
if method == "password":
amr.append(AMR_PASSWORD)
if method == "auth_webauthn_pwl":
amr.append(AMR_WEBAUTHN)
if "mfa_devices" in method_args:
if len(amr) > 0:
amr.append(AMR_MFA)
if amr:
token.amr = amr
# Include (or not) user standard claims in the id_token. # Include (or not) user standard claims in the id_token.
if self.provider.include_claims_in_id_token: if self.provider.include_claims_in_id_token:
from authentik.providers.oauth2.views.userinfo import UserInfoView from authentik.providers.oauth2.views.userinfo import UserInfoView

View File

@ -59,7 +59,6 @@ class TesOAuth2Introspection(OAuthTestCase):
res.content.decode(), res.content.decode(),
{ {
"acr": ACR_AUTHENTIK_DEFAULT, "acr": ACR_AUTHENTIK_DEFAULT,
"auth_time": None,
"aud": None, "aud": None,
"sub": "bar", "sub": "bar",
"exp": None, "exp": None,

View File

@ -16,7 +16,6 @@ class OAuthTestCase(TestCase):
required_jwt_keys = [ required_jwt_keys = [
"exp", "exp",
"iat", "iat",
"auth_time",
"acr", "acr",
"sub", "sub",
"iss", "iss",
@ -48,6 +47,7 @@ class OAuthTestCase(TestCase):
self.assert_non_none_or_unset(id_token, "nonce") self.assert_non_none_or_unset(id_token, "nonce")
self.assert_non_none_or_unset(id_token, "c_hash") self.assert_non_none_or_unset(id_token, "c_hash")
self.assert_non_none_or_unset(id_token, "amr") self.assert_non_none_or_unset(id_token, "amr")
self.assert_non_none_or_unset(id_token, "auth_time")
for key in self.required_jwt_keys: for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")