providers/oauth2: add missing kid header to JWT Tokens
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
a265dd54cc
commit
6600da7d98
|
@ -6,11 +6,10 @@ import time
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Any, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from dacite import from_dict
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
|
@ -238,7 +237,7 @@ class OAuth2Provider(Provider):
|
|||
token.access_token = token.create_access_token(user, request)
|
||||
return token
|
||||
|
||||
def get_jwt_keys(self) -> Union[RSAPrivateKey, str]:
|
||||
def get_jwt_key(self) -> str:
|
||||
"""
|
||||
Takes a provider and returns the set of keys associated with it.
|
||||
Returns a list of keys.
|
||||
|
@ -255,7 +254,7 @@ class OAuth2Provider(Provider):
|
|||
self.jwt_alg = JWTAlgorithms.HS256
|
||||
self.save()
|
||||
else:
|
||||
return self.rsa_key.private_key
|
||||
return self.rsa_key.key_data
|
||||
|
||||
if self.jwt_alg == JWTAlgorithms.HS256:
|
||||
return self.client_secret
|
||||
|
@ -299,11 +298,14 @@ class OAuth2Provider(Provider):
|
|||
|
||||
def encode(self, payload: dict[str, Any]) -> str:
|
||||
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
||||
key = self.get_jwt_keys()
|
||||
headers = {}
|
||||
if self.rsa_key:
|
||||
headers["kid"] = self.rsa_key.kid
|
||||
key = self.get_jwt_key()
|
||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||
self.refresh_from_db()
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
return encode(payload, key, algorithm=self.jwt_alg)
|
||||
return encode(payload, key, algorithm=self.jwt_alg, headers=headers)
|
||||
|
||||
class Meta:
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from django.urls import reverse
|
|||
from django.utils.encoding import force_str
|
||||
|
||||
from authentik.core.models import Application, User
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.flows.challenge import ChallengeTypes
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.providers.oauth2.errors import (
|
||||
|
@ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase):
|
|||
client_secret=generate_client_secret(),
|
||||
authorization_flow=flow,
|
||||
redirect_uris="http://localhost",
|
||||
rsa_key=CertificateKeyPair.objects.first(),
|
||||
)
|
||||
Application.objects.create(name="app", slug="app", provider=provider)
|
||||
state = generate_client_id()
|
||||
|
|
|
@ -2,7 +2,11 @@
|
|||
from django.test import TestCase
|
||||
from jwt import decode
|
||||
|
||||
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
|
||||
from authentik.providers.oauth2.models import (
|
||||
JWTAlgorithms,
|
||||
OAuth2Provider,
|
||||
RefreshToken,
|
||||
)
|
||||
|
||||
|
||||
class OAuthTestCase(TestCase):
|
||||
|
@ -19,9 +23,12 @@ class OAuthTestCase(TestCase):
|
|||
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||
"""Validate that all required fields are set"""
|
||||
key = provider.client_secret
|
||||
if provider.jwt_alg == JWTAlgorithms.RS256:
|
||||
key = provider.rsa_key.public_key
|
||||
jwt = decode(
|
||||
token.access_token,
|
||||
provider.client_secret,
|
||||
key,
|
||||
algorithms=[provider.jwt_alg],
|
||||
audience=provider.client_id,
|
||||
)
|
||||
|
|
Reference in a new issue