Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2023-11-16 15:17:04 +01:00
parent d28b488e91
commit b6efa3bde3
No known key found for this signature in database
GPG key ID: 9C3FA22FABF1AA8D
7 changed files with 16 additions and 3 deletions

View file

@ -69,6 +69,7 @@ class Themes(models.TextChoices):
def get_default_ui_footer_links(): def get_default_ui_footer_links():
"""Get default UI footer links based on current tenant settings"""
return get_current_tenant().footer_links return get_current_tenant().footer_links

View file

@ -117,6 +117,7 @@ def add_process_id(logger: Logger, method_name: str, event_dict):
def add_tenant_information(logger: Logger, method_name: str, event_dict): def add_tenant_information(logger: Logger, method_name: str, event_dict):
"""Add the current tenant"""
tenant = getattr(connection, "tenant", None) tenant = getattr(connection, "tenant", None)
if tenant is not None: if tenant is not None:
event_dict["schema_name"] = tenant.schema_name event_dict["schema_name"] = tenant.schema_name

View file

@ -8,6 +8,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials""" """database backend which supports rotating credentials"""
def get_connection_params(self): def get_connection_params(self):
"""Refresh DB credentials before getting connection params"""
CONFIG.refresh("postgresql.password") CONFIG.refresh("postgresql.password")
conn_params = super().get_connection_params() conn_params = super().get_connection_params()
conn_params["user"] = CONFIG.get("postgresql.user") conn_params["user"] = CONFIG.get("postgresql.user")

View file

@ -19,6 +19,8 @@ from authentik.tenants.models import Domain, Tenant
class TenantManagementKeyPermission(permissions.BasePermission): class TenantManagementKeyPermission(permissions.BasePermission):
"""Authentication based on tenant_management_key"""
def has_permission(self, request: Request, view: View) -> bool: def has_permission(self, request: Request, view: View) -> bool:
token = validate_auth(get_authorization_header(request)) token = validate_auth(get_authorization_header(request))
tenant_management_key = CONFIG.get("tenant_management_key") tenant_management_key = CONFIG.get("tenant_management_key")
@ -110,5 +112,5 @@ class SettingsView(RetrieveUpdateAPIView):
def get_object(self): def get_object(self):
obj = get_tenant(self.request) obj = get_tenant(self.request)
self.check_object_permissions(obj) self.check_object_permissions(self.request, obj)
return obj return obj

View file

@ -3,6 +3,7 @@ from uuid import uuid4
from django.apps import apps from django.apps import apps
from django.db import models from django.db import models
from django.db.utils import IntegrityError
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync from django_tenants.models import DomainMixin, TenantMixin, post_schema_sync
@ -60,12 +61,12 @@ class Tenant(TenantMixin, SerializerModel):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if self.schema_name == "template": if self.schema_name == "template":
raise Exception("Cannot create schema named template") raise IntegrityError("Cannot create schema named template")
super().save(*args, **kwargs) super().save(*args, **kwargs)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
if self.schema_name in ("public", "template"): if self.schema_name in ("public", "template"):
raise Exception("Cannot delete schema public or template") raise IntegrityError("Cannot delete schema public or template")
super().delete(*args, **kwargs) super().delete(*args, **kwargs)
@property @property
@ -83,6 +84,8 @@ class Tenant(TenantMixin, SerializerModel):
class Domain(DomainMixin, SerializerModel): class Domain(DomainMixin, SerializerModel):
"""Tenant domain"""
def __str__(self) -> str: def __str__(self) -> str:
return f"Domain {self.domain}" return f"Domain {self.domain}"
@ -99,6 +102,7 @@ class Domain(DomainMixin, SerializerModel):
@receiver(post_schema_sync, sender=TenantMixin) @receiver(post_schema_sync, sender=TenantMixin)
def tenant_needs_sync(sender, tenant, **kwargs): def tenant_needs_sync(sender, tenant, **kwargs):
"""Reconcile apps for a specific tenant on creation"""
if tenant.ready: if tenant.ready:
return return

View file

@ -1,9 +1,12 @@
"""Tenant-aware Celery beat scheduler"""
from tenant_schemas_celery.scheduler import ( from tenant_schemas_celery.scheduler import (
TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler, TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler,
) )
class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
"""Tenant-aware Celery beat scheduler"""
@classmethod @classmethod
def get_queryset(cls): def get_queryset(cls):
return super().get_queryset().filter(ready=True) return super().get_queryset().filter(ready=True)

View file

@ -1,3 +1,4 @@
"""Tenant utils"""
from django.db import connection from django.db import connection
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant