diff --git a/authentik/core/tests/utils.py b/authentik/core/tests/utils.py index 59b72e6cd..3da47f87b 100644 --- a/authentik/core/tests/utils.py +++ b/authentik/core/tests/utils.py @@ -44,9 +44,11 @@ def create_test_tenant() -> Tenant: return Tenant.objects.create(domain=uid, default=True) -def create_test_cert() -> CertificateKeyPair: +def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: """Generate a certificate for testing""" - builder = CertificateBuilder() + builder = CertificateBuilder( + use_ec_private_key=use_ec_private_key, + ) builder.common_name = "goauthentik.io" builder.build( subject_alt_names=["goauthentik.io"], diff --git a/authentik/crypto/builder.py b/authentik/crypto/builder.py index 5542e6e5c..98d391412 100644 --- a/authentik/crypto/builder.py +++ b/authentik/crypto/builder.py @@ -6,7 +6,8 @@ from typing import Optional from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import ec, rsa +from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES from cryptography.x509.oid import NameOID from authentik import __version__ @@ -18,7 +19,10 @@ class CertificateBuilder: common_name: str - def __init__(self): + _use_ec_private_key: bool + + def __init__(self, use_ec_private_key=False): + self._use_ec_private_key = use_ec_private_key self.__public_key = None self.__private_key = None self.__builder = None @@ -36,6 +40,14 @@ class CertificateBuilder: self.cert.save() return self.cert + def generate_private_key(self) -> PRIVATE_KEY_TYPES: + """Generate private key""" + if self._use_ec_private_key: + return ec.generate_private_key(curve=ec.SECP256R1) + return rsa.generate_private_key( + public_exponent=65537, key_size=4096, backend=default_backend() + ) + def build( self, validity_days: int = 365, @@ -43,9 +55,7 @@ class CertificateBuilder: ): """Build self-signed certificate""" one_day = datetime.timedelta(1, 0, 0) - self.__private_key = rsa.generate_private_key( - public_exponent=65537, key_size=4096, backend=default_backend() - ) + self.__private_key = self.generate_private_key() self.__public_key = self.__private_key.public_key() alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []] self.__builder = ( diff --git a/authentik/providers/oauth2/tests/test_jwks.py b/authentik/providers/oauth2/tests/test_jwks.py index 209a0215b..d42124167 100644 --- a/authentik/providers/oauth2/tests/test_jwks.py +++ b/authentik/providers/oauth2/tests/test_jwks.py @@ -3,6 +3,7 @@ import json from django.test import RequestFactory from django.urls.base import reverse +from jwt import PyJWKSet from authentik.core.models import Application from authentik.core.tests.utils import create_test_cert, create_test_flow @@ -32,6 +33,7 @@ class TestJWKS(OAuthTestCase): ) body = json.loads(response.content.decode()) self.assertEqual(len(body["keys"]), 1) + PyJWKSet.from_dict(body) def test_hs256(self): """Test JWKS request with HS256""" @@ -46,3 +48,20 @@ class TestJWKS(OAuthTestCase): reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) ) self.assertJSONEqual(response.content.decode(), {}) + + def test_es256(self): + """Test JWKS request with ES256""" + provider = OAuth2Provider.objects.create( + name="test", + client_id="test", + authorization_flow=create_test_flow(), + redirect_uris="http://local.invalid", + signing_key=create_test_cert(use_ec_private_key=True), + ) + app = Application.objects.create(name="test", slug="test", provider=provider) + response = self.client.get( + reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) + ) + body = json.loads(response.content.decode()) + self.assertEqual(len(body["keys"]), 1) + PyJWKSet.from_dict(body) diff --git a/authentik/providers/oauth2/views/jwks.py b/authentik/providers/oauth2/views/jwks.py index e996bc8fe..ee9a57ef7 100644 --- a/authentik/providers/oauth2/views/jwks.py +++ b/authentik/providers/oauth2/views/jwks.py @@ -4,6 +4,9 @@ from typing import Optional from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.ec import ( + SECP256R1, + SECP384R1, + SECP521R1, EllipticCurvePrivateKey, EllipticCurvePublicKey, ) @@ -26,6 +29,15 @@ def b64_enc(number: int) -> str: return final.decode("ascii") +# See https://notes.salrahman.com/generate-es256-es384-es512-private-keys/ +# and _CURVE_TYPES in the same file as the below curve files +ec_crv_map = { + SECP256R1: "P-256", + SECP384R1: "P-384", + SECP521R1: "P-512", +} + + class JWKSView(View): """Show RSA Key data for Provider""" @@ -54,8 +66,9 @@ class JWKSView(View): "alg": JWTAlgorithms.ES256, "use": "sig", "kid": key.kid, - "n": b64_enc(public_numbers.n), - "e": b64_enc(public_numbers.e), + "x": b64_enc(public_numbers.x), + "y": b64_enc(public_numbers.y), + "crv": ec_crv_map.get(type(public_key.curve), public_key.curve.name), } else: return key_data