providers/oauth2: optimise and cache signing key, prevent key being loaded multiple times

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-23 12:04:31 +01:00
parent 6a3a3e5f8d
commit 01da8e1792
No known key found for this signature in database
5 changed files with 11 additions and 9 deletions

View file

@ -4,12 +4,14 @@ import binascii
import json
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from functools import cached_property
from hashlib import sha256
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
from dacite.core import from_dict
from django.db import models
from django.http import HttpRequest
@ -259,7 +261,8 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request)
return token
def get_jwt_key(self) -> tuple[str, str]:
@cached_property
def jwt_key(self) -> tuple[str | PRIVATE_KEY_TYPES, str]:
"""Get either the configured certificate or the client secret"""
if not self.signing_key:
# No Certificate at all, assume HS256
@ -267,9 +270,9 @@ class OAuth2Provider(Provider):
key: CertificateKeyPair = self.signing_key
private_key = key.private_key
if isinstance(private_key, RSAPrivateKey):
return key.key_data, JWTAlgorithms.RS256
return private_key, JWTAlgorithms.RS256
if isinstance(private_key, EllipticCurvePrivateKey):
return key.key_data, JWTAlgorithms.ES256
return private_key, JWTAlgorithms.ES256
raise Exception(f"Invalid private key type: {type(private_key)}")
def get_issuer(self, request: HttpRequest) -> Optional[str]:
@ -312,10 +315,9 @@ class OAuth2Provider(Provider):
headers = {}
if self.signing_key:
headers["kid"] = self.signing_key.kid
key, alg = self.get_jwt_key()
key, alg = self.jwt_key
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
self.refresh_from_db()
# pyright: reportGeneralTypeIssues=false
return encode(payload, key, algorithm=alg, headers=headers)
class Meta:

View file

@ -143,7 +143,7 @@ class TestTokenClientCredentials(OAuthTestCase):
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(body["token_type"], "bearer")
_, alg = self.provider.get_jwt_key()
_, alg = self.provider.jwt_key
jwt = decode(
body["access_token"],
key=self.provider.signing_key.public_key,

View file

@ -210,7 +210,7 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(body["token_type"], "bearer")
_, alg = self.provider.get_jwt_key()
_, alg = self.provider.jwt_key
jwt = decode(
body["access_token"],
key=self.provider.signing_key.public_key,

View file

@ -29,7 +29,7 @@ class OAuthTestCase(TestCase):
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set"""
key, alg = provider.get_jwt_key()
key, alg = provider.jwt_key
if alg != JWTAlgorithms.HS256:
key = provider.signing_key.public_key
jwt = decode(

View file

@ -38,7 +38,7 @@ class ProviderInfoView(View):
)
if SCOPE_OPENID not in scopes:
scopes.append(SCOPE_OPENID)
_, supported_alg = provider.get_jwt_key()
_, supported_alg = provider.jwt_key
return {
"issuer": provider.get_issuer(self.request),
"authorization_endpoint": self.request.build_absolute_uri(