providers/oauth2: add missing kid header to JWT Tokens

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-05-21 23:40:00 +02:00
parent a265dd54cc
commit 6600da7d98
3 changed files with 19 additions and 8 deletions

View file

@ -6,11 +6,10 @@ import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from hashlib import sha256
from typing import Any, Optional, Type, Union
from typing import Any, Optional, Type
from urllib.parse import urlparse
from uuid import uuid4
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from dacite import from_dict
from django.db import models
from django.http import HttpRequest
@ -238,7 +237,7 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request)
return token
def get_jwt_keys(self) -> Union[RSAPrivateKey, str]:
def get_jwt_key(self) -> str:
"""
Takes a provider and returns the set of keys associated with it.
Returns a list of keys.
@ -255,7 +254,7 @@ class OAuth2Provider(Provider):
self.jwt_alg = JWTAlgorithms.HS256
self.save()
else:
return self.rsa_key.private_key
return self.rsa_key.key_data
if self.jwt_alg == JWTAlgorithms.HS256:
return self.client_secret
@ -299,11 +298,14 @@ class OAuth2Provider(Provider):
def encode(self, payload: dict[str, Any]) -> str:
"""Represent the ID Token as a JSON Web Token (JWT)."""
key = self.get_jwt_keys()
headers = {}
if self.rsa_key:
headers["kid"] = self.rsa_key.kid
key = self.get_jwt_key()
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
self.refresh_from_db()
# pyright: reportGeneralTypeIssues=false
return encode(payload, key, algorithm=self.jwt_alg)
return encode(payload, key, algorithm=self.jwt_alg, headers=headers)
class Meta:

View file

@ -4,6 +4,7 @@ from django.urls import reverse
from django.utils.encoding import force_str
from authentik.core.models import Application, User
from authentik.crypto.models import CertificateKeyPair
from authentik.flows.challenge import ChallengeTypes
from authentik.flows.models import Flow
from authentik.providers.oauth2.errors import (
@ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase):
client_secret=generate_client_secret(),
authorization_flow=flow,
redirect_uris="http://localhost",
rsa_key=CertificateKeyPair.objects.first(),
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_client_id()

View file

@ -2,7 +2,11 @@
from django.test import TestCase
from jwt import decode
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
from authentik.providers.oauth2.models import (
JWTAlgorithms,
OAuth2Provider,
RefreshToken,
)
class OAuthTestCase(TestCase):
@ -19,9 +23,12 @@ class OAuthTestCase(TestCase):
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set"""
key = provider.client_secret
if provider.jwt_alg == JWTAlgorithms.RS256:
key = provider.rsa_key.public_key
jwt = decode(
token.access_token,
provider.client_secret,
key,
algorithms=[provider.jwt_alg],
audience=provider.client_id,
)