fix tenants tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2024-01-12 03:54:33 +01:00
parent 7ce8af6192
commit 0a15cf7452
No known key found for this signature in database
GPG Key ID: 9C3FA22FABF1AA8D
4 changed files with 17 additions and 5 deletions

View File

@ -5,6 +5,7 @@ from hmac import compare_digest
from django.http import HttpResponseNotFound from django.http import HttpResponseNotFound
from django.http.request import urljoin from django.http.request import urljoin
from django.utils.timezone import now from django.utils.timezone import now
from django_tenants.utils import get_public_schema_name
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework import permissions from rest_framework import permissions
from rest_framework.authentication import get_authorization_header from rest_framework.authentication import get_authorization_header
@ -217,3 +218,8 @@ class SettingsView(RetrieveUpdateAPIView):
obj = self.request.tenant obj = self.request.tenant
self.check_object_permissions(self.request, obj) self.check_object_permissions(self.request, obj)
return obj return obj
def perform_update(self, serializer):
# We need to be in the public schema to actually modify a tenant
with Tenant.objects.get(schema_name=get_public_schema_name()):
super().perform_update(serializer)

View File

@ -14,8 +14,8 @@ TENANTS_API_KEY = generate_id()
HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"} HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"}
class TestAPI(TenantAPITestCase): class TestRecovery(TenantAPITestCase):
"""Test api view""" """Test recovery endpoints"""
def setUp(self): def setUp(self):
super().setUp() super().setUp()

View File

@ -1,7 +1,6 @@
"""Test Settings API""" """Test Settings API"""
from django.urls import reverse from django.urls import reverse
from django_tenants.utils import get_public_schema_name
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
@ -13,9 +12,13 @@ HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"}
class TestSettingsAPI(TenantAPITestCase): class TestSettingsAPI(TenantAPITestCase):
"""Test settings API"""
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.tenant_1 = Tenant.objects.get(schema_name=get_public_schema_name()) self.tenant_1 = Tenant.objects.create(
name=generate_id(), schema_name="t_" + generate_id().lower()
)
Domain.objects.create(tenant=self.tenant_1, domain="tenant1.testserver")
with self.tenant_1: with self.tenant_1:
self.admin_1 = create_test_admin_user() self.admin_1 = create_test_admin_user()
self.tenant_2 = Tenant.objects.create( self.tenant_2 = Tenant.objects.create(
@ -37,6 +40,7 @@ class TestSettingsAPI(TenantAPITestCase):
data={ data={
"avatars": "tenant_1_mode", "avatars": "tenant_1_mode",
}, },
HTTP_HOST="tenant1.testserver",
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
with self.tenant_1: with self.tenant_1:

View File

@ -1,13 +1,14 @@
from django.core.management import call_command from django.core.management import call_command
from django.db import connection, connections from django.db import connection, connections
from django_tenants.utils import get_public_schema_name
from rest_framework.test import APITransactionTestCase from rest_framework.test import APITransactionTestCase
class TenantAPITestCase(APITransactionTestCase): class TenantAPITestCase(APITransactionTestCase):
# Overridden to also remove additional schemas we may have created # Overridden to also remove additional schemas we may have created
def _fixture_teardown(self): def _fixture_teardown(self):
super()._fixture_teardown()
for db_name in self._databases_names(include_mirrors=False): for db_name in self._databases_names(include_mirrors=False):
connections[db_name].set_schema_to_public()
with connections[db_name].cursor() as cursor: with connections[db_name].cursor() as cursor:
cursor.execute( cursor.execute(
"SELECT nspname FROM pg_catalog.pg_namespace WHERE nspname ~ 't_.*'" "SELECT nspname FROM pg_catalog.pg_namespace WHERE nspname ~ 't_.*'"
@ -16,6 +17,7 @@ class TenantAPITestCase(APITransactionTestCase):
for row in schemas: for row in schemas:
schema = row[0] schema = row[0]
cursor.execute(f"DROP SCHEMA {schema} CASCADE") cursor.execute(f"DROP SCHEMA {schema} CASCADE")
super()._fixture_teardown()
def setUp(self): def setUp(self):
call_command("migrate_schemas", schema="template", tenant=True) call_command("migrate_schemas", schema="template", tenant=True)