providers/oauth2: fix inconsistent expiry encoded in JWT
- access token validity is used for JWTs issues in implicit flows - general cleanup of how times are set closes #2581 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
bdf50a35cd
commit
3306003f0e
|
@ -2,9 +2,8 @@
|
|||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
@ -14,7 +13,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
|||
from dacite.core import from_dict
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
from django.utils import dateformat, timezone
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from jwt import encode
|
||||
from rest_framework.serializers import Serializer
|
||||
|
@ -25,7 +24,7 @@ from authentik.events.models import Event, EventAction
|
|||
from authentik.events.utils import get_user
|
||||
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config
|
||||
from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
@ -237,14 +236,18 @@ class OAuth2Provider(Provider):
|
|||
)
|
||||
|
||||
def create_refresh_token(
|
||||
self, user: User, scope: list[str], request: HttpRequest
|
||||
self,
|
||||
user: User,
|
||||
scope: list[str],
|
||||
request: HttpRequest,
|
||||
expiry: timedelta,
|
||||
) -> "RefreshToken":
|
||||
"""Create and populate a RefreshToken object."""
|
||||
token = RefreshToken(
|
||||
user=user,
|
||||
provider=self,
|
||||
refresh_token=base64.urlsafe_b64encode(generate_key().encode()).decode(),
|
||||
expires=timezone.now() + timedelta_from_string(self.token_validity),
|
||||
expires=timezone.now() + expiry,
|
||||
scope=scope,
|
||||
)
|
||||
token.access_token = token.create_access_token(user, request)
|
||||
|
@ -484,18 +487,21 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
|||
)
|
||||
|
||||
# Convert datetimes into timestamps.
|
||||
now = int(time.time())
|
||||
iat_time = now
|
||||
exp_time = int(dateformat.format(self.expires, "U"))
|
||||
now = datetime.now()
|
||||
iat_time = int(now.timestamp())
|
||||
exp_time = int(self.expires.timestamp())
|
||||
# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
|
||||
auth_events = Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)).order_by(
|
||||
"-created"
|
||||
auth_event = (
|
||||
Event.objects.filter(action=EventAction.LOGIN, user=get_user(user))
|
||||
.order_by("-created")
|
||||
.first()
|
||||
)
|
||||
# Fallback in case we can't find any login events
|
||||
auth_time = datetime.now()
|
||||
if auth_events.exists():
|
||||
auth_time = auth_events.first().created
|
||||
auth_time = int(dateformat.format(auth_time, "U"))
|
||||
auth_time = now
|
||||
if auth_event:
|
||||
auth_time = auth_event.created
|
||||
|
||||
auth_timestamp = int(auth_time.timestamp())
|
||||
|
||||
token = IDToken(
|
||||
iss=self.provider.get_issuer(request),
|
||||
|
@ -503,7 +509,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
|
|||
aud=self.provider.client_id,
|
||||
exp=exp_time,
|
||||
iat=iat_time,
|
||||
auth_time=auth_time,
|
||||
auth_time=auth_timestamp,
|
||||
)
|
||||
|
||||
# Include (or not) user standard claims in the id_token.
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
"""Test authorize view"""
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.challenge import ChallengeTypes
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AuthorizationCode,
|
||||
|
@ -250,6 +252,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
client_id="test",
|
||||
authorization_flow=flow,
|
||||
redirect_uris="foo://localhost",
|
||||
access_code_validity="seconds=100",
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
@ -277,6 +280,11 @@ class TestAuthorize(OAuthTestCase):
|
|||
"to": f"foo://localhost?code={code.code}&state={state}",
|
||||
},
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
code.expires.timestamp() - now().timestamp(),
|
||||
timedelta_from_string(provider.access_code_validity).total_seconds(),
|
||||
delta=5,
|
||||
)
|
||||
|
||||
def test_full_implicit(self):
|
||||
"""Test full authorization"""
|
||||
|
@ -288,6 +296,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
authorization_flow=flow,
|
||||
redirect_uris="http://localhost",
|
||||
signing_key=self.keypair,
|
||||
access_code_validity="seconds=100",
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_id()
|
||||
|
@ -308,6 +317,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
token: RefreshToken = RefreshToken.objects.filter(user=user).first()
|
||||
expires = timedelta_from_string(provider.access_code_validity).total_seconds()
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
|
@ -316,11 +326,16 @@ class TestAuthorize(OAuthTestCase):
|
|||
"to": (
|
||||
f"http://localhost#access_token={token.access_token}"
|
||||
f"&id_token={provider.encode(token.id_token.to_dict())}&token_type=bearer"
|
||||
f"&expires_in=60&state={state}"
|
||||
f"&expires_in={int(expires)}&state={state}"
|
||||
),
|
||||
},
|
||||
)
|
||||
self.validate_jwt(token, provider)
|
||||
jwt = self.validate_jwt(token, provider)
|
||||
self.assertAlmostEqual(
|
||||
jwt["exp"] - now().timestamp(),
|
||||
expires,
|
||||
delta=5,
|
||||
)
|
||||
|
||||
def test_full_form_post_id_token(self):
|
||||
"""Test full authorization (form_post response)"""
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""OAuth test helpers"""
|
||||
from typing import Any
|
||||
|
||||
from django.test import TestCase
|
||||
from jwt import decode
|
||||
|
||||
|
@ -25,7 +27,7 @@ class OAuthTestCase(TestCase):
|
|||
cls.keypair = create_test_cert()
|
||||
super().setUpClass()
|
||||
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
||||
"""Validate that all required fields are set"""
|
||||
key, alg = provider.get_jwt_key()
|
||||
if alg != JWTAlgorithms.HS256:
|
||||
|
@ -40,3 +42,4 @@ class OAuthTestCase(TestCase):
|
|||
for key in self.required_jwt_keys:
|
||||
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
||||
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
||||
return jwt
|
||||
|
|
|
@ -261,7 +261,7 @@ class OAuthAuthorizationParams:
|
|||
code.code_challenge = self.code_challenge
|
||||
code.code_challenge_method = self.code_challenge_method
|
||||
|
||||
code.expires_at = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
|
||||
code.expires = timezone.now() + timedelta_from_string(self.provider.access_code_validity)
|
||||
code.scope = self.scope
|
||||
code.nonce = self.nonce
|
||||
code.is_open_id = SCOPE_OPENID in self.scope
|
||||
|
@ -525,6 +525,7 @@ class OAuthFulfillmentStage(StageView):
|
|||
user=self.request.user,
|
||||
scope=self.params.scope,
|
||||
request=self.request,
|
||||
expiry=timedelta_from_string(self.provider.access_code_validity),
|
||||
)
|
||||
|
||||
# Check if response_type must include access_token in the response.
|
||||
|
|
|
@ -443,6 +443,7 @@ class TokenView(View):
|
|||
user=self.params.authorization_code.user,
|
||||
scope=self.params.authorization_code.scope,
|
||||
request=self.request,
|
||||
expiry=timedelta_from_string(self.provider.token_validity),
|
||||
)
|
||||
|
||||
if self.params.authorization_code.is_open_id:
|
||||
|
@ -478,6 +479,7 @@ class TokenView(View):
|
|||
user=self.params.refresh_token.user,
|
||||
scope=self.params.scope,
|
||||
request=self.request,
|
||||
expiry=timedelta_from_string(self.provider.token_validity),
|
||||
)
|
||||
|
||||
# If the Token has an id_token it's an Authentication request.
|
||||
|
@ -509,6 +511,7 @@ class TokenView(View):
|
|||
user=self.params.user,
|
||||
scope=self.params.scope,
|
||||
request=self.request,
|
||||
expiry=timedelta_from_string(self.provider.token_validity),
|
||||
)
|
||||
refresh_token.id_token = refresh_token.create_id_token(
|
||||
user=self.params.user,
|
||||
|
@ -535,6 +538,7 @@ class TokenView(View):
|
|||
user=self.params.device_code.user,
|
||||
scope=self.params.device_code.scope,
|
||||
request=self.request,
|
||||
expiry=timedelta_from_string(self.provider.token_validity),
|
||||
)
|
||||
refresh_token.id_token = refresh_token.create_id_token(
|
||||
user=self.params.device_code.user,
|
||||
|
|
Reference in New Issue