add tenant api tests
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
parent
5c56dab82f
commit
cdfbb48cb6
|
@ -1,7 +1,7 @@
|
|||
"""Serializer for tenants models"""
|
||||
from hmac import compare_digest
|
||||
|
||||
from django.http import Http404
|
||||
from django.http import HttpResponseNotFound
|
||||
from django_tenants.utils import get_tenant
|
||||
from rest_framework import permissions
|
||||
from rest_framework.authentication import get_authorization_header
|
||||
|
@ -19,13 +19,15 @@ from authentik.lib.config import CONFIG
|
|||
from authentik.tenants.models import Domain, Tenant
|
||||
|
||||
|
||||
class TenantManagementKeyPermission(permissions.BasePermission):
|
||||
"""Authentication based on tenant_management_key"""
|
||||
class TenantApiKeyPermission(permissions.BasePermission):
|
||||
"""Authentication based on tenants.api_key"""
|
||||
|
||||
def has_permission(self, request: Request, view: View) -> bool:
|
||||
key = CONFIG.get("tenants.api_key", "")
|
||||
if not key:
|
||||
return False
|
||||
token = validate_auth(get_authorization_header(request))
|
||||
key = CONFIG.get("tenants.api_key")
|
||||
if compare_digest("", key):
|
||||
if token is None:
|
||||
return False
|
||||
return compare_digest(token, key)
|
||||
|
||||
|
@ -53,12 +55,13 @@ class TenantViewSet(ModelViewSet):
|
|||
"domains__domain",
|
||||
]
|
||||
ordering = ["schema_name"]
|
||||
permission_classes = [TenantManagementKeyPermission]
|
||||
authentication_classes = []
|
||||
permission_classes = [TenantApiKeyPermission]
|
||||
filter_backends = [OrderingFilter, SearchFilter]
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if not CONFIG.get_bool("tenants.enabled", True):
|
||||
return Http404()
|
||||
return HttpResponseNotFound()
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -81,9 +84,15 @@ class DomainViewSet(ModelViewSet):
|
|||
"tenant__schema_name",
|
||||
]
|
||||
ordering = ["domain"]
|
||||
permission_classes = [TenantManagementKeyPermission]
|
||||
authentication_classes = []
|
||||
permission_classes = [TenantApiKeyPermission]
|
||||
filter_backends = [OrderingFilter, SearchFilter]
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if not CONFIG.get_bool("tenants.enabled", True):
|
||||
return HttpResponseNotFound()
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class SettingsSerializer(ModelSerializer):
|
||||
"""Settings Serializer"""
|
||||
|
|
|
@ -6,6 +6,7 @@ import django.db.models.deletion
|
|||
import django_tenants.postgresql_backend.base
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.tenants.models
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
|
@ -42,7 +43,7 @@ class Migration(migrations.Migration):
|
|||
db_index=True,
|
||||
max_length=63,
|
||||
unique=True,
|
||||
validators=[django_tenants.postgresql_backend.base._check_schema_name],
|
||||
validators=[authentik.tenants.models._validate_schema_name],
|
||||
),
|
||||
),
|
||||
(
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Tenant models"""
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.utils import IntegrityError
|
||||
from django.dispatch import receiver
|
||||
|
@ -16,10 +18,25 @@ from authentik.lib.models import SerializerModel
|
|||
LOGGER = get_logger()
|
||||
|
||||
|
||||
VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$")
|
||||
|
||||
|
||||
def _validate_schema_name(name):
|
||||
if not VALID_SCHEMA_NAME.match(name):
|
||||
raise ValidationError(
|
||||
_(
|
||||
"Schema name must start with t_, only contain lowercase letters and numbers and be less than 63 characters."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Tenant(TenantMixin, SerializerModel):
|
||||
"""Tenant"""
|
||||
|
||||
tenant_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
schema_name = models.CharField(
|
||||
max_length=63, unique=True, db_index=True, validators=[_validate_schema_name]
|
||||
)
|
||||
name = models.TextField()
|
||||
|
||||
auto_create_schema = True
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
"""Test Tenant API"""
|
||||
from json import loads
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.db import connection
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APILiveServerTestCase, APITestCase, APITransactionTestCase
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
TENANTS_API_KEY = generate_id()
|
||||
HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"}
|
||||
|
||||
|
||||
class TestAPI(APITransactionTestCase):
|
||||
"""Test api view"""
|
||||
|
||||
def _fixture_teardown(self):
|
||||
for db_name in self._databases_names(include_mirrors=False):
|
||||
call_command(
|
||||
"flush",
|
||||
verbosity=0,
|
||||
interactive=False,
|
||||
database=db_name,
|
||||
reset_sequences=False,
|
||||
allow_cascade=True,
|
||||
inhibit_post_migrate=False,
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
call_command("migrate_schemas", schema="template", tenant=True)
|
||||
|
||||
def assertSchemaExists(self, schema_name):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';"
|
||||
)
|
||||
self.assertEqual(cursor.rowcount, 1)
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM information_schema.tables WHERE table_schema = 'template';"
|
||||
)
|
||||
expected_tables = cursor.rowcount
|
||||
cursor.execute(
|
||||
f"SELECT * FROM information_schema.tables WHERE table_schema = '{schema_name}';"
|
||||
)
|
||||
self.assertEqual(cursor.rowcount, expected_tables)
|
||||
|
||||
def assertSchemaDoesntExist(self, schema_name):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';"
|
||||
)
|
||||
self.assertEqual(cursor.rowcount, 0)
|
||||
|
||||
@CONFIG.patch("outposts.disable_embedded_outpost", True)
|
||||
@CONFIG.patch("tenants.enabled", True)
|
||||
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
|
||||
def test_tenant_create_delete(self):
|
||||
"""Test Tenant creation API Endpoint"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:tenant-list",
|
||||
),
|
||||
data={"name": generate_id(), "schema_name": "t_" + generate_id(length=63 - 2).lower()},
|
||||
headers=HEADERS,
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
body = loads(response.content.decode())
|
||||
|
||||
self.assertSchemaExists(body["schema_name"])
|
||||
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_api:tenant-detail",
|
||||
kwargs={"pk": body["tenant_uuid"]},
|
||||
),
|
||||
headers=HEADERS,
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
self.assertSchemaDoesntExist(body["schema_name"])
|
||||
|
||||
@CONFIG.patch("outposts.disable_embedded_outpost", True)
|
||||
@CONFIG.patch("tenants.enabled", True)
|
||||
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
|
||||
def test_unauthenticated(self):
|
||||
"""Test Tenant creation API Endpoint"""
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:tenant-list",
|
||||
),
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
@CONFIG.patch("outposts.disable_embedded_outpost", True)
|
||||
@CONFIG.patch("tenants.enabled", True)
|
||||
@CONFIG.patch("tenants.api_key", "")
|
||||
def test_no_api_key_configured(self):
|
||||
"""Test Tenant creation API Endpoint"""
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:tenant-list",
|
||||
),
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
||||
@CONFIG.patch("tenants.enabled", False)
|
||||
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
|
||||
def test_api_disabled(self):
|
||||
"""Test Tenant creation API Endpoint"""
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:tenant-list",
|
||||
),
|
||||
headers=HEADERS,
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
|
@ -13,7 +13,7 @@ COMMIT;
|
|||
class Migration(BaseMigration):
|
||||
def needs_migration(self) -> bool:
|
||||
self.cur.execute(
|
||||
"select * from information_schema.tables where table_name =" " 'django_migrations';"
|
||||
"select * from information_schema.tables where table_name = 'django_migrations';"
|
||||
)
|
||||
# No migration table, assume new installation
|
||||
if not bool(self.cur.rowcount):
|
||||
|
|
Reference in New Issue