fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2024-01-14 14:43:30 +01:00
parent 4a6f956c28
commit c3e9d23190
No known key found for this signature in database
5 changed files with 70 additions and 60 deletions

View File

@ -7,8 +7,6 @@ from django.urls import reverse
from authentik import __version__ from authentik import __version__
from authentik.blueprints.tests import reconcile_app from authentik.blueprints.tests import reconcile_app
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.core.tasks import clean_expired_models
from authentik.events.monitored_tasks import TaskStatus
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
@ -23,53 +21,6 @@ class TestAdminAPI(TestCase):
self.group.save() self.group.save()
self.client.force_login(self.user) self.client.force_login(self.user)
def test_tasks(self):
"""Test Task API"""
clean_expired_models.delay()
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
def test_tasks_single(self):
"""Test Task API (read single)"""
clean_expired_models.delay()
response = self.client.get(
reverse(
"authentik_api:admin_system_tasks-detail",
kwargs={"pk": "clean_expired_models"},
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.name)
self.assertEqual(body["task_name"], "clean_expired_models")
response = self.client.get(
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)
def test_tasks_retry(self):
"""Test Task API (retry)"""
clean_expired_models.delay()
response = self.client.post(
reverse(
"authentik_api:admin_system_tasks-retry",
kwargs={"pk": "clean_expired_models"},
)
)
self.assertEqual(response.status_code, 204)
def test_tasks_retry_404(self):
"""Test Task API (retry, 404)"""
response = self.client.post(
reverse(
"authentik_api:admin_system_tasks-retry",
kwargs={"pk": "qwerqewrqrqewrqewr"},
)
)
self.assertEqual(response.status_code, 404)
def test_version(self): def test_version(self):
"""Test Version API""" """Test Version API"""
response = self.client.get(reverse("authentik_api:admin_version")) response = self.client.get(reverse("authentik_api:admin_version"))

View File

@ -1,15 +1,25 @@
"""Test Monitored tasks""" """Test Monitored tasks"""
from django.test import TestCase from json import loads
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tasks import clean_expired_models
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import SystemTask, TaskStatus from authentik.events.models import SystemTask, TaskStatus
from authentik.events.monitored_tasks import MonitoredTask from authentik.events.monitored_tasks import MonitoredTask
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
class TestMonitoredTasks(TestCase): class TestSystemTasks(APITestCase):
"""Test Monitored tasks""" """Test Monitored tasks"""
def setUp(self):
super().setUp()
self.user = create_test_admin_user()
self.client.force_login(self.user)
def test_failed_successful_remove_state(self): def test_failed_successful_remove_state(self):
"""Test that a task with `save_on_success` set to `False` that failed saves """Test that a task with `save_on_success` set to `False` that failed saves
a state, and upon successful completion will delete the state""" a state, and upon successful completion will delete the state"""
@ -28,15 +38,64 @@ class TestMonitoredTasks(TestCase):
# First test successful run # First test successful run
should_fail = False should_fail = False
test_task.delay().get() test_task.delay().get()
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid)) self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
# Then test failed # Then test failed
should_fail = True should_fail = True
test_task.delay().get() test_task.delay().get()
info = SystemTask.objects.filter(name="test_task", uid=uid) task = SystemTask.objects.filter(name="test_task", uid=uid).first()
self.assertEqual(info.status, TaskStatus.ERROR) self.assertEqual(task.status, TaskStatus.ERROR)
# Then after that, the state should be removed # Then after that, the state should be removed
should_fail = False should_fail = False
test_task.delay().get() test_task.delay().get()
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid)) self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
def test_tasks(self):
"""Test Task API"""
clean_expired_models.delay().get()
response = self.client.get(reverse("authentik_api:systemtask-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
def test_tasks_single(self):
"""Test Task API (read single)"""
clean_expired_models.delay().get()
task = SystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.get(
reverse(
"authentik_api:systemtask-detail",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
self.assertEqual(body["name"], "clean_expired_models")
response = self.client.get(
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)
def test_tasks_run(self):
"""Test Task API (run)"""
clean_expired_models.delay().get()
task = SystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 204)
def test_tasks_run_404(self):
"""Test Task API (run, 404)"""
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": "qwerqewrqrqewrqewr"},
)
)
self.assertEqual(response.status_code, 404)

View File

@ -34,7 +34,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/"
def ldap_sync_all(): def ldap_sync_all():
"""Sync all sources""" """Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True): for source in LDAPSource.objects.filter(enabled=True):
ldap_sync_single.apply_async(args=[source.pk]) ldap_sync_single.apply_async(args=[str(source.pk)])
@CELERY_APP.task() @CELERY_APP.task()
@ -95,7 +95,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
for page in sync_inst.get_objects(): for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key)
signatures.append(page_sync) signatures.append(page_sync)
return signatures return signatures

View File

@ -40,7 +40,7 @@ class LDAPSyncTests(TestCase):
"""Test sync with missing page""" """Test sync with missing page"""
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get() ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get()
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first() task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
self.assertEqual(task.status, TaskStatus.ERROR) self.assertEqual(task.status, TaskStatus.ERROR)

View File

@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]):
"""Wrapper to convert EmailMessage to dict and send it from worker""" """Wrapper to convert EmailMessage to dict and send it from worker"""
tasks = [] tasks = []
for message in messages: for message in messages:
tasks.append(send_mail.s(message.__dict__, stage.pk)) tasks.append(send_mail.s(message.__dict__, str(stage.pk)))
lazy_group = group(*tasks) lazy_group = group(*tasks)
promise = lazy_group() promise = lazy_group()
return promise return promise
@ -46,7 +46,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str:
retry_backoff=True, retry_backoff=True,
base=MonitoredTask, base=MonitoredTask,
) )
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None): def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None):
"""Send Email for Email Stage. Retries are scheduled automatically.""" """Send Email for Email Stage. Retries are scheduled automatically."""
self.save_on_success = False self.save_on_success = False
message_id = make_msgid(domain=DNS_NAME) message_id = make_msgid(domain=DNS_NAME)