get enterprise token
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
938f6fe439
commit
cf93445b3f
|
@ -88,7 +88,7 @@ class LicenseKey:
|
|||
@staticmethod
|
||||
def get_total() -> "LicenseKey":
|
||||
"""Get a summarized version of all (not expired) licenses"""
|
||||
active_licenses = License.objects.filter(expiry__gte=now())
|
||||
active_licenses = License.non_expired()
|
||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||
for lic in active_licenses:
|
||||
total.internal_users += lic.internal_users
|
||||
|
@ -167,6 +167,10 @@ class License(SerializerModel):
|
|||
internal_users = models.BigIntegerField()
|
||||
external_users = models.BigIntegerField()
|
||||
|
||||
@classmethod
|
||||
def non_expired(cls) -> QuerySet["License"]:
|
||||
return License.objects.filter(expiry__gte=now())
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[BaseSerializer]:
|
||||
from authentik.enterprise.api import LicenseSerializer
|
||||
|
|
|
@ -8,6 +8,8 @@ from grpc import (
|
|||
UnaryUnaryClientInterceptor,
|
||||
insecure_channel,
|
||||
intercept_channel,
|
||||
ssl_channel_credentials,
|
||||
secure_channel,
|
||||
)
|
||||
from grpc._interceptor import _ClientCallDetails
|
||||
|
||||
|
@ -48,12 +50,28 @@ class AuthInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor)
|
|||
return continuation(self._intercept_client_call_details(client_call_details), request)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_enterprise_token() -> str:
|
||||
"""Get enterprise license key, if a license is installed, otherwise use the install ID"""
|
||||
from authentik.root.install_id import get_install_id
|
||||
|
||||
try:
|
||||
from authentik.enterprise.models import License
|
||||
|
||||
license = License.non_expired().order_by("-expiry").first()
|
||||
if not license:
|
||||
return get_install_id()
|
||||
return license.key
|
||||
except ImportError:
|
||||
return get_install_id()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_client(addr: str):
|
||||
"""get a cached client to a cloud-gateway"""
|
||||
target = addr
|
||||
channel = secure_channel(addr, ssl_channel_credentials)
|
||||
if settings.DEBUG:
|
||||
target = insecure_channel(target)
|
||||
channel = intercept_channel(target, AuthInterceptor("foo"))
|
||||
channel = insecure_channel(addr)
|
||||
channel = intercept_channel(addr, AuthInterceptor(get_enterprise_token()))
|
||||
client = AuthenticationPushStub(channel)
|
||||
return client
|
||||
|
|
Reference in New Issue