add importer wrapper that supports multiple yaml documents
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
47e663f48c
commit
6b78190093
|
@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase):
|
|||
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
|
||||
Group.objects.filter(name="test").delete()
|
||||
environ["foo"] = generate_id()
|
||||
importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), {"bar": "baz"})
|
||||
importer = Importer(load_yaml_fixture("fixtures/tags.yaml"), context={"bar": "baz"})
|
||||
self.assertTrue(importer.validate()[0])
|
||||
self.assertTrue(importer.apply())
|
||||
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Blueprint importer"""
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from dacite.config import Config
|
||||
|
@ -17,7 +18,7 @@ from rest_framework.serializers import BaseSerializer, Serializer
|
|||
from structlog.stdlib import BoundLogger, get_logger
|
||||
from structlog.testing import capture_logs
|
||||
from structlog.types import EventDict
|
||||
from yaml import load
|
||||
from yaml import load_all
|
||||
|
||||
from authentik.blueprints.v1.common import (
|
||||
Blueprint,
|
||||
|
@ -77,31 +78,31 @@ def transaction_rollback():
|
|||
atomic.__exit__(IntegrityError, None, None)
|
||||
|
||||
|
||||
class Importer:
|
||||
class SingleDocumentImporter:
|
||||
"""Import Blueprint from YAML"""
|
||||
|
||||
logger: BoundLogger
|
||||
__import: Blueprint
|
||||
|
||||
def __init__(self, yaml_input: str, context: Optional[dict] = None):
|
||||
def __init__(self, raw_blueprint: dict, context: Optional[dict] = None):
|
||||
self.__pk_map: dict[Any, Model] = {}
|
||||
self.logger = get_logger()
|
||||
import_dict = load(yaml_input, BlueprintLoader)
|
||||
try:
|
||||
self.__import = from_dict(
|
||||
Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState])
|
||||
self._import = from_dict(
|
||||
Blueprint, raw_blueprint, config=Config(cast=[BlueprintEntryDesiredState])
|
||||
)
|
||||
except DaciteError as exc:
|
||||
raise EntryInvalidError from exc
|
||||
ctx = {}
|
||||
always_merger.merge(ctx, self.__import.context)
|
||||
always_merger.merge(ctx, self._import.context)
|
||||
if context:
|
||||
always_merger.merge(ctx, context)
|
||||
self.__import.context = ctx
|
||||
self._import.context = ctx
|
||||
|
||||
@property
|
||||
def blueprint(self) -> Blueprint:
|
||||
"""Get imported blueprint"""
|
||||
return self.__import
|
||||
return self._import
|
||||
|
||||
def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace any value if it is a known primary key of an other object"""
|
||||
|
@ -147,7 +148,7 @@ class Importer:
|
|||
# pylint: disable-msg=too-many-locals
|
||||
def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]:
|
||||
"""Validate a single entry"""
|
||||
if not entry.check_all_conditions_match(self.__import):
|
||||
if not entry.check_all_conditions_match(self._import):
|
||||
self.logger.debug("One or more conditions of this entry are not fulfilled, skipping")
|
||||
return None
|
||||
|
||||
|
@ -158,7 +159,7 @@ class Importer:
|
|||
raise EntryInvalidError(f"Model {model} not allowed")
|
||||
if issubclass(model, BaseMetaModel):
|
||||
serializer_class: type[Serializer] = model.serializer()
|
||||
serializer = serializer_class(data=entry.get_attrs(self.__import))
|
||||
serializer = serializer_class(data=entry.get_attrs(self._import))
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
except ValidationError as exc:
|
||||
|
@ -172,7 +173,7 @@ class Importer:
|
|||
# the full serializer for later usage
|
||||
# Because a model might have multiple unique columns, we chain all identifiers together
|
||||
# to create an OR query.
|
||||
updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self.__import))
|
||||
updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self._import))
|
||||
for key, value in list(updated_identifiers.items()):
|
||||
if isinstance(value, dict) and "pk" in value:
|
||||
del updated_identifiers[key]
|
||||
|
@ -211,7 +212,7 @@ class Importer:
|
|||
model_instance.pk = updated_identifiers["pk"]
|
||||
serializer_kwargs["instance"] = model_instance
|
||||
try:
|
||||
full_data = self.__update_pks_for_attrs(entry.get_attrs(self.__import))
|
||||
full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
|
||||
except ValueError as exc:
|
||||
raise EntryInvalidError(exc) from exc
|
||||
always_merger.merge(full_data, updated_identifiers)
|
||||
|
@ -282,8 +283,8 @@ class Importer:
|
|||
"""Validate loaded blueprint export, ensure all models are allowed
|
||||
and serializers have no errors"""
|
||||
self.logger.debug("Starting blueprint import validation")
|
||||
orig_import = deepcopy(self.__import)
|
||||
if self.__import.version != 1:
|
||||
orig_import = deepcopy(self._import)
|
||||
if self._import.version != 1:
|
||||
self.logger.warning("Invalid blueprint version")
|
||||
return False, []
|
||||
with (
|
||||
|
@ -295,5 +296,42 @@ class Importer:
|
|||
self.logger.debug("Blueprint validation failed")
|
||||
for log in logs:
|
||||
getattr(self.logger, log.get("log_level"))(**log)
|
||||
self.__import = orig_import
|
||||
self._import = orig_import
|
||||
return successful, logs
|
||||
|
||||
|
||||
class Importer:
|
||||
"""Importer capable of importing multi-document YAML"""
|
||||
|
||||
_importers: list[SingleDocumentImporter]
|
||||
|
||||
def __init__(self, *yaml_input: str, context: Optional[dict] = None):
|
||||
docs = []
|
||||
for doc in yaml_input:
|
||||
docs += load_all(doc, BlueprintLoader)
|
||||
self._importers = []
|
||||
for doc in docs:
|
||||
self._importers.append(SingleDocumentImporter(doc, context))
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict:
|
||||
"""Get the merged metadata of all blueprints"""
|
||||
metadata = {}
|
||||
for importer in self._importers:
|
||||
if importer._import.metadata:
|
||||
always_merger.merge(metadata, asdict(importer._import.metadata))
|
||||
return metadata
|
||||
|
||||
def apply(self) -> bool:
|
||||
"""Apply all importers"""
|
||||
return all(x.apply() for x in self._importers)
|
||||
|
||||
def validate(self) -> tuple[bool, list[EventDict]]:
|
||||
"""Validate all importers"""
|
||||
valid = []
|
||||
events = []
|
||||
for importer in self._importers:
|
||||
_valid, _events = importer.validate()
|
||||
valid.append(_valid)
|
||||
events += _events
|
||||
return all(valid), events
|
||||
|
|
|
@ -186,9 +186,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
|||
return
|
||||
blueprint_content = instance.retrieve()
|
||||
file_hash = sha512(blueprint_content.encode()).hexdigest()
|
||||
importer = Importer(blueprint_content, instance.context)
|
||||
if importer.blueprint.metadata:
|
||||
instance.metadata = asdict(importer.blueprint.metadata)
|
||||
importer = Importer(blueprint_content, context=instance.context)
|
||||
instance.metadata = importer.metadata
|
||||
valid, logs = importer.validate()
|
||||
if not valid:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
|
|
Reference in New Issue