blueprints: allow for adding remote blueprints (#3435)

* allow blueprints to be fetched from HTTP URLs

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* fix tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* remove os.path

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* add validation for blueprint path

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* fix tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L 2022-08-17 22:00:47 +01:00 committed by GitHub
parent e87236b285
commit 1adc6948b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 97 additions and 35 deletions

View File

@ -1,6 +1,7 @@
"""Serializer mixin for managed models""" """Serializer mixin for managed models"""
from drf_spectacular.utils import extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, DateTimeField, JSONField from rest_framework.fields import CharField, DateTimeField, JSONField
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
from rest_framework.request import Request from rest_framework.request import Request
@ -9,7 +10,7 @@ from rest_framework.serializers import ListSerializer, ModelSerializer
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance, BlueprintRetrievalFailed
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
@ -31,6 +32,14 @@ class MetadataSerializer(PassiveSerializer):
class BlueprintInstanceSerializer(ModelSerializer): class BlueprintInstanceSerializer(ModelSerializer):
"""Info about a single blueprint instance file""" """Info about a single blueprint instance file"""
def validate_path(self, path: str) -> str:
"""Ensure the path specified is retrievable"""
try:
BlueprintInstance(path=path).retrieve()
except BlueprintRetrievalFailed as exc:
raise ValidationError(exc) from exc
return path
class Meta: class Meta:
model = BlueprintInstance model = BlueprintInstance

View File

@ -2,6 +2,7 @@
from django.core.management.base import BaseCommand, no_translations from django.core.management.base import BaseCommand, no_translations
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
LOGGER = get_logger() LOGGER = get_logger()
@ -14,14 +15,14 @@ class Command(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
"""Apply all blueprints in order, abort when one fails to import""" """Apply all blueprints in order, abort when one fails to import"""
for blueprint_path in options.get("blueprints", []): for blueprint_path in options.get("blueprints", []):
with open(blueprint_path, "r", encoding="utf8") as blueprint_file: content = BlueprintInstance(path=blueprint_path).retrieve()
importer = Importer(blueprint_file.read()) importer = Importer(content)
valid, logs = importer.validate() valid, logs = importer.validate()
if not valid: if not valid:
for log in logs: for log in logs:
LOGGER.debug(**log) LOGGER.debug(**log)
raise ValueError("blueprint invalid") raise ValueError("blueprint invalid")
importer.apply() importer.apply()
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("blueprints", nargs="+", type=str) parser.add_argument("blueprints", nargs="+", type=str)

View File

@ -1,12 +1,23 @@
"""Managed Object models""" """Managed Object models"""
from pathlib import Path
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from requests import RequestException
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.lib.config import CONFIG
from authentik.lib.models import CreatedUpdatedModel, SerializerModel from authentik.lib.models import CreatedUpdatedModel, SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.http import get_http_session
class BlueprintRetrievalFailed(SentryIgnoredException):
"""Error raised when we're unable to fetch the blueprint contents, whether it be HTTP files
not being accessible or local files not being readable"""
class ManagedModel(models.Model): class ManagedModel(models.Model):
@ -60,6 +71,19 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
managed_models = ArrayField(models.TextField(), default=list) managed_models = ArrayField(models.TextField(), default=list)
def retrieve(self) -> str:
"""Retrieve blueprint contents"""
if urlparse(self.path).scheme != "":
try:
res = get_http_session().get(self.path, timeout=3, allow_redirects=True)
res.raise_for_status()
return res.text
except RequestException as exc:
raise BlueprintRetrievalFailed(exc) from exc
path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path))
with path.open("r", encoding="utf-8") as _file:
return _file.read()
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
from authentik.blueprints.api import BlueprintInstanceSerializer from authentik.blueprints.api import BlueprintInstanceSerializer

View File

@ -6,6 +6,7 @@ from typing import Callable
from django.apps import apps from django.apps import apps
from authentik.blueprints.manager import ManagedAppConfig from authentik.blueprints.manager import ManagedAppConfig
from authentik.blueprints.models import BlueprintInstance
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -19,11 +20,9 @@ def apply_blueprint(*files: str):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
base_path = Path(CONFIG.y("blueprints_dir"))
for file in files: for file in files:
full_path = Path(base_path, file) content = BlueprintInstance(path=file).retrieve()
with full_path.open("r", encoding="utf-8") as _file: Importer(content).apply()
Importer(_file.read()).apply()
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper

View File

@ -4,6 +4,7 @@ from typing import Callable
from django.test import TransactionTestCase from django.test import TransactionTestCase
from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -18,12 +19,13 @@ class TestBundled(TransactionTestCase):
self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists()) self.assertTrue(Tenant.objects.filter(domain="authentik-default").exists())
def blueprint_tester(file_name: str) -> Callable: def blueprint_tester(file_name: Path) -> Callable:
"""This is used instead of subTest for better visibility""" """This is used instead of subTest for better visibility"""
def tester(self: TestBundled): def tester(self: TestBundled):
with open(file_name, "r", encoding="utf8") as blueprint: base = Path("blueprints/")
importer = Importer(blueprint.read()) rel_path = Path(file_name).relative_to(base)
importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve())
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())

View File

@ -1,4 +1,5 @@
"""Test blueprints v1 tasks""" """Test blueprints v1 tasks"""
from hashlib import sha512
from tempfile import NamedTemporaryFile, mkdtemp from tempfile import NamedTemporaryFile, mkdtemp
from django.test import TransactionTestCase from django.test import TransactionTestCase
@ -36,25 +37,32 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
@CONFIG.patch("blueprints_dir", TMP) @CONFIG.patch("blueprints_dir", TMP)
def test_valid(self): def test_valid(self):
"""Test valid file""" """Test valid file"""
blueprint_id = generate_id()
with NamedTemporaryFile(mode="w+", suffix=".yaml", dir=TMP) as file: with NamedTemporaryFile(mode="w+", suffix=".yaml", dir=TMP) as file:
file.write( file.write(
dump( dump(
{ {
"version": 1, "version": 1,
"entries": [], "entries": [],
"metadata": {
"name": blueprint_id,
},
} }
) )
) )
file.seek(0)
file_hash = sha512(file.read().encode()).hexdigest()
file.flush() file.flush()
blueprints_discover() # pylint: disable=no-value-for-parameter blueprints_discover() # pylint: disable=no-value-for-parameter
instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
self.assertEqual(instance.last_applied_hash, file_hash)
self.assertEqual( self.assertEqual(
BlueprintInstance.objects.first().last_applied_hash, instance.metadata,
( {
"e52bb445b03cd36057258dc9f0ce0fbed8278498ee1470e45315293e5f026d1b" "name": blueprint_id,
"d1f9b3526871c0003f5c07be5c3316d9d4a08444bd8fed1b3f03294e51e44522" "labels": {},
), },
) )
self.assertEqual(BlueprintInstance.objects.first().metadata, {})
@CONFIG.patch("blueprints_dir", TMP) @CONFIG.patch("blueprints_dir", TMP)
def test_valid_updated(self): def test_valid_updated(self):

View File

@ -8,10 +8,15 @@ from dacite import from_dict
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from structlog.stdlib import get_logger
from yaml import load from yaml import load
from yaml.error import YAMLError from yaml.error import YAMLError
from authentik.blueprints.models import BlueprintInstance, BlueprintInstanceStatus from authentik.blueprints.models import (
BlueprintInstance,
BlueprintInstanceStatus,
BlueprintRetrievalFailed,
)
from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
@ -25,6 +30,8 @@ from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
@dataclass @dataclass
class BlueprintFile: class BlueprintFile:
@ -54,21 +61,29 @@ def blueprints_find():
root = Path(CONFIG.y("blueprints_dir")) root = Path(CONFIG.y("blueprints_dir"))
for file in root.glob("**/*.yaml"): for file in root.glob("**/*.yaml"):
path = Path(file) path = Path(file)
LOGGER.debug("found blueprint", path=str(path))
with open(path, "r", encoding="utf-8") as blueprint_file: with open(path, "r", encoding="utf-8") as blueprint_file:
try: try:
raw_blueprint = load(blueprint_file.read(), BlueprintLoader) raw_blueprint = load(blueprint_file.read(), BlueprintLoader)
except YAMLError: except YAMLError as exc:
raw_blueprint = None raw_blueprint = None
LOGGER.warning("failed to parse blueprint", exc=exc, path=str(path))
if not raw_blueprint: if not raw_blueprint:
continue continue
metadata = raw_blueprint.get("metadata", None) metadata = raw_blueprint.get("metadata", None)
version = raw_blueprint.get("version", 1) version = raw_blueprint.get("version", 1)
if version != 1: if version != 1:
LOGGER.warning("invalid blueprint version", version=version, path=str(path))
continue continue
file_hash = sha512(path.read_bytes()).hexdigest() file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime) blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint) blueprints.append(blueprint)
LOGGER.info(
"parsed & loaded blueprint",
hash=file_hash,
path=str(path),
)
return blueprints return blueprints
@ -127,10 +142,9 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
if not instance or not instance.enabled: if not instance or not instance.enabled:
return return
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(instance.path)) blueprint_content = instance.retrieve()
file_hash = sha512(full_path.read_bytes()).hexdigest() file_hash = sha512(blueprint_content.encode()).hexdigest()
with open(full_path, "r", encoding="utf-8") as blueprint_file: importer = Importer(blueprint_content, instance.context)
importer = Importer(blueprint_file.read(), instance.context)
valid, logs = importer.validate() valid, logs = importer.validate()
if not valid: if not valid:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
@ -148,7 +162,13 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
instance.last_applied = now() instance.last_applied = now()
instance.save() instance.save()
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
except (DatabaseError, ProgrammingError, InternalError, IOError) as exc: except (
DatabaseError,
ProgrammingError,
InternalError,
IOError,
BlueprintRetrievalFailed,
) as exc:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
instance.save() instance.save()
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))

View File

@ -1,6 +1,5 @@
"""outpost tasks""" """outpost tasks"""
from os import R_OK, access from os import R_OK, access
from os.path import expanduser
from pathlib import Path from pathlib import Path
from socket import gethostname from socket import gethostname
from typing import Any, Optional from typing import Any, Optional
@ -252,13 +251,13 @@ def outpost_local_connection():
name="Local Kubernetes Cluster", local=True, kubeconfig={} name="Local Kubernetes Cluster", local=True, kubeconfig={}
) )
# For development, check for the existence of a kubeconfig file # For development, check for the existence of a kubeconfig file
kubeconfig_path = expanduser(KUBE_CONFIG_DEFAULT_LOCATION) kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
if Path(kubeconfig_path).exists(): if kubeconfig_path.exists():
LOGGER.debug("Detected kubeconfig") LOGGER.debug("Detected kubeconfig")
kubeconfig_local_name = f"k8s-{gethostname()}" kubeconfig_local_name = f"k8s-{gethostname()}"
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
LOGGER.debug("Creating kubeconfig Service Connection") LOGGER.debug("Creating kubeconfig Service Connection")
with open(kubeconfig_path, "r", encoding="utf8") as _kubeconfig: with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
KubernetesServiceConnection.objects.create( KubernetesServiceConnection.objects.create(
name=kubeconfig_local_name, name=kubeconfig_local_name,
kubeconfig=yaml.safe_load(_kubeconfig), kubeconfig=yaml.safe_load(_kubeconfig),

View File

@ -1,5 +1,5 @@
"""test OAuth Source""" """test OAuth Source"""
from os.path import abspath from pathlib import Path
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Optional from typing import Any, Optional
@ -116,7 +116,7 @@ class TestSourceOAuth2(SeleniumTestCase):
interval=5 * 100 * 1000000, interval=5 * 100 * 1000000,
start_period=1 * 100 * 1000000, start_period=1 * 100 * 1000000,
), ),
"volumes": {abspath(CONFIG_PATH): {"bind": "/config.yml", "mode": "ro"}}, "volumes": {str(Path(CONFIG_PATH).absolute()): {"bind": "/config.yml", "mode": "ro"}},
} }
def create_objects(self): def create_objects(self):