providers/oauth2: fix issues with es256 and add tests (#3808)

fix issues with es256 and add tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L 2022-10-18 22:01:29 +02:00 committed by GitHub
parent bb43c49b1e
commit b85be12567
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 9 deletions

View file

@ -44,9 +44,11 @@ def create_test_tenant() -> Tenant:
return Tenant.objects.create(domain=uid, default=True) return Tenant.objects.create(domain=uid, default=True)
def create_test_cert() -> CertificateKeyPair: def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:
"""Generate a certificate for testing""" """Generate a certificate for testing"""
builder = CertificateBuilder() builder = CertificateBuilder(
use_ec_private_key=use_ec_private_key,
)
builder.common_name = "goauthentik.io" builder.common_name = "goauthentik.io"
builder.build( builder.build(
subject_alt_names=["goauthentik.io"], subject_alt_names=["goauthentik.io"],

View file

@ -6,7 +6,8 @@ from typing import Optional
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
from authentik import __version__ from authentik import __version__
@ -18,7 +19,10 @@ class CertificateBuilder:
common_name: str common_name: str
def __init__(self): _use_ec_private_key: bool
def __init__(self, use_ec_private_key=False):
self._use_ec_private_key = use_ec_private_key
self.__public_key = None self.__public_key = None
self.__private_key = None self.__private_key = None
self.__builder = None self.__builder = None
@ -36,6 +40,14 @@ class CertificateBuilder:
self.cert.save() self.cert.save()
return self.cert return self.cert
def generate_private_key(self) -> PRIVATE_KEY_TYPES:
"""Generate private key"""
if self._use_ec_private_key:
return ec.generate_private_key(curve=ec.SECP256R1)
return rsa.generate_private_key(
public_exponent=65537, key_size=4096, backend=default_backend()
)
def build( def build(
self, self,
validity_days: int = 365, validity_days: int = 365,
@ -43,9 +55,7 @@ class CertificateBuilder:
): ):
"""Build self-signed certificate""" """Build self-signed certificate"""
one_day = datetime.timedelta(1, 0, 0) one_day = datetime.timedelta(1, 0, 0)
self.__private_key = rsa.generate_private_key( self.__private_key = self.generate_private_key()
public_exponent=65537, key_size=4096, backend=default_backend()
)
self.__public_key = self.__private_key.public_key() self.__public_key = self.__private_key.public_key()
alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []] alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []]
self.__builder = ( self.__builder = (

View file

@ -3,6 +3,7 @@ import json
from django.test import RequestFactory from django.test import RequestFactory
from django.urls.base import reverse from django.urls.base import reverse
from jwt import PyJWKSet
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_cert, create_test_flow
@ -32,6 +33,7 @@ class TestJWKS(OAuthTestCase):
) )
body = json.loads(response.content.decode()) body = json.loads(response.content.decode())
self.assertEqual(len(body["keys"]), 1) self.assertEqual(len(body["keys"]), 1)
PyJWKSet.from_dict(body)
def test_hs256(self): def test_hs256(self):
"""Test JWKS request with HS256""" """Test JWKS request with HS256"""
@ -46,3 +48,20 @@ class TestJWKS(OAuthTestCase):
reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug})
) )
self.assertJSONEqual(response.content.decode(), {}) self.assertJSONEqual(response.content.decode(), {})
def test_es256(self):
"""Test JWKS request with ES256"""
provider = OAuth2Provider.objects.create(
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(use_ec_private_key=True),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
response = self.client.get(
reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug})
)
body = json.loads(response.content.decode())
self.assertEqual(len(body["keys"]), 1)
PyJWKSet.from_dict(body)

View file

@ -4,6 +4,9 @@ from typing import Optional
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.ec import (
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurvePrivateKey, EllipticCurvePrivateKey,
EllipticCurvePublicKey, EllipticCurvePublicKey,
) )
@ -26,6 +29,15 @@ def b64_enc(number: int) -> str:
return final.decode("ascii") return final.decode("ascii")
# See https://notes.salrahman.com/generate-es256-es384-es512-private-keys/
# and _CURVE_TYPES in the same file as the below curve files
ec_crv_map = {
SECP256R1: "P-256",
SECP384R1: "P-384",
SECP521R1: "P-512",
}
class JWKSView(View): class JWKSView(View):
"""Show RSA Key data for Provider""" """Show RSA Key data for Provider"""
@ -54,8 +66,9 @@ class JWKSView(View):
"alg": JWTAlgorithms.ES256, "alg": JWTAlgorithms.ES256,
"use": "sig", "use": "sig",
"kid": key.kid, "kid": key.kid,
"n": b64_enc(public_numbers.n), "x": b64_enc(public_numbers.x),
"e": b64_enc(public_numbers.e), "y": b64_enc(public_numbers.y),
"crv": ec_crv_map.get(type(public_key.curve), public_key.curve.name),
} }
else: else:
return key_data return key_data