diff --git a/passbook/flows/tests/test_planner.py b/passbook/flows/tests/test_planner.py index 25f2bd1a7..df0dc5a7e 100644 --- a/passbook/flows/tests/test_planner.py +++ b/passbook/flows/tests/test_planner.py @@ -1,13 +1,15 @@ """flow planner tests""" from unittest.mock import MagicMock, patch +from django.core.cache import cache from django.shortcuts import reverse from django.test import RequestFactory, TestCase from guardian.shortcuts import get_anonymous_user +from passbook.core.models import User from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding -from passbook.flows.planner import FlowPlanner +from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key from passbook.policies.types import PolicyResult from passbook.stages.dummy.models import DummyStage @@ -81,3 +83,24 @@ class TestFlowPlanner(TestCase): self.assertEqual( TIME_NOW_MOCK.call_count, 2 ) # When taking from cache, time is not measured + + def test_planner_default_context(self): + """Test planner with default_context""" + flow = Flow.objects.create( + name="test-default-context", + slug="test-default-context", + designation=FlowDesignation.AUTHENTICATION, + ) + FlowStageBinding.objects.create( + flow=flow, stage=DummyStage.objects.create(name="dummy"), order=0 + ) + + user = User.objects.create(username="test-user") + request = self.request_factory.get( + reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + ) + request.user = user + planner = FlowPlanner(flow) + planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER: user}) + key = cache_key(flow, user) + self.assertTrue(cache.get(key) is not None) diff --git a/passbook/flows/views.py b/passbook/flows/views.py index 37380a90c..f955106db 100644 --- a/passbook/flows/views.py +++ b/passbook/flows/views.py @@ -34,7 +34,7 @@ class FlowExecutorView(View): def setup(self, request: HttpRequest, flow_slug: str): super().setup(request, flow_slug=flow_slug) - self.flow = get_object_or_404(Flow, slug=flow_slug) + self.flow = get_object_or_404(Flow.objects.select_related(), slug=flow_slug) def handle_invalid_flow(self, exc: BaseException) -> HttpResponse: """When a flow is non-applicable check if user is on the correct domain""" diff --git a/passbook/stages/identification/api.py b/passbook/stages/identification/api.py index bd56a0a61..f40c329e3 100644 --- a/passbook/stages/identification/api.py +++ b/passbook/stages/identification/api.py @@ -16,6 +16,8 @@ class IdentificationStageSerializer(ModelSerializer): "name", "user_fields", "template", + "enrollment_flow", + "recovery_flow", ] diff --git a/passbook/stages/identification/migrations/0002_auto_20200530_2204.py b/passbook/stages/identification/migrations/0002_auto_20200530_2204.py new file mode 100644 index 000000000..8a554f1ee --- /dev/null +++ b/passbook/stages/identification/migrations/0002_auto_20200530_2204.py @@ -0,0 +1,41 @@ +# Generated by Django 3.0.6 on 2020-05-30 22:04 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_flows", "0002_default_flows"), + ("passbook_stages_identification", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="identificationstage", + name="enrollment_flow", + field=models.ForeignKey( + blank=True, + default=None, + help_text="Optional enrollment flow, which is linked at the bottom of the page.", + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + related_name="+", + to="passbook_flows.Flow", + ), + ), + migrations.AddField( + model_name="identificationstage", + name="recovery_flow", + field=models.ForeignKey( + blank=True, + default=None, + help_text="Optional enrollment flow, which is linked at the bottom of the page.", + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + related_name="+", + to="passbook_flows.Flow", + ), + ), + ] diff --git a/passbook/stages/identification/models.py b/passbook/stages/identification/models.py index d43b6b15c..94bf0ec48 100644 --- a/passbook/stages/identification/models.py +++ b/passbook/stages/identification/models.py @@ -3,7 +3,7 @@ from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.translation import gettext_lazy as _ -from passbook.flows.models import Stage +from passbook.flows.models import Flow, Stage class UserFields(models.TextChoices): @@ -29,6 +29,29 @@ class IdentificationStage(Stage): ) template = models.TextField(choices=Templates.choices) + enrollment_flow = models.ForeignKey( + Flow, + on_delete=models.SET_DEFAULT, + null=True, + blank=True, + related_name="+", + default=None, + help_text=_( + "Optional enrollment flow, which is linked at the bottom of the page." + ), + ) + recovery_flow = models.ForeignKey( + Flow, + on_delete=models.SET_DEFAULT, + null=True, + blank=True, + related_name="+", + default=None, + help_text=_( + "Optional enrollment flow, which is linked at the bottom of the page." + ), + ) + type = "passbook.stages.identification.stage.IdentificationStageView" form = "passbook.stages.identification.forms.IdentificationStageForm" diff --git a/passbook/stages/identification/stage.py b/passbook/stages/identification/stage.py index c6abe62b6..38189bf1b 100644 --- a/passbook/stages/identification/stage.py +++ b/passbook/stages/identification/stage.py @@ -10,7 +10,6 @@ from django.views.generic import FormView from structlog import get_logger from passbook.core.models import Source, User -from passbook.flows.models import FlowDesignation from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER from passbook.flows.stage import StageView from passbook.stages.identification.forms import IdentificationForm @@ -34,18 +33,17 @@ class IdentificationStageView(FormView, StageView): return [current_stage.template] def get_context_data(self, **kwargs): + current_stage: IdentificationStage = self.executor.current_stage # Check for related enrollment and recovery flow, add URL to view - enrollment_flow = self.executor.flow.related_flow(FlowDesignation.ENROLLMENT) - if enrollment_flow: + if current_stage.enrollment_flow: kwargs["enroll_url"] = reverse( "passbook_flows:flow-executor-shell", - kwargs={"flow_slug": enrollment_flow.slug}, + kwargs={"flow_slug": current_stage.enrollment_flow.slug}, ) - recovery_flow = self.executor.flow.related_flow(FlowDesignation.RECOVERY) - if recovery_flow: + if current_stage.recovery_flow: kwargs["recovery_url"] = reverse( "passbook_flows:flow-executor-shell", - kwargs={"flow_slug": recovery_flow.slug}, + kwargs={"flow_slug": current_stage.recovery_flow.slug}, ) kwargs["primary_action"] = _("Log in") diff --git a/passbook/stages/identification/tests.py b/passbook/stages/identification/tests.py index 4ecf255e9..fa8c3919d 100644 --- a/passbook/stages/identification/tests.py +++ b/passbook/stages/identification/tests.py @@ -85,15 +85,19 @@ class TestIdentificationStage(TestCase): slug="unique-enrollment-string", designation=FlowDesignation.ENROLLMENT, ) + self.stage.enrollment_flow = flow + self.stage.save() FlowStageBinding.objects.create( flow=flow, stage=self.stage, order=0, ) response = self.client.get( - reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": self.flow.slug} + ), ) self.assertEqual(response.status_code, 200) - self.assertIn(flow.name, response.rendered_content) + self.assertIn(flow.slug, response.rendered_content) def test_recovery_flow(self): """Test that recovery flow is linked correctly""" @@ -102,12 +106,16 @@ class TestIdentificationStage(TestCase): slug="unique-recovery-string", designation=FlowDesignation.RECOVERY, ) + self.stage.recovery_flow = flow + self.stage.save() FlowStageBinding.objects.create( flow=flow, stage=self.stage, order=0, ) response = self.client.get( - reverse("passbook_flows:flow-executor", kwargs={"flow_slug": flow.slug}), + reverse( + "passbook_flows:flow-executor", kwargs={"flow_slug": self.flow.slug} + ), ) self.assertEqual(response.status_code, 200) - self.assertIn(flow.name, response.rendered_content) + self.assertIn(flow.slug, response.rendered_content) diff --git a/swagger.yaml b/swagger.yaml index d999299c8..ac38448f0 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -5919,6 +5919,20 @@ definitions: enum: - stages/identification/login.html - stages/identification/recovery.html + enrollment_flow: + title: Enrollment flow + description: Optional enrollment flow, which is linked at the bottom of the + page. + type: string + format: uuid + x-nullable: true + recovery_flow: + title: Recovery flow + description: Optional enrollment flow, which is linked at the bottom of the + page. + type: string + format: uuid + x-nullable: true InvitationStage: required: - name