diff --git a/authentik/core/management/commands/bootstrap_tasks.py b/authentik/core/management/commands/bootstrap_tasks.py index e57f82258..89258eaf4 100644 --- a/authentik/core/management/commands/bootstrap_tasks.py +++ b/authentik/core/management/commands/bootstrap_tasks.py @@ -1,13 +1,20 @@ """Run bootstrap tasks""" from django.core.management.base import BaseCommand +from django_tenants.utils import get_public_schema_name -from authentik.root.celery import _get_startup_tasks +from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant +from authentik.tenants.models import Tenant class Command(BaseCommand): """Run bootstrap tasks to ensure certain objects are created""" def handle(self, **options): - tasks = _get_startup_tasks() - for task in tasks: - task() + for task in _get_startup_tasks_default_tenant(): + with Tenant.objects.get(schema_name=get_public_schema_name()): + task() + + for task in _get_startup_tasks_all_tenants(): + for tenant in Tenant.objects.filter(ready=True): + with tenant: + task() diff --git a/authentik/root/test_runner.py b/authentik/root/test_runner.py index b2bf7a3d7..6e99fde3e 100644 --- a/authentik/root/test_runner.py +++ b/authentik/root/test_runner.py @@ -31,7 +31,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover settings.TEST = True settings.CELERY["task_always_eager"] = True - CONFIG.set("avatars", "none") CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb") CONFIG.set("blueprints_dir", "./blueprints") CONFIG.set( diff --git a/authentik/tenants/middleware.py b/authentik/tenants/middleware.py index af8d34e22..75143acf9 100644 --- a/authentik/tenants/middleware.py +++ b/authentik/tenants/middleware.py @@ -2,10 +2,9 @@ from typing import Callable from django.http import HttpRequest, HttpResponse +from django_tenants.utils import get_tenant from sentry_sdk.api import set_tag -from authentik.tenants.utils import get_tenant_for_request - class CurrentTenantMiddleware: """Add current tenant to http request""" @@ -17,8 +16,9 @@ class CurrentTenantMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: if not hasattr(request, "tenant"): - tenant = get_tenant_for_request(request) + tenant = get_tenant(request) setattr(request, "tenant", tenant) - set_tag("authentik.tenant_uuid", tenant.tenant_uuid.hex) - set_tag("authentik.tenant_domain_regex", tenant.domain_regex) + if tenant is not None: + set_tag("authentik.tenant_uuid", tenant.tenant_uuid.hex) + set_tag("authentik.tenant_domain_regex", tenant.domain_regex) return self.get_response(request)