fix tests
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
4a6f956c28
commit
c3e9d23190
|
@ -7,8 +7,6 @@ from django.urls import reverse
|
||||||
from authentik import __version__
|
from authentik import __version__
|
||||||
from authentik.blueprints.tests import reconcile_app
|
from authentik.blueprints.tests import reconcile_app
|
||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
from authentik.core.tasks import clean_expired_models
|
|
||||||
from authentik.events.monitored_tasks import TaskStatus
|
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,53 +21,6 @@ class TestAdminAPI(TestCase):
|
||||||
self.group.save()
|
self.group.save()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
def test_tasks(self):
|
|
||||||
"""Test Task API"""
|
|
||||||
clean_expired_models.delay()
|
|
||||||
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
body = loads(response.content)
|
|
||||||
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
|
|
||||||
|
|
||||||
def test_tasks_single(self):
|
|
||||||
"""Test Task API (read single)"""
|
|
||||||
clean_expired_models.delay()
|
|
||||||
response = self.client.get(
|
|
||||||
reverse(
|
|
||||||
"authentik_api:admin_system_tasks-detail",
|
|
||||||
kwargs={"pk": "clean_expired_models"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
body = loads(response.content)
|
|
||||||
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.name)
|
|
||||||
self.assertEqual(body["task_name"], "clean_expired_models")
|
|
||||||
response = self.client.get(
|
|
||||||
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 404)
|
|
||||||
|
|
||||||
def test_tasks_retry(self):
|
|
||||||
"""Test Task API (retry)"""
|
|
||||||
clean_expired_models.delay()
|
|
||||||
response = self.client.post(
|
|
||||||
reverse(
|
|
||||||
"authentik_api:admin_system_tasks-retry",
|
|
||||||
kwargs={"pk": "clean_expired_models"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 204)
|
|
||||||
|
|
||||||
def test_tasks_retry_404(self):
|
|
||||||
"""Test Task API (retry, 404)"""
|
|
||||||
response = self.client.post(
|
|
||||||
reverse(
|
|
||||||
"authentik_api:admin_system_tasks-retry",
|
|
||||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 404)
|
|
||||||
|
|
||||||
def test_version(self):
|
def test_version(self):
|
||||||
"""Test Version API"""
|
"""Test Version API"""
|
||||||
response = self.client.get(reverse("authentik_api:admin_version"))
|
response = self.client.get(reverse("authentik_api:admin_version"))
|
||||||
|
|
|
@ -1,15 +1,25 @@
|
||||||
"""Test Monitored tasks"""
|
"""Test Monitored tasks"""
|
||||||
from django.test import TestCase
|
from json import loads
|
||||||
|
|
||||||
|
from django.urls import reverse
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.tasks import clean_expired_models
|
||||||
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.events.models import SystemTask, TaskStatus
|
from authentik.events.models import SystemTask, TaskStatus
|
||||||
from authentik.events.monitored_tasks import MonitoredTask
|
from authentik.events.monitored_tasks import MonitoredTask
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
|
|
||||||
|
|
||||||
class TestMonitoredTasks(TestCase):
|
class TestSystemTasks(APITestCase):
|
||||||
"""Test Monitored tasks"""
|
"""Test Monitored tasks"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.user = create_test_admin_user()
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
def test_failed_successful_remove_state(self):
|
def test_failed_successful_remove_state(self):
|
||||||
"""Test that a task with `save_on_success` set to `False` that failed saves
|
"""Test that a task with `save_on_success` set to `False` that failed saves
|
||||||
a state, and upon successful completion will delete the state"""
|
a state, and upon successful completion will delete the state"""
|
||||||
|
@ -28,15 +38,64 @@ class TestMonitoredTasks(TestCase):
|
||||||
# First test successful run
|
# First test successful run
|
||||||
should_fail = False
|
should_fail = False
|
||||||
test_task.delay().get()
|
test_task.delay().get()
|
||||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid))
|
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||||
|
|
||||||
# Then test failed
|
# Then test failed
|
||||||
should_fail = True
|
should_fail = True
|
||||||
test_task.delay().get()
|
test_task.delay().get()
|
||||||
info = SystemTask.objects.filter(name="test_task", uid=uid)
|
task = SystemTask.objects.filter(name="test_task", uid=uid).first()
|
||||||
self.assertEqual(info.status, TaskStatus.ERROR)
|
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||||
|
|
||||||
# Then after that, the state should be removed
|
# Then after that, the state should be removed
|
||||||
should_fail = False
|
should_fail = False
|
||||||
test_task.delay().get()
|
test_task.delay().get()
|
||||||
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid))
|
self.assertIsNone(SystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||||
|
|
||||||
|
def test_tasks(self):
|
||||||
|
"""Test Task API"""
|
||||||
|
clean_expired_models.delay().get()
|
||||||
|
response = self.client.get(reverse("authentik_api:systemtask-list"))
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = loads(response.content)
|
||||||
|
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
|
||||||
|
|
||||||
|
def test_tasks_single(self):
|
||||||
|
"""Test Task API (read single)"""
|
||||||
|
clean_expired_models.delay().get()
|
||||||
|
task = SystemTask.objects.filter(name="clean_expired_models").first()
|
||||||
|
response = self.client.get(
|
||||||
|
reverse(
|
||||||
|
"authentik_api:systemtask-detail",
|
||||||
|
kwargs={"pk": str(task.pk)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = loads(response.content)
|
||||||
|
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
|
||||||
|
self.assertEqual(body["name"], "clean_expired_models")
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 404)
|
||||||
|
|
||||||
|
def test_tasks_run(self):
|
||||||
|
"""Test Task API (run)"""
|
||||||
|
clean_expired_models.delay().get()
|
||||||
|
task = SystemTask.objects.filter(name="clean_expired_models").first()
|
||||||
|
response = self.client.post(
|
||||||
|
reverse(
|
||||||
|
"authentik_api:systemtask-run",
|
||||||
|
kwargs={"pk": str(task.pk)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 204)
|
||||||
|
|
||||||
|
def test_tasks_run_404(self):
|
||||||
|
"""Test Task API (run, 404)"""
|
||||||
|
response = self.client.post(
|
||||||
|
reverse(
|
||||||
|
"authentik_api:systemtask-run",
|
||||||
|
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 404)
|
||||||
|
|
|
@ -34,7 +34,7 @@ CACHE_KEY_STATUS = "goauthentik.io/sources/ldap/status/"
|
||||||
def ldap_sync_all():
|
def ldap_sync_all():
|
||||||
"""Sync all sources"""
|
"""Sync all sources"""
|
||||||
for source in LDAPSource.objects.filter(enabled=True):
|
for source in LDAPSource.objects.filter(enabled=True):
|
||||||
ldap_sync_single.apply_async(args=[source.pk])
|
ldap_sync_single.apply_async(args=[str(source.pk)])
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task()
|
@CELERY_APP.task()
|
||||||
|
@ -95,7 +95,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||||
for page in sync_inst.get_objects():
|
for page in sync_inst.get_objects():
|
||||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||||
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
||||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
page_sync = ldap_sync.si(str(source.pk), class_to_path(sync), page_cache_key)
|
||||||
signatures.append(page_sync)
|
signatures.append(page_sync)
|
||||||
return signatures
|
return signatures
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class LDAPSyncTests(TestCase):
|
||||||
"""Test sync with missing page"""
|
"""Test sync with missing page"""
|
||||||
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||||
ldap_sync.delay(self.source.pk, class_to_path(UserLDAPSynchronizer), "foo").get()
|
ldap_sync.delay(str(self.source.pk), class_to_path(UserLDAPSynchronizer), "foo").get()
|
||||||
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
|
task = SystemTask.objects.filter(name="ldap_sync", uid="ldap:users:foo").first()
|
||||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ def send_mails(stage: EmailStage, *messages: list[EmailMultiAlternatives]):
|
||||||
"""Wrapper to convert EmailMessage to dict and send it from worker"""
|
"""Wrapper to convert EmailMessage to dict and send it from worker"""
|
||||||
tasks = []
|
tasks = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
tasks.append(send_mail.s(message.__dict__, stage.pk))
|
tasks.append(send_mail.s(message.__dict__, str(stage.pk)))
|
||||||
lazy_group = group(*tasks)
|
lazy_group = group(*tasks)
|
||||||
promise = lazy_group()
|
promise = lazy_group()
|
||||||
return promise
|
return promise
|
||||||
|
@ -46,7 +46,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str:
|
||||||
retry_backoff=True,
|
retry_backoff=True,
|
||||||
base=MonitoredTask,
|
base=MonitoredTask,
|
||||||
)
|
)
|
||||||
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None):
|
def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[str] = None):
|
||||||
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
||||||
self.save_on_success = False
|
self.save_on_success = False
|
||||||
message_id = make_msgid(domain=DNS_NAME)
|
message_id = make_msgid(domain=DNS_NAME)
|
||||||
|
|
Reference in New Issue