From 9814d3be0305b2e5f3faf482605ed79e1ba0ac2e Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Mon, 11 May 2020 15:01:14 +0200 Subject: [PATCH] flows: add Planner and Executor unittests --- passbook/flows/tests/test_misc.py | 24 ++++ passbook/flows/tests/test_views.py | 146 ++++++++++++++++++++++ passbook/flows/tests/test_views_helper.py | 39 ++++++ passbook/flows/views.py | 22 ++-- passbook/lib/utils/urls.py | 17 ++- passbook/recovery/tests.py | 4 +- 6 files changed, 239 insertions(+), 13 deletions(-) create mode 100644 passbook/flows/tests/test_misc.py create mode 100644 passbook/flows/tests/test_views.py create mode 100644 passbook/flows/tests/test_views_helper.py diff --git a/passbook/flows/tests/test_misc.py b/passbook/flows/tests/test_misc.py new file mode 100644 index 000000000..2fb773ee1 --- /dev/null +++ b/passbook/flows/tests/test_misc.py @@ -0,0 +1,24 @@ +"""miscellaneous flow tests""" +from django.test import TestCase + +from passbook.flows.api import StageSerializer, StageViewSet +from passbook.flows.models import Stage +from passbook.stages.dummy.models import DummyStage + + +class TestFlowsMisc(TestCase): + """miscellaneous tests""" + + def test_models(self): + """Test that ui_user_settings returns none""" + self.assertIsNone(Stage().ui_user_settings) + + def test_api_serializer(self): + """Test that stage serializer returns the correct type""" + obj = DummyStage() + self.assertEqual(StageSerializer().get_type(obj), "dummy") + + def test_api_viewset(self): + """Test that stage serializer returns the correct type""" + dummy = DummyStage.objects.create() + self.assertIn(dummy, StageViewSet().get_queryset()) diff --git a/passbook/flows/tests/test_views.py b/passbook/flows/tests/test_views.py new file mode 100644 index 000000000..081ba0f00 --- /dev/null +++ b/passbook/flows/tests/test_views.py @@ -0,0 +1,146 @@ +"""flow views tests""" +from unittest.mock import MagicMock, patch + +from django.shortcuts import reverse +from django.test import Client, TestCase + +from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException +from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding +from passbook.flows.planner import FlowPlan +from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN +from passbook.lib.config import CONFIG +from passbook.stages.dummy.models import DummyStage + +POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],)) + + +class TestFlowExecutor(TestCase): + """Test views logic""" + + def setUp(self): + self.client = Client() + + def test_invalid_domain(self): + """Check that an invalid domain triggers the correct message""" + flow = Flow.objects.create( + name="test-empty", + slug="test-empty", + designation=FlowDesignation.AUTHENTICATION, + ) + wrong_domain = CONFIG.y("domain") + "-invalid:8000" + response = self.client.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + HTTP_HOST=wrong_domain, + ) + self.assertEqual(response.status_code, 400) + self.assertIn("match", response.rendered_content) + self.assertIn(CONFIG.y("domain"), response.rendered_content) + self.assertIn(wrong_domain.split(":")[0], response.rendered_content) + + def test_existing_plan_diff_flow(self): + """Check that a plan for a different flow cancels the current plan""" + flow = Flow.objects.create( + name="test-existing-plan-diff", + slug="test-existing-plan-diff", + designation=FlowDesignation.AUTHENTICATION, + ) + stage = DummyStage.objects.create(name="dummy") + plan = FlowPlan(flow_pk=flow.pk.hex + "a", stages=[stage]) + session = self.client.session + session[SESSION_KEY_PLAN] = plan + session.save() + + cancel_mock = MagicMock() + with patch("passbook.flows.views.FlowExecutorView.cancel", cancel_mock): + response = self.client.get( + reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug} + ), + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(cancel_mock.call_count, 1) + + @patch( + "passbook.flows.planner.FlowPlanner._check_flow_root_policies", + POLICY_RESULT_MOCK, + ) + def test_invalid_non_applicable_flow(self): + """Tests that a non-applicable flow returns the correct error message""" + flow = Flow.objects.create( + name="test-non-applicable", + slug="test-non-applicable", + designation=FlowDesignation.AUTHENTICATION, + ) + + CONFIG.update_from_dict({"domain": "testserver"}) + response = self.client.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + self.assertEqual(response.status_code, 400) + self.assertInHTML(FlowNonApplicableException.__doc__, response.rendered_content) + + def test_invalid_empty_flow(self): + """Tests that an empty flow returns the correct error message""" + flow = Flow.objects.create( + name="test-empty", + slug="test-empty", + designation=FlowDesignation.AUTHENTICATION, + ) + + CONFIG.update_from_dict({"domain": "testserver"}) + response = self.client.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + self.assertEqual(response.status_code, 400) + self.assertInHTML(EmptyFlowException.__doc__, response.rendered_content) + + def test_invalid_flow_redirect(self): + """Tests that an invalid flow still redirects""" + flow = Flow.objects.create( + name="test-empty", + slug="test-empty", + designation=FlowDesignation.AUTHENTICATION, + ) + + CONFIG.update_from_dict({"domain": "testserver"}) + dest = "/unique-string" + response = self.client.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}) + + f"?{NEXT_ARG_NAME}={dest}" + ) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, dest) + + def test_multi_stage_flow(self): + """Test a full flow with multiple stages""" + flow = Flow.objects.create( + name="test-full", + slug="test-full", + designation=FlowDesignation.AUTHENTICATION, + ) + FlowStageBinding.objects.create( + flow=flow, stage=DummyStage.objects.create(name="dummy1"), order=0 + ) + FlowStageBinding.objects.create( + flow=flow, stage=DummyStage.objects.create(name="dummy2"), order=1 + ) + + exec_url = reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug} + ) + # First Request, start planning, renders form + response = self.client.get(exec_url) + self.assertEqual(response.status_code, 200) + # Check that two stages are in plan + session = self.client.session + plan: FlowPlan = session[SESSION_KEY_PLAN] + self.assertEqual(len(plan.stages), 2) + # Second request, submit form, one stage left + response = self.client.post(exec_url) + # Second request redirects to the same URL + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, exec_url) + # Check that two stages are in plan + session = self.client.session + plan: FlowPlan = session[SESSION_KEY_PLAN] + self.assertEqual(len(plan.stages), 1) diff --git a/passbook/flows/tests/test_views_helper.py b/passbook/flows/tests/test_views_helper.py new file mode 100644 index 000000000..7336cfc07 --- /dev/null +++ b/passbook/flows/tests/test_views_helper.py @@ -0,0 +1,39 @@ +"""flow views tests""" +from django.shortcuts import reverse +from django.test import Client, TestCase + +from passbook.flows.models import Flow, FlowDesignation +from passbook.flows.planner import FlowPlan +from passbook.flows.views import SESSION_KEY_PLAN + + +class TestHelperView(TestCase): + """Test helper views logic""" + + def setUp(self): + self.client = Client() + + def test_default_view(self): + """Test that ToDefaultFlow returns the expected URL""" + flow = Flow.objects.filter(designation=FlowDesignation.INVALIDATION,).first() + response = self.client.get(reverse("passbook_flows:default-invalidation"),) + expected_url = reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug} + ) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, expected_url) + + def test_default_view_invalid_plan(self): + """Test that ToDefaultFlow returns the expected URL (with an invalid plan)""" + flow = Flow.objects.filter(designation=FlowDesignation.INVALIDATION,).first() + plan = FlowPlan(flow_pk=flow.pk.hex + "aa", stages=[]) + session = self.client.session + session[SESSION_KEY_PLAN] = plan + session.save() + + response = self.client.get(reverse("passbook_flows:default-invalidation"),) + expected_url = reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug} + ) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, expected_url) diff --git a/passbook/flows/views.py b/passbook/flows/views.py index d687c8800..43ca4e609 100644 --- a/passbook/flows/views.py +++ b/passbook/flows/views.py @@ -12,7 +12,7 @@ from passbook.flows.models import Flow, FlowDesignation, Stage from passbook.flows.planner import FlowPlan, FlowPlanner from passbook.lib.config import CONFIG from passbook.lib.utils.reflection import class_to_path, path_to_class -from passbook.lib.utils.urls import is_url_absolute, redirect_with_qs +from passbook.lib.utils.urls import redirect_with_qs from passbook.lib.views import bad_request_message LOGGER = get_logger() @@ -59,7 +59,8 @@ class FlowExecutorView(View): incorrect_domain_message = self._check_config_domain() if incorrect_domain_message: return incorrect_domain_message - return bad_request_message(self.request, str(exc)) + message = exc.__doc__ if exc.__doc__ else str(exc) + return bad_request_message(self.request, message) def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse: # Early check if theres an active Plan for the current session @@ -128,10 +129,8 @@ class FlowExecutorView(View): def _flow_done(self) -> HttpResponse: """User Successfully passed all stages""" self.cancel() - next_param = self.request.GET.get(NEXT_ARG_NAME, None) - if next_param and not is_url_absolute(next_param): - return redirect(next_param) - return redirect_with_qs("passbook_core:overview") + next_param = self.request.GET.get(NEXT_ARG_NAME, "passbook_core:overview") + return redirect_with_qs(next_param) def stage_ok(self) -> HttpResponse: """Callback called by stages upon successful completion. @@ -183,9 +182,16 @@ class ToDefaultFlow(View): designation: Optional[FlowDesignation] = None def dispatch(self, request: HttpRequest) -> HttpResponse: - if SESSION_KEY_PLAN in self.request.session: - del self.request.session[SESSION_KEY_PLAN] flow = get_object_or_404(Flow, designation=self.designation) + # If user already has a pending plan, clear it so we don't have to later. + if SESSION_KEY_PLAN in self.request.session: + plan: FlowPlan = self.request.session[SESSION_KEY_PLAN] + if plan.flow_pk != flow.pk.hex: + LOGGER.warning( + "f(def): Found existing plan for other flow, deleteing plan", + flow_slug=flow.slug, + ) + del self.request.session[SESSION_KEY_PLAN] # TODO: Get Flow depending on subdomain? return redirect_with_qs( "passbook_flows:flow-executor", request.GET, flow_slug=flow.slug diff --git a/passbook/lib/utils/urls.py b/passbook/lib/utils/urls.py index 30cb25603..29f0e9eaf 100644 --- a/passbook/lib/utils/urls.py +++ b/passbook/lib/utils/urls.py @@ -3,7 +3,11 @@ from urllib.parse import urlparse from django.http import HttpResponse from django.shortcuts import redirect, reverse +from django.urls import NoReverseMatch from django.utils.http import urlencode +from structlog import get_logger + +LOGGER = get_logger() def is_url_absolute(url): @@ -13,7 +17,12 @@ def is_url_absolute(url): def redirect_with_qs(view: str, get_query_set=None, **kwargs) -> HttpResponse: """Wrapper to redirect whilst keeping GET Parameters""" - target = reverse(view, kwargs=kwargs) - if get_query_set: - target += "?" + urlencode(get_query_set.items()) - return redirect(target) + try: + target = reverse(view, kwargs=kwargs) + except NoReverseMatch: + LOGGER.debug("redirect target is not a valid view", view=view) + raise + else: + if get_query_set: + target += "?" + urlencode(get_query_set.items()) + return redirect(target) diff --git a/passbook/recovery/tests.py b/passbook/recovery/tests.py index bb2c19b68..3080b176e 100644 --- a/passbook/recovery/tests.py +++ b/passbook/recovery/tests.py @@ -6,6 +6,7 @@ from django.shortcuts import reverse from django.test import TestCase from passbook.core.models import Nonce, User +from passbook.lib.config import CONFIG class TestRecovery(TestCase): @@ -16,10 +17,11 @@ class TestRecovery(TestCase): def test_create_key(self): """Test creation of a new key""" + CONFIG.update_from_dict({"domain": "testserver"}) out = StringIO() self.assertEqual(len(Nonce.objects.all()), 0) call_command("create_recovery_key", "1", self.user.username, stdout=out) - self.assertIn("https://localhost/recovery/use-nonce/", out.getvalue()) + self.assertIn("https://testserver/recovery/use-nonce/", out.getvalue()) self.assertEqual(len(Nonce.objects.all()), 1) def test_recovery_view(self):