providers/oauth2: optimise and cache signing key, prevent key being loaded multiple times
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
6a3a3e5f8d
commit
01da8e1792
|
@ -4,12 +4,14 @@ import binascii
|
|||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from functools import cached_property
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
|
||||
from dacite.core import from_dict
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
|
@ -259,7 +261,8 @@ class OAuth2Provider(Provider):
|
|||
token.access_token = token.create_access_token(user, request)
|
||||
return token
|
||||
|
||||
def get_jwt_key(self) -> tuple[str, str]:
|
||||
@cached_property
|
||||
def jwt_key(self) -> tuple[str | PRIVATE_KEY_TYPES, str]:
|
||||
"""Get either the configured certificate or the client secret"""
|
||||
if not self.signing_key:
|
||||
# No Certificate at all, assume HS256
|
||||
|
@ -267,9 +270,9 @@ class OAuth2Provider(Provider):
|
|||
key: CertificateKeyPair = self.signing_key
|
||||
private_key = key.private_key
|
||||
if isinstance(private_key, RSAPrivateKey):
|
||||
return key.key_data, JWTAlgorithms.RS256
|
||||
return private_key, JWTAlgorithms.RS256
|
||||
if isinstance(private_key, EllipticCurvePrivateKey):
|
||||
return key.key_data, JWTAlgorithms.ES256
|
||||
return private_key, JWTAlgorithms.ES256
|
||||
raise Exception(f"Invalid private key type: {type(private_key)}")
|
||||
|
||||
def get_issuer(self, request: HttpRequest) -> Optional[str]:
|
||||
|
@ -312,10 +315,9 @@ class OAuth2Provider(Provider):
|
|||
headers = {}
|
||||
if self.signing_key:
|
||||
headers["kid"] = self.signing_key.kid
|
||||
key, alg = self.get_jwt_key()
|
||||
key, alg = self.jwt_key
|
||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||
self.refresh_from_db()
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
return encode(payload, key, algorithm=alg, headers=headers)
|
||||
|
||||
class Meta:
|
||||
|
|
|
@ -143,7 +143,7 @@ class TestTokenClientCredentials(OAuthTestCase):
|
|||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], "bearer")
|
||||
_, alg = self.provider.get_jwt_key()
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
|
|
|
@ -210,7 +210,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
|
|||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content.decode())
|
||||
self.assertEqual(body["token_type"], "bearer")
|
||||
_, alg = self.provider.get_jwt_key()
|
||||
_, alg = self.provider.jwt_key
|
||||
jwt = decode(
|
||||
body["access_token"],
|
||||
key=self.provider.signing_key.public_key,
|
||||
|
|
|
@ -29,7 +29,7 @@ class OAuthTestCase(TestCase):
|
|||
|
||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
|
||||
"""Validate that all required fields are set"""
|
||||
key, alg = provider.get_jwt_key()
|
||||
key, alg = provider.jwt_key
|
||||
if alg != JWTAlgorithms.HS256:
|
||||
key = provider.signing_key.public_key
|
||||
jwt = decode(
|
||||
|
|
|
@ -38,7 +38,7 @@ class ProviderInfoView(View):
|
|||
)
|
||||
if SCOPE_OPENID not in scopes:
|
||||
scopes.append(SCOPE_OPENID)
|
||||
_, supported_alg = provider.get_jwt_key()
|
||||
_, supported_alg = provider.jwt_key
|
||||
return {
|
||||
"issuer": provider.get_issuer(self.request),
|
||||
"authorization_endpoint": self.request.build_absolute_uri(
|
||||
|
|
Reference in a new issue