allow fetching of blueprint instance content to return multiple contents

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-12-22 22:49:27 +01:00 committed by Jens Langhammer
parent 6b78190093
commit 9dfe06fb18
No known key found for this signature in database
7 changed files with 24 additions and 16 deletions

View File

@ -18,7 +18,7 @@ class Command(BaseCommand):
"""Apply all blueprints in order, abort when one fails to import"""
for blueprint_path in options.get("blueprints", []):
content = BlueprintInstance(path=blueprint_path).retrieve()
importer = Importer(content)
importer = Importer(*content)
valid, logs = importer.validate()
if not valid:
for log in logs:

View File

@ -70,7 +70,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
enabled = models.BooleanField(default=True)
managed_models = ArrayField(models.TextField(), default=list)
def retrieve_oci(self) -> str:
def retrieve_oci(self) -> list[str]:
"""Get blueprint from an OCI registry"""
client = BlueprintOCIClient(self.path.replace("oci://", "https://"))
try:
@ -79,16 +79,16 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
except OCIException as exc:
raise BlueprintRetrievalFailed(exc) from exc
def retrieve_file(self) -> str:
def retrieve_file(self) -> list[str]:
"""Get blueprint from path"""
try:
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path))
with full_path.open("r", encoding="utf-8") as _file:
return _file.read()
return [_file.read()]
except (IOError, OSError) as exc:
raise BlueprintRetrievalFailed(exc) from exc
def retrieve(self) -> str:
def retrieve(self) -> list[str]:
"""Retrieve blueprint contents"""
if self.path.startswith("oci://"):
return self.retrieve_oci()

View File

@ -21,7 +21,7 @@ def apply_blueprint(*files: str):
def wrapper(*args, **kwargs):
for file in files:
content = BlueprintInstance(path=file).retrieve()
Importer(content).apply()
Importer(*content).apply()
return func(*args, **kwargs)
return wrapper

View File

@ -29,7 +29,7 @@ class TestBlueprintOCI(TransactionTestCase):
BlueprintInstance(
path="oci://ghcr.io/goauthentik/blueprints/test:latest"
).retrieve(),
"foo",
["foo"],
)
def test_manifests_error(self):

View File

@ -25,7 +25,8 @@ def blueprint_tester(file_name: Path) -> Callable:
def tester(self: TestPackaged):
base = Path("blueprints/")
rel_path = Path(file_name).relative_to(base)
importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve())
contents = BlueprintInstance(path=str(rel_path)).retrieve()
importer = Importer(*contents)
self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply())

View File

@ -75,22 +75,29 @@ class BlueprintOCIClient:
raise OCIException(manifest["errors"])
return manifest
def fetch_blobs(self, manifest: dict[str, Any]):
def fetch_blobs(self, manifest: dict[str, Any]) -> list[str]:
"""Fetch blob based on manifest info"""
blob = None
blob_digests = []
for layer in manifest.get("layers", []):
if layer.get("mediaType", "") == OCI_MEDIA_TYPE:
blob = layer.get("digest")
self.logger.debug("Found layer with matching media type", blob=blob)
if not blob:
blob_digests.append(layer.get("digest"))
if not blob_digests:
raise OCIException("Blob not found")
bodies = []
for blob in blob_digests:
bodies.append(self.fetch_blob(blob))
self.logger.debug("Fetched blobs", count=len(bodies))
return bodies
def fetch_blob(self, digest: str) -> str:
"""Fetch blob based on manifest info"""
blob_request = self.client.NewRequest(
"GET",
"/v2/<name>/blobs/<digest>",
WithDigest(blob),
WithDigest(digest),
)
try:
self.logger.debug("Fetching blob", digest=digest)
blob_response = self.client.Do(blob_request)
blob_response.raise_for_status()
return blob_response.text

View File

@ -185,8 +185,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
if not instance or not instance.enabled:
return
blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer(blueprint_content, context=instance.context)
file_hash = sha512("".join(blueprint_content).encode()).hexdigest()
importer = Importer(*blueprint_content, context=instance.context)
instance.metadata = importer.metadata
valid, logs = importer.validate()
if not valid: