From dd017e719060c6b77fd2695e8a5bdc5413b3e89f Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sun, 6 Sep 2020 01:07:06 +0200 Subject: [PATCH] flows: fix exporting and importing for models with multiple unique fields --- docs/flow/examples/login-2fa.json | 111 ++++++++++++++++ passbook/flows/forms.py | 2 +- passbook/flows/management/__init__.py | 0 .../flows/management/commands/__init__.py | 0 .../flows/management/commands/apply_flow.py | 22 ++++ passbook/flows/migrations/0011_flow_title.py | 8 +- .../migrations/0013_auto_20200905_2142.py | 16 +++ passbook/flows/models.py | 4 +- passbook/flows/tests/test_transfer.py | 35 +++-- passbook/flows/transfer/common.py | 25 ++-- passbook/flows/transfer/exporter.py | 6 +- passbook/flows/transfer/importer.py | 122 ++++++++++-------- passbook/outposts/signals.py | 2 +- passbook/root/settings.py | 1 + passbook/stages/prompt/api.py | 5 +- 15 files changed, 280 insertions(+), 79 deletions(-) create mode 100644 docs/flow/examples/login-2fa.json create mode 100644 passbook/flows/management/__init__.py create mode 100644 passbook/flows/management/commands/__init__.py create mode 100644 passbook/flows/management/commands/apply_flow.py create mode 100644 passbook/flows/migrations/0013_auto_20200905_2142.py diff --git a/docs/flow/examples/login-2fa.json b/docs/flow/examples/login-2fa.json new file mode 100644 index 000000000..a1b9637a5 --- /dev/null +++ b/docs/flow/examples/login-2fa.json @@ -0,0 +1,111 @@ +{ + "version": 1, + "entries": [ + { + "identifiers": { + "slug": "default-authentication-flow", + "pk": "563ece21-e9a4-47e5-a264-23ffd923e393" + }, + "model": "passbook_flows.flow", + "attrs": { + "name": "Default Authentication Flow", + "title": "Welcome to passbook!", + "designation": "authentication" + } + }, + { + "identifiers": { + "pk": "69d41125-3987-499b-8d74-ef27b54b88c8", + "name": "default-authentication-login" + }, + "model": "passbook_stages_user_login.userloginstage", + "attrs": { + "session_duration": 0 + } + }, + { + "identifiers": { + "pk": "5f594f27-0def-488d-9855-fe604eb13de5", + "name": "default-authentication-identification" + }, + "model": "passbook_stages_identification.identificationstage", + "attrs": { + "user_fields": [ + "email", + "username" + ], + "template": "stages/identification/login.html", + "enrollment_flow": null, + "recovery_flow": null + } + }, + { + "identifiers": { + "pk": "37f709c3-8817-45e8-9a93-80a925d293c2", + "name": "default-authentication-flow-totp" + }, + "model": "passbook_stages_otp_validate.otpvalidatestage", + "attrs": {} + }, + { + "identifiers": { + "pk": "d8affa62-500c-4c5c-a01f-5835e1ffdf40", + "name": "default-authentication-password" + }, + "model": "passbook_stages_password.passwordstage", + "attrs": { + "backends": [ + "django.contrib.auth.backends.ModelBackend" + ] + } + }, + { + "identifiers": { + "pk": "a3056482-b692-4e3a-93f1-7351c6a351c7", + "target": "563ece21-e9a4-47e5-a264-23ffd923e393", + "stage": "5f594f27-0def-488d-9855-fe604eb13de5", + "order": 0 + }, + "model": "passbook_flows.flowstagebinding", + "attrs": { + "re_evaluate_policies": false + } + }, + { + "identifiers": { + "pk": "4e8538cf-3e18-4a68-82ae-6df6725fa2e6", + "target": "563ece21-e9a4-47e5-a264-23ffd923e393", + "stage": "d8affa62-500c-4c5c-a01f-5835e1ffdf40", + "order": 1 + }, + "model": "passbook_flows.flowstagebinding", + "attrs": { + "re_evaluate_policies": false + } + }, + { + "identifiers": { + "pk": "688aec6f-5622-42c6-83a5-d22072d7e798", + "target": "563ece21-e9a4-47e5-a264-23ffd923e393", + "stage": "37f709c3-8817-45e8-9a93-80a925d293c2", + "order": 2 + }, + "model": "passbook_flows.flowstagebinding", + "attrs": { + "re_evaluate_policies": false + } + }, + { + "identifiers": { + "pk": "f3fede3a-a9b5-4232-9ec7-be7ff4194b27", + "target": "563ece21-e9a4-47e5-a264-23ffd923e393", + "stage": "69d41125-3987-499b-8d74-ef27b54b88c8", + "order": 3 + }, + "model": "passbook_flows.flowstagebinding", + "attrs": { + "re_evaluate_policies": false + } + } + ] +} diff --git a/passbook/flows/forms.py b/passbook/flows/forms.py index 8345a6987..b15c1fb8d 100644 --- a/passbook/flows/forms.py +++ b/passbook/flows/forms.py @@ -39,7 +39,7 @@ class FlowForm(forms.ModelForm): class FlowStageBindingForm(forms.ModelForm): """FlowStageBinding Form""" - stage = GroupedModelChoiceField(queryset=Stage.objects.all().select_subclasses(),) + stage = GroupedModelChoiceField(queryset=Stage.objects.all().select_subclasses()) class Meta: diff --git a/passbook/flows/management/__init__.py b/passbook/flows/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/flows/management/commands/__init__.py b/passbook/flows/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/flows/management/commands/apply_flow.py b/passbook/flows/management/commands/apply_flow.py new file mode 100644 index 000000000..36e59c037 --- /dev/null +++ b/passbook/flows/management/commands/apply_flow.py @@ -0,0 +1,22 @@ +"""Apply flow from commandline""" +from django.core.management.base import BaseCommand, no_translations + +from passbook.flows.transfer.importer import FlowImporter + + +class Command(BaseCommand): + """Apply flow from commandline""" + + @no_translations + def handle(self, *args, **options): + """Apply all flows in order, abort when one fails to import""" + for flow_path in options.get("flows", []): + with open(flow_path, "r") as flow_file: + importer = FlowImporter(flow_file.read()) + valid = importer.validate() + if not valid: + raise ValueError("Flow invalid") + importer.apply() + + def add_arguments(self, parser): + parser.add_argument("flows", nargs="+", type=str) diff --git a/passbook/flows/migrations/0011_flow_title.py b/passbook/flows/migrations/0011_flow_title.py index c9b4c181a..3cd88e465 100644 --- a/passbook/flows/migrations/0011_flow_title.py +++ b/passbook/flows/migrations/0011_flow_title.py @@ -6,13 +6,13 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor def add_title_for_defaults(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): slug_title_map = { - "default-authentication-flow": "Default Authentication Flow", + "default-authentication-flow": "Welcome to passbook!", "default-invalidation-flow": "Default Invalidation Flow", - "default-source-enrollment": "Default Source Enrollment Flow", - "default-source-authentication": "Default Source Authentication Flow", + "default-source-enrollment": "Welcome to passbook!", + "default-source-authentication": "Welcome to passbook!", "default-provider-authorization-implicit-consent": "Default Provider Authorization Flow (implicit consent)", "default-provider-authorization-explicit-consent": "Default Provider Authorization Flow (explicit consent)", - "default-password-change": "Default Password Change Flow", + "default-password-change": "Change password", } db_alias = schema_editor.connection.alias Flow = apps.get_model("passbook_flows", "Flow") diff --git a/passbook/flows/migrations/0013_auto_20200905_2142.py b/passbook/flows/migrations/0013_auto_20200905_2142.py new file mode 100644 index 000000000..1f97c0581 --- /dev/null +++ b/passbook/flows/migrations/0013_auto_20200905_2142.py @@ -0,0 +1,16 @@ +# Generated by Django 3.1.1 on 2020-09-05 21:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_flows", "0012_auto_20200830_1056"), + ] + + operations = [ + migrations.AlterField( + model_name="stage", name="name", field=models.TextField(unique=True), + ), + ] diff --git a/passbook/flows/models.py b/passbook/flows/models.py index 8120adc31..75bf63a48 100644 --- a/passbook/flows/models.py +++ b/passbook/flows/models.py @@ -46,7 +46,7 @@ class Stage(SerializerModel): stage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) - name = models.TextField() + name = models.TextField(unique=True) objects = InheritanceManager() @@ -170,7 +170,7 @@ class FlowStageBinding(SerializerModel, PolicyBindingModel): return FlowStageBindingSerializer def __str__(self) -> str: - return f"Flow Binding {self.target} -> {self.stage}" + return f"'{self.target}' -> '{self.stage}' # {self.order}" class Meta: diff --git a/passbook/flows/tests/test_transfer.py b/passbook/flows/tests/test_transfer.py index 6913fef0a..576a6ed95 100644 --- a/passbook/flows/tests/test_transfer.py +++ b/passbook/flows/tests/test_transfer.py @@ -1,6 +1,7 @@ """Test flow transfer""" from json import dumps +from django.db import transaction from django.test import TransactionTestCase from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding @@ -21,12 +22,13 @@ class TestFlowTransfer(TransactionTestCase): importer = FlowImporter('{"version": 3}') self.assertFalse(importer.validate()) importer = FlowImporter( - '{"version": 1,"entries":[{"identifier":"","attrs":{},"model": "passbook_core.User"}]}' + '{"version": 1,"entries":[{"identifiers":{},"attrs":{},"model": "passbook_core.User"}]}' ) self.assertFalse(importer.validate()) def test_export_validate_import(self): """Test export and validate it""" + sid = transaction.savepoint() login_stage = UserLoginStage.objects.create(name="default-authentication-login") flow = Flow.objects.create( @@ -40,42 +42,55 @@ class TestFlowTransfer(TransactionTestCase): exporter = FlowExporter(flow) export = exporter.export() + + transaction.savepoint_rollback(sid) + self.assertEqual(len(export.entries), 3) export_json = dumps(export, cls=DataclassEncoder) importer = FlowImporter(export_json) self.assertTrue(importer.validate()) - flow.delete() - login_stage.delete() self.assertTrue(importer.apply()) self.assertTrue(Flow.objects.filter(slug="test").exists()) def test_export_validate_import_policies(self): """Test export and validate it""" + sid = transaction.savepoint() + flow_policy = ExpressionPolicy.objects.create( name="default-source-authentication-if-sso", expression="return True", ) flow = Flow.objects.create( - slug="default-source-authentication", + slug="default-source-authentication-test", designation=FlowDesignation.AUTHENTICATION, name="Welcome to passbook!", ) PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0) user_login = UserLoginStage.objects.create( - name="default-source-authentication-login" + name="default-source-authentication-login-test" ) FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) exporter = FlowExporter(flow) export = exporter.export() + + transaction.savepoint_rollback(sid) + export_json = dumps(export, cls=DataclassEncoder) importer = FlowImporter(export_json) self.assertTrue(importer.validate()) self.assertTrue(importer.apply()) + self.assertTrue( + UserLoginStage.objects.filter( + name="default-source-authentication-login-test" + ).exists() + ) def test_export_validate_import_prompt(self): """Test export and validate it""" + sid = transaction.savepoint() + # First stage fields username_prompt = Prompt.objects.create( field_key="username", label="Username", order=0, type=FieldTypes.TEXT @@ -90,13 +105,13 @@ class TestFlowTransfer(TransactionTestCase): type=FieldTypes.PASSWORD, ) # Stages - first_stage = PromptStage.objects.create(name="prompt-stage-first") + first_stage = PromptStage.objects.create(name="prompt-stage-first-test") first_stage.fields.set([username_prompt, password, password_repeat]) first_stage.save() # Password checking policy password_policy = ExpressionPolicy.objects.create( - name="policy-enrollment-password-equals", + name="policy-enrollment-password-equals-test", expression="return request.context['password'] == request.context['password_repeat']", ) PolicyBinding.objects.create( @@ -105,7 +120,7 @@ class TestFlowTransfer(TransactionTestCase): flow = Flow.objects.create( name="default-enrollment-flow", - slug="default-enrollment-flow", + slug="default-enrollment-flow-test", designation=FlowDesignation.ENROLLMENT, ) @@ -114,6 +129,10 @@ class TestFlowTransfer(TransactionTestCase): exporter = FlowExporter(flow) export = exporter.export() export_json = dumps(export, cls=DataclassEncoder) + + transaction.savepoint_rollback(sid) + importer = FlowImporter(export_json) + self.assertTrue(importer.validate()) self.assertTrue(importer.apply()) diff --git a/passbook/flows/transfer/common.py b/passbook/flows/transfer/common.py index 01b9ae57c..f0288b582 100644 --- a/passbook/flows/transfer/common.py +++ b/passbook/flows/transfer/common.py @@ -11,10 +11,10 @@ from passbook.lib.sentry import SentryIgnoredException def get_attrs(obj: SerializerModel) -> Dict[str, Any]: """Get object's attributes via their serializer, and covert it to a normal dict""" data = dict(obj.serializer(obj).data) - if "policies" in data: - data.pop("policies") - if "stages" in data: - data.pop("stages") + to_remove = ("policies", "stages", "pk") + for to_remove_name in to_remove: + if to_remove_name in data: + data.pop(to_remove_name) return data @@ -22,17 +22,26 @@ def get_attrs(obj: SerializerModel) -> Dict[str, Any]: class FlowBundleEntry: """Single entry of a bundle""" - identifier: str + identifiers: Dict[str, Any] model: str attrs: Dict[str, Any] @staticmethod - def from_model(model: SerializerModel) -> "FlowBundleEntry": + def from_model( + model: SerializerModel, *extra_identifier_names: str + ) -> "FlowBundleEntry": """Convert a SerializerModel instance to a Bundle Entry""" + identifiers = { + "pk": model.pk, + } + all_attrs = get_attrs(model) + + for extra_identifier_name in extra_identifier_names: + identifiers[extra_identifier_name] = all_attrs.pop(extra_identifier_name) return FlowBundleEntry( - identifier=model.pk, + identifiers=identifiers, model=f"{model._meta.app_label}.{model._meta.model_name}", - attrs=get_attrs(model), + attrs=all_attrs, ) diff --git a/passbook/flows/transfer/exporter.py b/passbook/flows/transfer/exporter.py index 9ccd23e1b..f0fa3bad3 100644 --- a/passbook/flows/transfer/exporter.py +++ b/passbook/flows/transfer/exporter.py @@ -28,13 +28,13 @@ class FlowExporter: for stage in stages: if isinstance(stage, PromptStage): pass - yield FlowBundleEntry.from_model(stage) + yield FlowBundleEntry.from_model(stage, "name") def walk_stage_bindings(self) -> Iterator[FlowBundleEntry]: """Convert all bindings attached to self.flow into FlowBundleEntry objects""" bindings = FlowStageBinding.objects.filter(target=self.flow).select_related() for binding in bindings: - yield FlowBundleEntry.from_model(binding) + yield FlowBundleEntry.from_model(binding, "target", "stage", "order") def walk_policies(self) -> Iterator[FlowBundleEntry]: """Walk over all policies and their respective bindings""" @@ -64,7 +64,7 @@ class FlowExporter: def export(self) -> FlowBundle: """Create a list of all objects including the flow""" bundle = FlowBundle() - bundle.entries.append(FlowBundleEntry.from_model(self.flow)) + bundle.entries.append(FlowBundleEntry.from_model(self.flow, "slug")) if self.with_stage_prompts: bundle.entries.extend(self.walk_stage_prompts()) bundle.entries.extend(self.walk_stages()) diff --git a/passbook/flows/transfer/importer.py b/passbook/flows/transfer/importer.py index e8d4f74e0..5b79084d0 100644 --- a/passbook/flows/transfer/importer.py +++ b/passbook/flows/transfer/importer.py @@ -1,12 +1,14 @@ """Flow importer""" from json import loads -from typing import Type +from typing import Any, Dict from dacite import from_dict from dacite.exceptions import DaciteError from django.apps import apps from django.db import transaction from django.db.models import Model +from django.db.models.query_utils import Q +from rest_framework.exceptions import ValidationError from rest_framework.serializers import BaseSerializer, Serializer from structlog import BoundLogger, get_logger @@ -17,7 +19,7 @@ from passbook.flows.transfer.common import ( FlowBundleEntry, ) from passbook.lib.models import SerializerModel -from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel +from passbook.policies.models import Policy, PolicyBinding from passbook.stages.prompt.models import Prompt ALLOWED_MODELS = (Flow, FlowStageBinding, Stage, Policy, PolicyBinding, Prompt) @@ -28,49 +30,42 @@ class FlowImporter: __import: FlowBundle + __pk_map: Dict[Any, Model] + logger: BoundLogger def __init__(self, json_input: str): self.logger = get_logger() + self.__pk_map = {} import_dict = loads(json_input) try: self.__import = from_dict(FlowBundle, import_dict) except DaciteError as exc: raise EntryInvalidError from exc - def validate(self) -> bool: - """Validate loaded flow export, ensure all models are allowed - and serializers have no errors""" - if self.__import.version != 1: - self.logger.warning("Invalid bundle version") - return False - for entry in self.__import.entries: - try: - self._validate_single(entry) - except EntryInvalidError as exc: - self.logger.warning(exc) - return False - return True + def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + """Replace any value if it is a known primary key of an other object""" + for key, value in attrs.items(): + if isinstance(value, (list, dict)): + continue + if value in self.__pk_map: + attrs[key] = self.__pk_map[value] + self.logger.debug( + "updating reference in entry", key=key, new_value=attrs[key] + ) + return attrs - def __get_pk_filed(self, model_class: Type[Model]) -> str: - fields = model_class._meta.get_fields() - pks = [] - for field in fields: - # Ignore base PK from pbm as that isn't the same pk we exported - if field.model in [PolicyBindingModel]: + def __query_from_identifier(self, attrs: Dict[str, Any]) -> Q: + """Generate an or'd query from all identifiers in an entry""" + # Since identifiers can also be pk-references to other objects (see FlowStageBinding) + # we have to ensure those references are also replaced + main_query = Q(pk=attrs["pk"]) + sub_query = Q() + for identifier, value in attrs.items(): + if identifier == "pk": continue - # Ignore primary keys with _ptr suffix as those are surrogate and not what we exported - if field.name.endswith("_ptr"): - continue - if hasattr(field, "primary_key"): - if field.primary_key: - pks.append(field.name) - if len(pks) > 1: - self.logger.debug( - "Found more than one fields with primary_key=True, using pk", pks=pks - ) - return "pk" - return pks[0] + sub_query &= Q(**{identifier: value}) + return main_query | sub_query def _validate_single(self, entry: FlowBundleEntry) -> BaseSerializer: """Validate a single entry""" @@ -82,43 +77,53 @@ class FlowImporter: # If we try to validate without referencing a possible instance # we'll get a duplicate error, hence we load the model here and return # the full serializer for later usage - existing_models = model.objects.filter(pk=entry.identifier) - serializer_kwargs = {"data": entry.attrs} + # Because a model might have multiple unique columns, we chain all identifiers together + # to create an OR query. + updated_identifiers = self.__update_pks_for_attrs(entry.identifiers) + existing_models = model.objects.filter( + self.__query_from_identifier(updated_identifiers) + ) + + serializer_kwargs = {} if existing_models.exists(): + model_instance = existing_models.first() self.logger.debug( - "initialise serializer with instance", instance=existing_models.first() + "initialise serializer with instance", + model=model, + instance=model_instance, + pk=model_instance.pk, ) - serializer_kwargs["instance"] = existing_models.first() + serializer_kwargs["instance"] = model_instance else: - self.logger.debug("initialise new instance", pk=entry.identifier) + self.logger.debug( + "initialise new instance", model=model, **updated_identifiers + ) + full_data = self.__update_pks_for_attrs(entry.attrs) + full_data.update(updated_identifiers) + serializer_kwargs["data"] = full_data serializer: Serializer = model().serializer(**serializer_kwargs) - is_valid = serializer.is_valid() - if not is_valid: - raise EntryInvalidError(f"Serializer errors {serializer.errors}") - if not existing_models.exists(): - # only insert the PK if we're creating a new model, otherwise we get - # an integrity error - model_pk = self.__get_pk_filed(model) - serializer.validated_data[model_pk] = entry.identifier + try: + serializer.is_valid(raise_exception=True) + except ValidationError as exc: + raise EntryInvalidError(f"Serializer errors {serializer.errors}") from exc return serializer def apply(self) -> bool: """Apply (create/update) flow json, in database transaction""" - transaction.set_autocommit(False) + sid = transaction.savepoint() successful = self._apply_models() if not successful: self.logger.debug("Reverting changes due to error") - transaction.rollback() - transaction.set_autocommit(True) + transaction.savepoint_rollback(sid) return False self.logger.debug("Committing changes") - transaction.commit() - transaction.set_autocommit(True) + transaction.savepoint_commit(sid) return True def _apply_models(self) -> bool: """Apply (create/update) flow json""" + self.__pk_map = {} for entry in self.__import.entries: model_app_label, model_name = entry.model.split(".") model: SerializerModel = apps.get_model(model_app_label, model_name) @@ -130,5 +135,20 @@ class FlowImporter: return False model = serializer.save() + self.__pk_map[entry.identifiers["pk"]] = model.pk self.logger.debug("updated model", model=model, pk=model.pk) return True + + def validate(self) -> bool: + """Validate loaded flow export, ensure all models are allowed + and serializers have no errors""" + self.logger.debug("Starting flow import validaton") + if self.__import.version != 1: + self.logger.warning("Invalid bundle version") + return False + sid = transaction.savepoint() + successful = self._apply_models() + if not successful: + self.logger.debug("Flow validation failed") + transaction.savepoint_rollback(sid) + return successful diff --git a/passbook/outposts/signals.py b/passbook/outposts/signals.py index 0f97be49e..880aa4b53 100644 --- a/passbook/outposts/signals.py +++ b/passbook/outposts/signals.py @@ -54,5 +54,5 @@ def _send_update(outpost_model: Model): for outpost in outpost_model.outpost_set.all(): channel_layer = get_channel_layer() for channel in outpost.channels: - print(f"sending update to channel {channel}") + LOGGER.debug("sending update", channel=channel) async_to_sync(channel_layer.send)(channel, {"type": "event.update"}) diff --git a/passbook/root/settings.py b/passbook/root/settings.py index 04feaaccd..43006443e 100644 --- a/passbook/root/settings.py +++ b/passbook/root/settings.py @@ -376,6 +376,7 @@ _LOGGING_HANDLER_MAP = { "docker": "WARNING", "urllib3": "WARNING", "websockets": "WARNING", + "daphne": "WARNING", } for handler_name, level in _LOGGING_HANDLER_MAP.items(): # pyright: reportGeneralTypeIssues=false diff --git a/passbook/stages/prompt/api.py b/passbook/stages/prompt/api.py index 1c4d4a01b..b54d94d4c 100644 --- a/passbook/stages/prompt/api.py +++ b/passbook/stages/prompt/api.py @@ -1,5 +1,6 @@ """Prompt Stage API Views""" -from rest_framework.serializers import ModelSerializer +from rest_framework.serializers import CharField, ModelSerializer +from rest_framework.validators import UniqueValidator from rest_framework.viewsets import ModelViewSet from passbook.stages.prompt.models import Prompt, PromptStage @@ -8,6 +9,8 @@ from passbook.stages.prompt.models import Prompt, PromptStage class PromptStageSerializer(ModelSerializer): """PromptStage Serializer""" + name = CharField(validators=[UniqueValidator(queryset=PromptStage.objects.all())]) + class Meta: model = PromptStage