Compare commits
1 Commits
trustchain
...
root/confi
Author | SHA1 | Date |
---|---|---|
Jens Langhammer | 3fa987f443 |
|
@ -5,13 +5,20 @@ from contextlib import contextmanager
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
|
from pathlib import Path
|
||||||
from sys import argv, stderr
|
from sys import argv, stderr
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from django.conf import ImproperlyConfigured
|
from django.conf import ImproperlyConfigured
|
||||||
|
from watchdog.events import (
|
||||||
|
FileModifiedEvent,
|
||||||
|
FileSystemEvent,
|
||||||
|
FileSystemEventHandler,
|
||||||
|
)
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
|
||||||
SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob(
|
SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob(
|
||||||
"/etc/authentik/config.d/*.yml", recursive=True
|
"/etc/authentik/config.d/*.yml", recursive=True
|
||||||
|
@ -38,9 +45,47 @@ class ConfigLoader:
|
||||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||||
|
|
||||||
loaded_file = []
|
loaded_file = []
|
||||||
|
observer: Observer
|
||||||
|
|
||||||
|
class FSObserver(FileSystemEventHandler):
|
||||||
|
"""File system observer"""
|
||||||
|
|
||||||
|
loader: "ConfigLoader"
|
||||||
|
path: str
|
||||||
|
container: Optional[dict] = None
|
||||||
|
key: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loader: "ConfigLoader",
|
||||||
|
path: str,
|
||||||
|
container: Optional[dict] = None,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.loader = loader
|
||||||
|
self.path = path
|
||||||
|
self.container = container
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def on_any_event(self, event: FileSystemEvent):
|
||||||
|
if not isinstance(event, FileModifiedEvent):
|
||||||
|
return
|
||||||
|
if event.is_directory:
|
||||||
|
return
|
||||||
|
if event.src_path != self.path:
|
||||||
|
return
|
||||||
|
if self.container and self.key:
|
||||||
|
with open(self.path, "r", encoding="utf8") as _file:
|
||||||
|
self.container[self.key] = _file.read()
|
||||||
|
else:
|
||||||
|
self.loader.log("info", "Updating from changed file", file=self.path)
|
||||||
|
self.loader.update_from_file(self.path, watch=False)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.observer = Observer()
|
||||||
|
self.observer.start()
|
||||||
self.__config = {}
|
self.__config = {}
|
||||||
base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../.."))
|
base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
for path in SEARCH_PATHS:
|
for path in SEARCH_PATHS:
|
||||||
|
@ -81,11 +126,11 @@ class ConfigLoader:
|
||||||
root[key] = self.update(root.get(key, {}), value)
|
root[key] = self.update(root.get(key, {}), value)
|
||||||
else:
|
else:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = self.parse_uri(value)
|
value = self.parse_uri(value, root, key)
|
||||||
root[key] = value
|
root[key] = value
|
||||||
return root
|
return root
|
||||||
|
|
||||||
def parse_uri(self, value: str) -> str:
|
def parse_uri(self, value: str, container: dict[str, Any], key: Optional[str] = None, ) -> str:
|
||||||
"""Parse string values which start with a URI"""
|
"""Parse string values which start with a URI"""
|
||||||
url = urlparse(value)
|
url = urlparse(value)
|
||||||
if url.scheme == "env":
|
if url.scheme == "env":
|
||||||
|
@ -93,13 +138,23 @@ class ConfigLoader:
|
||||||
if url.scheme == "file":
|
if url.scheme == "file":
|
||||||
try:
|
try:
|
||||||
with open(url.path, "r", encoding="utf8") as _file:
|
with open(url.path, "r", encoding="utf8") as _file:
|
||||||
value = _file.read().strip()
|
value = _file.read()
|
||||||
|
if key:
|
||||||
|
self.observer.schedule(
|
||||||
|
ConfigLoader.FSObserver(
|
||||||
|
self,
|
||||||
|
url.path,
|
||||||
|
container,
|
||||||
|
key,
|
||||||
|
),
|
||||||
|
Path(url.path).parent,
|
||||||
|
)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
||||||
value = url.query
|
value = url.query
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def update_from_file(self, path: str):
|
def update_from_file(self, path: str, watch=True):
|
||||||
"""Update config from file contents"""
|
"""Update config from file contents"""
|
||||||
try:
|
try:
|
||||||
with open(path, encoding="utf8") as file:
|
with open(path, encoding="utf8") as file:
|
||||||
|
@ -107,6 +162,8 @@ class ConfigLoader:
|
||||||
self.update(self.__config, yaml.safe_load(file))
|
self.update(self.__config, yaml.safe_load(file))
|
||||||
self.log("debug", "Loaded config", file=path)
|
self.log("debug", "Loaded config", file=path)
|
||||||
self.loaded_file.append(path)
|
self.loaded_file.append(path)
|
||||||
|
if watch:
|
||||||
|
self.observer.schedule(ConfigLoader.FSObserver(self, path), Path(path).parent)
|
||||||
except yaml.YAMLError as exc:
|
except yaml.YAMLError as exc:
|
||||||
raise ImproperlyConfigured from exc
|
raise ImproperlyConfigured from exc
|
||||||
except PermissionError as exc:
|
except PermissionError as exc:
|
||||||
|
@ -181,13 +238,12 @@ class ConfigLoader:
|
||||||
if comp not in root:
|
if comp not in root:
|
||||||
root[comp] = {}
|
root[comp] = {}
|
||||||
root = root.get(comp, {})
|
root = root.get(comp, {})
|
||||||
root[path_parts[-1]] = value
|
self.parse_uri(value, root, path_parts[-1])
|
||||||
|
|
||||||
def y_bool(self, path: str, default=False) -> bool:
|
def y_bool(self, path: str, default=False) -> bool:
|
||||||
"""Wrapper for y that converts value into boolean"""
|
"""Wrapper for y that converts value into boolean"""
|
||||||
return str(self.y(path, default)).lower() == "true"
|
return str(self.y(path, default)).lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
CONFIG = ConfigLoader()
|
CONFIG = ConfigLoader()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -5,7 +5,7 @@ from tempfile import mkstemp
|
||||||
from django.conf import ImproperlyConfigured
|
from django.conf import ImproperlyConfigured
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from authentik.lib.config import ENV_PREFIX, ConfigLoader
|
from authentik.lib.config import CONFIG, ENV_PREFIX, ConfigLoader
|
||||||
|
|
||||||
|
|
||||||
class TestConfig(TestCase):
|
class TestConfig(TestCase):
|
||||||
|
@ -31,8 +31,8 @@ class TestConfig(TestCase):
|
||||||
"""Test URI parsing (environment)"""
|
"""Test URI parsing (environment)"""
|
||||||
config = ConfigLoader()
|
config = ConfigLoader()
|
||||||
environ["foo"] = "bar"
|
environ["foo"] = "bar"
|
||||||
self.assertEqual(config.parse_uri("env://foo"), "bar")
|
self.assertEqual(config.parse_uri("env://foo", {}), "bar")
|
||||||
self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
|
self.assertEqual(config.parse_uri("env://foo?bar", {}), "bar")
|
||||||
|
|
||||||
def test_uri_file(self):
|
def test_uri_file(self):
|
||||||
"""Test URI parsing (file load)"""
|
"""Test URI parsing (file load)"""
|
||||||
|
@ -41,8 +41,8 @@ class TestConfig(TestCase):
|
||||||
write(file, "foo".encode())
|
write(file, "foo".encode())
|
||||||
_, file2_name = mkstemp()
|
_, file2_name = mkstemp()
|
||||||
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
||||||
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
|
self.assertEqual(config.parse_uri(f"file://{file_name}", {}), "foo")
|
||||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
|
self.assertEqual(config.parse_uri(f"file://{file2_name}?def", {}), "def")
|
||||||
unlink(file_name)
|
unlink(file_name)
|
||||||
unlink(file2_name)
|
unlink(file2_name)
|
||||||
|
|
||||||
|
@ -59,3 +59,13 @@ class TestConfig(TestCase):
|
||||||
config.update_from_file(file2_name)
|
config.update_from_file(file2_name)
|
||||||
unlink(file_name)
|
unlink(file_name)
|
||||||
unlink(file2_name)
|
unlink(file2_name)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
"""Test change to file"""
|
||||||
|
file, file_name = mkstemp()
|
||||||
|
write(file, b"test")
|
||||||
|
CONFIG.y_set("test.file", f"file://{file_name}")
|
||||||
|
self.assertEqual(CONFIG.y("test.file"), "test")
|
||||||
|
write(file, "test2")
|
||||||
|
self.assertEqual(CONFIG.y("test.file"), "test2")
|
||||||
|
unlink(file_name)
|
||||||
|
|
Reference in New Issue