From 6b781900938531ae1a7d4a2eeedab48f9a26f7bf Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Thu, 22 Dec 2022 21:39:25 +0100 Subject: [PATCH] add importer wrapper that supports multiple yaml documents Signed-off-by: Jens Langhammer --- authentik/blueprints/tests/test_v1.py | 2 +- authentik/blueprints/v1/importer.py | 70 +++++++++++++++++++++------ authentik/blueprints/v1/tasks.py | 5 +- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/authentik/blueprints/tests/test_v1.py b/authentik/blueprints/tests/test_v1.py index b5f6a252a..2ca811fc0 100644 --- a/authentik/blueprints/tests/test_v1.py +++ b/authentik/blueprints/tests/test_v1.py @@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase): ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete() Group.objects.filter(name="test").delete() environ["foo"] = generate_id() - importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), {"bar": "baz"}) + importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), context={"bar": "baz"}) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first() diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 1983d87c1..a4af6e690 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -1,6 +1,7 @@ """Blueprint importer""" from contextlib import contextmanager from copy import deepcopy +from dataclasses import asdict from typing import Any, Optional from dacite.config import Config @@ -17,7 +18,7 @@ from rest_framework.serializers import BaseSerializer, Serializer from structlog.stdlib import BoundLogger, get_logger from structlog.testing import capture_logs from structlog.types import EventDict -from yaml import load +from yaml import load_all from authentik.blueprints.v1.common import ( Blueprint, @@ -77,31 +78,31 @@ def transaction_rollback(): atomic.__exit__(IntegrityError, None, None) -class Importer: +class SingleDocumentImporter: """Import Blueprint from YAML""" logger: BoundLogger + __import: Blueprint - def __init__(self, yaml_input: str, context: Optional[dict] = None): + def __init__(self, raw_blueprint: dict, context: Optional[dict] = None): self.__pk_map: dict[Any, Model] = {} self.logger = get_logger() - import_dict = load(yaml_input, BlueprintLoader) try: - self.__import = from_dict( - Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) + self._import = from_dict( + Blueprint, raw_blueprint, config=Config(cast=[BlueprintEntryDesiredState]) ) except DaciteError as exc: raise EntryInvalidError from exc ctx = {} - always_merger.merge(ctx, self.__import.context) + always_merger.merge(ctx, self._import.context) if context: always_merger.merge(ctx, context) - self.__import.context = ctx + self._import.context = ctx @property def blueprint(self) -> Blueprint: """Get imported blueprint""" - return self.__import + return self._import def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]: """Replace any value if it is a known primary key of an other object""" @@ -147,7 +148,7 @@ class Importer: # pylint: disable-msg=too-many-locals def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: """Validate a single entry""" - if not entry.check_all_conditions_match(self.__import): + if not entry.check_all_conditions_match(self._import): self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") return None @@ -158,7 +159,7 @@ class Importer: raise EntryInvalidError(f"Model {model} not allowed") if issubclass(model, BaseMetaModel): serializer_class: type[Serializer] = model.serializer() - serializer = serializer_class(data=entry.get_attrs(self.__import)) + serializer = serializer_class(data=entry.get_attrs(self._import)) try: serializer.is_valid(raise_exception=True) except ValidationError as exc: @@ -172,7 +173,7 @@ class Importer: # the full serializer for later usage # Because a model might have multiple unique columns, we chain all identifiers together # to create an OR query. - updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self.__import)) + updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self._import)) for key, value in list(updated_identifiers.items()): if isinstance(value, dict) and "pk" in value: del updated_identifiers[key] @@ -211,7 +212,7 @@ class Importer: model_instance.pk = updated_identifiers["pk"] serializer_kwargs["instance"] = model_instance try: - full_data = self.__update_pks_for_attrs(entry.get_attrs(self.__import)) + full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) except ValueError as exc: raise EntryInvalidError(exc) from exc always_merger.merge(full_data, updated_identifiers) @@ -282,8 +283,8 @@ class Importer: """Validate loaded blueprint export, ensure all models are allowed and serializers have no errors""" self.logger.debug("Starting blueprint import validation") - orig_import = deepcopy(self.__import) - if self.__import.version != 1: + orig_import = deepcopy(self._import) + if self._import.version != 1: self.logger.warning("Invalid blueprint version") return False, [] with ( @@ -295,5 +296,42 @@ class Importer: self.logger.debug("Blueprint validation failed") for log in logs: getattr(self.logger, log.get("log_level"))(**log) - self.__import = orig_import + self._import = orig_import return successful, logs + + +class Importer: + """Importer capable of importing multi-document YAML""" + + _importers: list[SingleDocumentImporter] + + def __init__(self, *yaml_input: str, context: Optional[dict] = None): + docs = [] + for doc in yaml_input: + docs += load_all(doc, BlueprintLoader) + self._importers = [] + for doc in docs: + self._importers.append(SingleDocumentImporter(doc, context)) + + @property + def metadata(self) -> dict: + """Get the merged metadata of all blueprints""" + metadata = {} + for importer in self._importers: + if importer._import.metadata: + always_merger.merge(metadata, asdict(importer._import.metadata)) + return metadata + + def apply(self) -> bool: + """Apply all importers""" + return all(x.apply() for x in self._importers) + + def validate(self) -> tuple[bool, list[EventDict]]: + """Validate all importers""" + valid = [] + events = [] + for importer in self._importers: + _valid, _events = importer.validate() + valid.append(_valid) + events += _events + return all(valid), events diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 6f74694e3..792811b14 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -186,9 +186,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str): return blueprint_content = instance.retrieve() file_hash = sha512(blueprint_content.encode()).hexdigest() - importer = Importer(blueprint_content, instance.context) - if importer.blueprint.metadata: - instance.metadata = asdict(importer.blueprint.metadata) + importer = Importer(blueprint_content, context=instance.context) + instance.metadata = importer.metadata valid, logs = importer.validate() if not valid: instance.status = BlueprintInstanceStatus.ERROR