root: partial Live-updating config (#5959)

* stages/email: directly use email credentials from config

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use custom database backend that supports dynamic credentials

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add crude config reloader

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make method names for CONFIG clearer

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* replace config.set with environ

Not sure if this is the cleanest way, but it persists through a config reload

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* re-add set for @patch

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* even more crudeness

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* clean up some old stuff?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* somewhat rewrite config loader to keep track of a source of an attribute so we can refresh it

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* cleanup old things

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix flow e2e

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-07-19 23:13:22 +02:00 committed by GitHub
parent fb4e4dc8db
commit 2f469d2709
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
44 changed files with 260 additions and 184 deletions

View file

@ -58,7 +58,7 @@ def clear_update_notifications():
@prefill_task @prefill_task
def update_latest_version(self: MonitoredTask): def update_latest_version(self: MonitoredTask):
"""Update latest version info""" """Update latest version info"""
if CONFIG.y_bool("disable_update_check"): if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."])) self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
return return

View file

@ -70,7 +70,7 @@ class ConfigView(APIView):
caps.append(Capabilities.CAN_SAVE_MEDIA) caps.append(Capabilities.CAN_SAVE_MEDIA)
if GEOIP_READER.enabled: if GEOIP_READER.enabled:
caps.append(Capabilities.CAN_GEO_IP) caps.append(Capabilities.CAN_GEO_IP)
if CONFIG.y_bool("impersonation"): if CONFIG.get_bool("impersonation"):
caps.append(Capabilities.CAN_IMPERSONATE) caps.append(Capabilities.CAN_IMPERSONATE)
if settings.DEBUG: # pragma: no cover if settings.DEBUG: # pragma: no cover
caps.append(Capabilities.CAN_DEBUG) caps.append(Capabilities.CAN_DEBUG)
@ -86,17 +86,17 @@ class ConfigView(APIView):
return ConfigSerializer( return ConfigSerializer(
{ {
"error_reporting": { "error_reporting": {
"enabled": CONFIG.y("error_reporting.enabled"), "enabled": CONFIG.get("error_reporting.enabled"),
"sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"), "sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"),
"environment": CONFIG.y("error_reporting.environment"), "environment": CONFIG.get("error_reporting.environment"),
"send_pii": CONFIG.y("error_reporting.send_pii"), "send_pii": CONFIG.get("error_reporting.send_pii"),
"traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)), "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
}, },
"capabilities": self.get_capabilities(), "capabilities": self.get_capabilities(),
"cache_timeout": int(CONFIG.y("redis.cache_timeout")), "cache_timeout": int(CONFIG.get("redis.cache_timeout")),
"cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")), "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
"cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")), "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
"cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")), "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
} }
) )

View file

@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
return return
blueprint_file.seek(0) blueprint_file.seek(0)
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first() instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir"))) rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
meta = None meta = None
if metadata: if metadata:
meta = from_dict(BlueprintMetadata, metadata) meta = from_dict(BlueprintMetadata, metadata)
@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True): for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
check_blueprint_v1_file(BlueprintInstance, Path(file)) check_blueprint_v1_file(BlueprintInstance, Path(file))
for blueprint in BlueprintInstance.objects.using(db_alias).all(): for blueprint in BlueprintInstance.objects.using(db_alias).all():

View file

@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
def retrieve_file(self) -> str: def retrieve_file(self) -> str:
"""Get blueprint from path""" """Get blueprint from path"""
try: try:
base = Path(CONFIG.y("blueprints_dir")) base = Path(CONFIG.get("blueprints_dir"))
full_path = base.joinpath(Path(self.path)).resolve() full_path = base.joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(str(base.resolve())): if not str(full_path).startswith(str(base.resolve())):
raise BlueprintRetrievalFailed("Invalid blueprint path") raise BlueprintRetrievalFailed("Invalid blueprint path")

View file

@ -62,7 +62,7 @@ def start_blueprint_watcher():
if _file_watcher_started: if _file_watcher_started:
return return
observer = Observer() observer = Observer()
observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True) observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
observer.start() observer.start()
_file_watcher_started = True _file_watcher_started = True
@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
blueprints_discovery.delay() blueprints_discovery.delay()
if isinstance(event, FileModifiedEvent): if isinstance(event, FileModifiedEvent):
path = Path(event.src_path) path = Path(event.src_path)
root = Path(CONFIG.y("blueprints_dir")).absolute() root = Path(CONFIG.get("blueprints_dir")).absolute()
rel_path = str(path.relative_to(root)) rel_path = str(path.relative_to(root))
for instance in BlueprintInstance.objects.filter(path=rel_path): for instance in BlueprintInstance.objects.filter(path=rel_path):
LOGGER.debug("modified blueprint file, starting apply", instance=instance) LOGGER.debug("modified blueprint file, starting apply", instance=instance)
@ -101,7 +101,7 @@ def blueprints_find_dict():
def blueprints_find(): def blueprints_find():
"""Find blueprints and return valid ones""" """Find blueprints and return valid ones"""
blueprints = [] blueprints = []
root = Path(CONFIG.y("blueprints_dir")) root = Path(CONFIG.get("blueprints_dir"))
for path in root.rglob("**/*.yaml"): for path in root.rglob("**/*.yaml"):
# Check if any part in the path starts with a dot and assume a hidden file # Check if any part in the path starts with a dot and assume a hidden file
if any(part for part in path.parts if part.startswith(".")): if any(part for part in path.parts if part.startswith(".")):

View file

@ -596,7 +596,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
def impersonate(self, request: Request, pk: int) -> Response: def impersonate(self, request: Request, pk: int) -> Response:
"""Impersonate a user""" """Impersonate a user"""
if not CONFIG.y_bool("impersonation"): if not CONFIG.get_bool("impersonation"):
LOGGER.debug("User attempted to impersonate", user=request.user) LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401) return Response(status=401)
if not request.user.has_perm("impersonate"): if not request.user.has_perm("impersonate"):

View file

@ -18,7 +18,7 @@ class Command(BaseCommand):
def handle(self, **options): def handle(self, **options):
close_old_connections() close_old_connections()
if CONFIG.y_bool("remote_debug"): if CONFIG.get_bool("remote_debug"):
import debugpy import debugpy
debugpy.listen(("0.0.0.0", 6900)) # nosec debugpy.listen(("0.0.0.0", 6900)) # nosec

View file

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key""" """Default token key"""
# We use generate_id since the chars in the key should be easy # We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery) # to use in Emails (for verification) and URLs (for recovery)
return generate_id(int(CONFIG.y("default_token_length"))) return generate_id(int(CONFIG.get("default_token_length")))
class UserTypes(models.TextChoices): class UserTypes(models.TextChoices):

View file

@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask):
certs = {} certs = {}
private_keys = {} private_keys = {}
discovered = 0 discovered = 0
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True): for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True):
path = Path(file) path = Path(file)
if not path.exists(): if not path.exists():
continue continue

View file

@ -33,7 +33,7 @@ class GeoIPReader:
def __open(self): def __open(self):
"""Get GeoIP Reader, if configured, otherwise none""" """Get GeoIP Reader, if configured, otherwise none"""
path = CONFIG.y("geoip") path = CONFIG.get("geoip")
if path == "" or not path: if path == "" or not path:
return return
try: try:
@ -46,7 +46,7 @@ class GeoIPReader:
def __check_expired(self): def __check_expired(self):
"""Check if the modification date of the GeoIP database has """Check if the modification date of the GeoIP database has
changed, and reload it if so""" changed, and reload it if so"""
path = CONFIG.y("geoip") path = CONFIG.get("geoip")
try: try:
mtime = stat(path).st_mtime mtime = stat(path).st_mtime
diff = self.__last_mtime < mtime diff = self.__last_mtime < mtime

View file

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored. # was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored" PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows")) CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
CACHE_PREFIX = "goauthentik.io/flows/planner/" CACHE_PREFIX = "goauthentik.io/flows/planner/"

View file

@ -18,7 +18,6 @@ from authentik.flows.planner import FlowPlan, FlowPlanner
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
@ -85,7 +84,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION, FlowDesignation.AUTHENTICATION,
) )
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get( response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
) )
@ -111,7 +109,6 @@ class TestFlowExecutor(FlowTestCase):
denied_action=FlowDeniedAction.CONTINUE, denied_action=FlowDeniedAction.CONTINUE,
) )
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get( response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
) )
@ -128,7 +125,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION, FlowDesignation.AUTHENTICATION,
) )
CONFIG.update_from_dict({"domain": "testserver"})
dest = "/unique-string" dest = "/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
@ -145,7 +141,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION, FlowDesignation.AUTHENTICATION,
) )
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get( response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
) )

View file

@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str:
"initials": avatar_mode_generated, "initials": avatar_mode_generated,
"gravatar": avatar_mode_gravatar, "gravatar": avatar_mode_gravatar,
} }
modes: str = CONFIG.y("avatars", "none") modes: str = CONFIG.get("avatars", "none")
for mode in modes.split(","): for mode in modes.split(","):
avatar = None avatar = None
if mode in mode_map: if mode in mode_map:

View file

@ -2,13 +2,15 @@
import os import os
from collections.abc import Mapping from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from glob import glob from glob import glob
from json import dumps, loads from json import JSONEncoder, dumps, loads
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
from sys import argv, stderr from sys import argv, stderr
from time import time from time import time
from typing import Any from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import yaml import yaml
@ -32,15 +34,44 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
return root return root
@dataclass
class Attr:
"""Single configuration attribute"""
class Source(Enum):
"""Sources a configuration attribute can come from, determines what should be done with
Attr.source (and if it's set at all)"""
UNSPECIFIED = "unspecified"
ENV = "env"
CONFIG_FILE = "config_file"
URI = "uri"
value: Any
source_type: Source = field(default=Source.UNSPECIFIED)
# depending on source_type, might contain the environment variable or the path
# to the config file containing this change or the file containing this value
source: Optional[str] = field(default=None)
class AttrEncoder(JSONEncoder):
"""JSON encoder that can deal with `Attr` classes"""
def default(self, o: Any) -> Any:
if isinstance(o, Attr):
return o.value
return super().default(o)
class ConfigLoader: class ConfigLoader:
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with """Search through SEARCH_PATHS and load configuration. Environment variables starting with
`ENV_PREFIX` are also applied. `ENV_PREFIX` are also applied.
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
loaded_file = [] def __init__(self, **kwargs):
def __init__(self):
super().__init__() super().__init__()
self.__config = {} self.__config = {}
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve() base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
@ -65,6 +96,7 @@ class ConfigLoader:
# Update config with env file # Update config with env file
self.update_from_file(env_file) self.update_from_file(env_file)
self.update_from_env() self.update_from_env()
self.update(self.__config, kwargs)
def log(self, level: str, message: str, **kwargs): def log(self, level: str, message: str, **kwargs):
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when """Custom Log method, we want to ensure ConfigLoader always logs JSON even when
@ -86,22 +118,32 @@ class ConfigLoader:
else: else:
if isinstance(value, str): if isinstance(value, str):
value = self.parse_uri(value) value = self.parse_uri(value)
elif not isinstance(value, Attr):
value = Attr(value)
root[key] = value root[key] = value
return root return root
def parse_uri(self, value: str) -> str: def refresh(self, key: str):
"""Update a single value"""
attr: Attr = get_path_from_dict(self.raw, key)
if attr.source_type != Attr.Source.URI:
return
attr.value = self.parse_uri(attr.source).value
def parse_uri(self, value: str) -> Attr:
"""Parse string values which start with a URI""" """Parse string values which start with a URI"""
url = urlparse(value) url = urlparse(value)
parsed_value = value
if url.scheme == "env": if url.scheme == "env":
value = os.getenv(url.netloc, url.query) parsed_value = os.getenv(url.netloc, url.query)
if url.scheme == "file": if url.scheme == "file":
try: try:
with open(url.path, "r", encoding="utf8") as _file: with open(url.path, "r", encoding="utf8") as _file:
value = _file.read().strip() parsed_value = _file.read().strip()
except OSError as exc: except OSError as exc:
self.log("error", f"Failed to read config value from {url.path}: {exc}") self.log("error", f"Failed to read config value from {url.path}: {exc}")
value = url.query parsed_value = url.query
return value return Attr(parsed_value, Attr.Source.URI, value)
def update_from_file(self, path: Path): def update_from_file(self, path: Path):
"""Update config from file contents""" """Update config from file contents"""
@ -110,7 +152,6 @@ class ConfigLoader:
try: try:
self.update(self.__config, yaml.safe_load(file)) self.update(self.__config, yaml.safe_load(file))
self.log("debug", "Loaded config", file=str(path)) self.log("debug", "Loaded config", file=str(path))
self.loaded_file.append(path)
except yaml.YAMLError as exc: except yaml.YAMLError as exc:
raise ImproperlyConfigured from exc raise ImproperlyConfigured from exc
except PermissionError as exc: except PermissionError as exc:
@ -121,10 +162,6 @@ class ConfigLoader:
error=str(exc), error=str(exc),
) )
def update_from_dict(self, update: dict):
"""Update config from dict"""
self.__config.update(update)
def update_from_env(self): def update_from_env(self):
"""Check environment variables""" """Check environment variables"""
outer = {} outer = {}
@ -145,7 +182,7 @@ class ConfigLoader:
value = loads(value) value = loads(value)
except JSONDecodeError: except JSONDecodeError:
pass pass
current_obj[dot_parts[-1]] = value current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
idx += 1 idx += 1
if idx > 0: if idx > 0:
self.log("debug", "Loaded environment variables", count=idx) self.log("debug", "Loaded environment variables", count=idx)
@ -154,28 +191,32 @@ class ConfigLoader:
@contextmanager @contextmanager
def patch(self, path: str, value: Any): def patch(self, path: str, value: Any):
"""Context manager for unittests to patch a value""" """Context manager for unittests to patch a value"""
original_value = self.y(path) original_value = self.get(path)
self.y_set(path, value) self.set(path, value)
try: try:
yield yield
finally: finally:
self.y_set(path, original_value) self.set(path, original_value)
@property @property
def raw(self) -> dict: def raw(self) -> dict:
"""Get raw config dictionary""" """Get raw config dictionary"""
return self.__config return self.__config
# pylint: disable=invalid-name def get(self, path: str, default=None, sep=".") -> Any:
def y(self, path: str, default=None, sep=".") -> Any:
"""Access attribute by using yaml path""" """Access attribute by using yaml path"""
# Walk sub_dicts before parsing path # Walk sub_dicts before parsing path
root = self.raw root = self.raw
# Walk each component of the path # Walk each component of the path
return get_path_from_dict(root, path, sep=sep, default=default) attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value
def y_set(self, path: str, value: Any, sep="."): def get_bool(self, path: str, default=False) -> bool:
"""Set value using same syntax as y()""" """Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true"
def set(self, path: str, value: Any, sep="."):
"""Set value using same syntax as get()"""
# Walk sub_dicts before parsing path # Walk sub_dicts before parsing path
root = self.raw root = self.raw
# Walk each component of the path # Walk each component of the path
@ -184,17 +225,14 @@ class ConfigLoader:
if comp not in root: if comp not in root:
root[comp] = {} root[comp] = {}
root = root.get(comp, {}) root = root.get(comp, {})
root[path_parts[-1]] = value root[path_parts[-1]] = Attr(value)
def y_bool(self, path: str, default=False) -> bool:
"""Wrapper for y that converts value into boolean"""
return str(self.y(path, default)).lower() == "true"
CONFIG = ConfigLoader() CONFIG = ConfigLoader()
if __name__ == "__main__": if __name__ == "__main__":
if len(argv) < 2: if len(argv) < 2:
print(dumps(CONFIG.raw, indent=4)) print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
else: else:
print(CONFIG.y(argv[1])) print(CONFIG.get(argv[1]))

View file

@ -51,18 +51,18 @@ class SentryTransport(HttpTransport):
def sentry_init(**sentry_init_kwargs): def sentry_init(**sentry_init_kwargs):
"""Configure sentry SDK""" """Configure sentry SDK"""
sentry_env = CONFIG.y("error_reporting.environment", "customer") sentry_env = CONFIG.get("error_reporting.environment", "customer")
kwargs = { kwargs = {
"environment": sentry_env, "environment": sentry_env,
"send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False), "send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False),
"_experiments": { "_experiments": {
"profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)), "profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)),
}, },
} }
kwargs.update(**sentry_init_kwargs) kwargs.update(**sentry_init_kwargs)
# pylint: disable=abstract-class-instantiated # pylint: disable=abstract-class-instantiated
sentry_sdk_init( sentry_sdk_init(
dsn=CONFIG.y("error_reporting.sentry_dsn"), dsn=CONFIG.get("error_reporting.sentry_dsn"),
integrations=[ integrations=[
ArgvIntegration(), ArgvIntegration(),
StdlibIntegration(), StdlibIntegration(),
@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float:
return 0 return 0
if _type == "websocket": if _type == "websocket":
return 0 return 0
return float(CONFIG.y("error_reporting.sample_rate", 0.1)) return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def before_send(event: dict, hint: dict) -> Optional[dict]: def before_send(event: dict, hint: dict) -> Optional[dict]:

View file

@ -16,23 +16,23 @@ class TestConfig(TestCase):
config = ConfigLoader() config = ConfigLoader()
environ[ENV_PREFIX + "_test__test"] = "bar" environ[ENV_PREFIX + "_test__test"] = "bar"
config.update_from_env() config.update_from_env()
self.assertEqual(config.y("test.test"), "bar") self.assertEqual(config.get("test.test"), "bar")
def test_patch(self): def test_patch(self):
"""Test patch decorator""" """Test patch decorator"""
config = ConfigLoader() config = ConfigLoader()
config.y_set("foo.bar", "bar") config.set("foo.bar", "bar")
self.assertEqual(config.y("foo.bar"), "bar") self.assertEqual(config.get("foo.bar"), "bar")
with config.patch("foo.bar", "baz"): with config.patch("foo.bar", "baz"):
self.assertEqual(config.y("foo.bar"), "baz") self.assertEqual(config.get("foo.bar"), "baz")
self.assertEqual(config.y("foo.bar"), "bar") self.assertEqual(config.get("foo.bar"), "bar")
def test_uri_env(self): def test_uri_env(self):
"""Test URI parsing (environment)""" """Test URI parsing (environment)"""
config = ConfigLoader() config = ConfigLoader()
environ["foo"] = "bar" environ["foo"] = "bar"
self.assertEqual(config.parse_uri("env://foo"), "bar") self.assertEqual(config.parse_uri("env://foo").value, "bar")
self.assertEqual(config.parse_uri("env://foo?bar"), "bar") self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
def test_uri_file(self): def test_uri_file(self):
"""Test URI parsing (file load)""" """Test URI parsing (file load)"""
@ -41,11 +41,25 @@ class TestConfig(TestCase):
write(file, "foo".encode()) write(file, "foo".encode())
_, file2_name = mkstemp() _, file2_name = mkstemp()
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo") self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo")
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def") self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def")
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
def test_uri_file_update(self):
"""Test URI parsing (file load and update)"""
file, file_name = mkstemp()
write(file, "foo".encode())
config = ConfigLoader(file_test=f"file://{file_name}")
self.assertEqual(config.get("file_test"), "foo")
# Update config file
write(file, "bar".encode())
config.refresh("file_test")
self.assertEqual(config.get("file_test"), "foobar")
unlink(file_name)
def test_file_update(self): def test_file_update(self):
"""Test update_from_file""" """Test update_from_file"""
config = ConfigLoader() config = ConfigLoader()

View file

@ -50,7 +50,7 @@ def get_env() -> str:
"""Get environment in which authentik is currently running""" """Get environment in which authentik is currently running"""
if "CI" in os.environ: if "CI" in os.environ:
return "ci" return "ci"
if CONFIG.y_bool("debug"): if CONFIG.get_bool("debug"):
return "dev" return "dev"
if SERVICE_HOST_ENV_NAME in os.environ: if SERVICE_HOST_ENV_NAME in os.environ:
return "kubernetes" return "kubernetes"

View file

@ -97,7 +97,7 @@ class BaseController:
if self.outpost.config.container_image is not None: if self.outpost.config.container_image is not None:
return self.outpost.config.container_image return self.outpost.config.container_image
image_name_template: str = CONFIG.y("outposts.container_image_base") image_name_template: str = CONFIG.get("outposts.container_image_base")
return image_name_template % { return image_name_template % {
"type": self.outpost.type, "type": self.outpost.type,
"version": __version__, "version": __version__,

View file

@ -58,7 +58,7 @@ class OutpostConfig:
authentik_host_insecure: bool = False authentik_host_insecure: bool = False
authentik_host_browser: str = "" authentik_host_browser: str = ""
log_level: str = CONFIG.y("log_level") log_level: str = CONFIG.get("log_level")
object_naming_template: str = field(default="ak-outpost-%(name)s") object_naming_template: str = field(default="ak-outpost-%(name)s")
container_image: Optional[str] = field(default=None) container_image: Optional[str] = field(default=None)

View file

@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
def outpost_connection_discovery(self: MonitoredTask): def outpost_connection_discovery(self: MonitoredTask):
"""Checks the local environment and create Service connections.""" """Checks the local environment and create Service connections."""
status = TaskResult(TaskResultStatus.SUCCESSFUL) status = TaskResult(TaskResultStatus.SUCCESSFUL)
if not CONFIG.y_bool("outposts.discover"): if not CONFIG.get_bool("outposts.discover"):
status.messages.append("Outpost integration discovery is disabled") status.messages.append("Outpost integration discovery is disabled")
self.set_status(status) self.set_status(status)
return return

View file

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
FORK_CTX = get_context("fork") FORK_CTX = get_context("fork")
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies")) CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
PROCESS_CLASS = FORK_CTX.Process PROCESS_CLASS = FORK_CTX.Process

View file

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger() LOGGER = get_logger()
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation")) CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
def update_score(request: HttpRequest, identifier: str, amount: int): def update_score(request: HttpRequest, identifier: str, amount: int):

View file

@ -46,7 +46,7 @@ class DeviceView(View):
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
throttle = AnonRateThrottle() throttle = AnonRateThrottle()
throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour") throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate) throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
if not throttle.allow_request(request, self): if not throttle.allow_request(request, self):
return HttpResponse(status=429) return HttpResponse(status=429)

View file

15
authentik/root/db/base.py Normal file
View file

@ -0,0 +1,15 @@
"""authentik database backend"""
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG
class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials"""
def get_connection_params(self):
CONFIG.refresh("postgresql.password")
conn_params = super().get_connection_params()
conn_params["user"] = CONFIG.get("postgresql.user")
conn_params["password"] = CONFIG.get("postgresql.password")
return conn_params

View file

@ -26,15 +26,15 @@ def get_install_id_raw():
"""Get install_id without django loaded, this is required for the startup when we get """Get install_id without django loaded, this is required for the startup when we get
the install_id but django isn't loaded yet and we can't use the function above.""" the install_id but django isn't loaded yet and we can't use the function above."""
conn = connect( conn = connect(
dbname=CONFIG.y("postgresql.name"), dbname=CONFIG.get("postgresql.name"),
user=CONFIG.y("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.y("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.y("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.y("postgresql.port")), port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"), sslkey=CONFIG.get("postgresql.sslkey"),
) )
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;") cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")

View file

@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent
STATICFILES_DIRS = [BASE_DIR / Path("web")] STATICFILES_DIRS = [BASE_DIR / Path("web")]
MEDIA_ROOT = BASE_DIR / Path("media") MEDIA_ROOT = BASE_DIR / Path("media")
DEBUG = CONFIG.y_bool("debug") DEBUG = CONFIG.get_bool("debug")
SECRET_KEY = CONFIG.y("secret_key") SECRET_KEY = CONFIG.get("secret_key")
INTERNAL_IPS = ["127.0.0.1"] INTERNAL_IPS = ["127.0.0.1"]
ALLOWED_HOSTS = ["*"] ALLOWED_HOSTS = ["*"]
@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf"
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF" CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
LANGUAGE_COOKIE_NAME = "authentik_language" LANGUAGE_COOKIE_NAME = "authentik_language"
SESSION_COOKIE_NAME = "authentik_session" SESSION_COOKIE_NAME = "authentik_session"
SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None) SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
AUTHENTICATION_BACKENDS = [ AUTHENTICATION_BACKENDS = [
"django.contrib.auth.backends.ModelBackend", "django.contrib.auth.backends.ModelBackend",
@ -179,26 +179,26 @@ REST_FRAMEWORK = {
"TEST_REQUEST_DEFAULT_FORMAT": "json", "TEST_REQUEST_DEFAULT_FORMAT": "json",
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"], "DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
"DEFAULT_THROTTLE_RATES": { "DEFAULT_THROTTLE_RATES": {
"anon": CONFIG.y("throttle.default"), "anon": CONFIG.get("throttle.default"),
}, },
} }
_redis_protocol_prefix = "redis://" _redis_protocol_prefix = "redis://"
_redis_celery_tls_requirements = "" _redis_celery_tls_requirements = ""
if CONFIG.y_bool("redis.tls", False): if CONFIG.get_bool("redis.tls", False):
_redis_protocol_prefix = "rediss://" _redis_protocol_prefix = "rediss://"
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}" _redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
_redis_url = ( _redis_url = (
f"{_redis_protocol_prefix}:" f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.y('redis.port'))}" f"{int(CONFIG.get('redis.port'))}"
) )
CACHES = { CACHES = {
"default": { "default": {
"BACKEND": "django_redis.cache.RedisCache", "BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}", "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)), "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache", "KEY_PREFIX": "authentik_cache",
} }
@ -238,7 +238,7 @@ ROOT_URLCONF = "authentik.root.urls"
TEMPLATES = [ TEMPLATES = [
{ {
"BACKEND": "django.template.backends.django.DjangoTemplates", "BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [CONFIG.y("email.template_dir")], "DIRS": [CONFIG.get("email.template_dir")],
"APP_DIRS": True, "APP_DIRS": True,
"OPTIONS": { "OPTIONS": {
"context_processors": [ "context_processors": [
@ -258,7 +258,7 @@ CHANNEL_LAYERS = {
"default": { "default": {
"BACKEND": "channels_redis.core.RedisChannelLayer", "BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": { "CONFIG": {
"hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"], "hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
"prefix": "authentik_channels", "prefix": "authentik_channels",
}, },
}, },
@ -270,34 +270,37 @@ CHANNEL_LAYERS = {
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django_prometheus.db.backends.postgresql", "ENGINE": "authentik.root.db",
"HOST": CONFIG.y("postgresql.host"), "HOST": CONFIG.get("postgresql.host"),
"NAME": CONFIG.y("postgresql.name"), "NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.y("postgresql.user"), "USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.y("postgresql.password"), "PASSWORD": CONFIG.get("postgresql.password"),
"PORT": int(CONFIG.y("postgresql.port")), "PORT": int(CONFIG.get("postgresql.port")),
"SSLMODE": CONFIG.y("postgresql.sslmode"), "SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"), "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.y("postgresql.sslcert"), "SSLCERT": CONFIG.get("postgresql.sslcert"),
"SSLKEY": CONFIG.y("postgresql.sslkey"), "SSLKEY": CONFIG.get("postgresql.sslkey"),
} }
} }
if CONFIG.y_bool("postgresql.use_pgbouncer", False): if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors # https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections # https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
# Email # Email
EMAIL_HOST = CONFIG.y("email.host") # These values should never actually be used, emails are only sent from email stages, which
EMAIL_PORT = int(CONFIG.y("email.port")) # loads the config directly from CONFIG
EMAIL_HOST_USER = CONFIG.y("email.username") # See authentik/stages/email/models.py, line 105
EMAIL_HOST_PASSWORD = CONFIG.y("email.password") EMAIL_HOST = CONFIG.get("email.host")
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False) EMAIL_PORT = int(CONFIG.get("email.port"))
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False) EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout")) EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
DEFAULT_FROM_EMAIL = CONFIG.y("email.from") EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] " EMAIL_SUBJECT_PREFIX = "[authentik] "
@ -345,15 +348,15 @@ CELERY = {
}, },
"task_create_missing_queues": True, "task_create_missing_queues": True,
"task_default_queue": "authentik", "task_default_queue": "authentik",
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}", "broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}", "result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
} }
# Sentry integration # Sentry integration
env = get_env() env = get_env()
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False) _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
if _ERROR_REPORTING: if _ERROR_REPORTING:
sentry_env = CONFIG.y("error_reporting.environment", "customer") sentry_env = CONFIG.get("error_reporting.environment", "customer")
sentry_init() sentry_init()
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16]) set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
@ -367,7 +370,7 @@ MEDIA_URL = "/media/"
TEST = False TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
# We can't check TEST here as its set later by the test runner # We can't check TEST here as its set later by the test runner
LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG" LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean # We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog # https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level so that would have to be re-written too # Additionally, the entire code uses debug as highest level so that would have to be re-written too

View file

@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover
settings.TEST = True settings.TEST = True
settings.CELERY["task_always_eager"] = True settings.CELERY["task_always_eager"] = True
CONFIG.y_set("avatars", "none") CONFIG.set("avatars", "none")
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb") CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb")
CONFIG.y_set("blueprints_dir", "./blueprints") CONFIG.set("blueprints_dir", "./blueprints")
CONFIG.y_set( CONFIG.set(
"outposts.container_image_base", "outposts.container_image_base",
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}", f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
) )
CONFIG.y_set("error_reporting.sample_rate", 0) CONFIG.set("error_reporting.sample_rate", 0)
sentry_init( sentry_init(
environment="testing", environment="testing",
send_default_pii=True, send_default_pii=True,

View file

@ -136,7 +136,7 @@ class LDAPSource(Source):
chmod(private_key_file, 0o600) chmod(private_key_file, 0o600)
tls_kwargs["local_private_key_file"] = private_key_file tls_kwargs["local_private_key_file"] = private_key_file
tls_kwargs["local_certificate_file"] = certificate_file tls_kwargs["local_certificate_file"] = certificate_file
if ciphers := CONFIG.y("ldap.tls.ciphers", None): if ciphers := CONFIG.get("ldap.tls.ciphers", None):
tls_kwargs["ciphers"] = ciphers.strip() tls_kwargs["ciphers"] = ciphers.strip()
if self.sni: if self.sni:
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip() tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()

View file

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False, types_only=False,
get_operational_attributes=False, get_operational_attributes=False,
controls=None, controls=None,
paged_size=int(CONFIG.y("ldap.page_size", 50)), paged_size=int(CONFIG.get("ldap.page_size", 50)),
paged_criticality=False, paged_criticality=False,
): ):
"""Search in pages, returns each page""" """Search in pages, returns each page"""

View file

@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task( @CELERY_APP.task(
bind=True, bind=True,
base=MonitoredTask, base=MonitoredTask,
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
) )
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source: if not source:
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID

View file

@ -13,6 +13,7 @@ from rest_framework.serializers import BaseSerializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.lib.config import CONFIG
LOGGER = get_logger() LOGGER = get_logger()
@ -104,7 +105,16 @@ class EmailStage(Stage):
def backend(self) -> BaseEmailBackend: def backend(self) -> BaseEmailBackend:
"""Get fully configured Email Backend instance""" """Get fully configured Email Backend instance"""
if self.use_global_settings: if self.use_global_settings:
return self.backend_class() CONFIG.refresh("email.password")
return self.backend_class(
host=CONFIG.get("email.host"),
port=int(CONFIG.get("email.port")),
username=CONFIG.get("email.username"),
password=CONFIG.get("email.password"),
use_tls=CONFIG.get_bool("email.use_tls", False),
use_ssl=CONFIG.get_bool("email.use_ssl", False),
timeout=int(CONFIG.get("email.timeout")),
)
return self.backend_class( return self.backend_class(
host=self.host, host=self.host,
port=self.port, port=self.port,

View file

@ -13,6 +13,7 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
from authentik.lib.config import CONFIG
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
@ -120,7 +121,7 @@ class TestEmailStage(FlowTestCase):
def test_use_global_settings(self): def test_use_global_settings(self):
"""Test use_global_settings""" """Test use_global_settings"""
host = "some-unique-string" host = "some-unique-string"
with self.settings(EMAIL_HOST=host): with CONFIG.patch("email.host", host):
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host) self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
def test_token(self): def test_token(self):

View file

@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer):
ui_footer_links = ListField( ui_footer_links = ListField(
child=FooterLinkSerializer(), child=FooterLinkSerializer(),
read_only=True, read_only=True,
default=CONFIG.y("footer_links", []), default=CONFIG.get("footer_links", []),
) )
ui_theme = ChoiceField( ui_theme = ChoiceField(
choices=Themes.choices, choices=Themes.choices,

View file

@ -24,7 +24,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png", "branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik", "branding_title": "authentik",
"matched_domain": tenant.domain, "matched_domain": tenant.domain,
"ui_footer_links": CONFIG.y("footer_links"), "ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC, "ui_theme": Themes.AUTOMATIC,
"default_locale": "", "default_locale": "",
}, },
@ -43,7 +43,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png", "branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "custom", "branding_title": "custom",
"matched_domain": "bar.baz", "matched_domain": "bar.baz",
"ui_footer_links": CONFIG.y("footer_links"), "ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC, "ui_theme": Themes.AUTOMATIC,
"default_locale": "", "default_locale": "",
}, },
@ -59,7 +59,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png", "branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik", "branding_title": "authentik",
"matched_domain": "fallback", "matched_domain": "fallback",
"ui_footer_links": CONFIG.y("footer_links"), "ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC, "ui_theme": Themes.AUTOMATIC,
"default_locale": "", "default_locale": "",
}, },

View file

@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
trace = span.to_traceparent() trace = span.to_traceparent()
return { return {
"tenant": tenant, "tenant": tenant,
"footer_links": CONFIG.y("footer_links"), "footer_links": CONFIG.get("footer_links"),
"sentry_trace": trace, "sentry_trace": trace,
"version": get_full_version(), "version": get_full_version(),
} }

View file

@ -94,21 +94,21 @@ entries:
prompt_data = request.context.get("prompt_data") prompt_data = request.context.get("prompt_data")
if not request.user.group_attributes(request.http_request).get( if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True) USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True)
): ):
if prompt_data.get("email") != request.user.email: if prompt_data.get("email") != request.user.email:
ak_message("Not allowed to change email address.") ak_message("Not allowed to change email address.")
return False return False
if not request.user.group_attributes(request.http_request).get( if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True) USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True)
): ):
if prompt_data.get("name") != request.user.name: if prompt_data.get("name") != request.user.name:
ak_message("Not allowed to change name.") ak_message("Not allowed to change name.")
return False return False
if not request.user.group_attributes(request.http_request).get( if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True) USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True)
): ):
if prompt_data.get("username") != request.user.username: if prompt_data.get("username") != request.user.username:
ak_message("Not allowed to change username.") ak_message("Not allowed to change username.")

View file

@ -37,7 +37,7 @@ makedirs(prometheus_tmp_dir, exist_ok=True)
max_requests = 1000 max_requests = 1000
max_requests_jitter = 50 max_requests_jitter = 50
_debug = CONFIG.y_bool("DEBUG", False) _debug = CONFIG.get_bool("DEBUG", False)
logconfig_dict = { logconfig_dict = {
"version": 1, "version": 1,
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
else: else:
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
workers = int(CONFIG.y("web.workers", default_workers)) workers = int(CONFIG.get("web.workers", default_workers))
threads = int(CONFIG.y("web.threads", 4)) threads = int(CONFIG.get("web.threads", 4))
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
@ -133,7 +133,7 @@ def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
worker._worker_id = _next_worker_id(server) worker._worker_id = _next_worker_id(server)
if not CONFIG.y_bool("disable_startup_analytics", False): if not CONFIG.get_bool("disable_startup_analytics", False):
env = get_env() env = get_env()
should_send = env not in ["dev", "ci"] should_send = env not in ["dev", "ci"]
if should_send: if should_send:
@ -158,7 +158,7 @@ if not CONFIG.y_bool("disable_startup_analytics", False):
except Exception: # nosec except Exception: # nosec
pass pass
if CONFIG.y_bool("remote_debug"): if CONFIG.get_bool("remote_debug"):
import debugpy import debugpy
debugpy.listen(("0.0.0.0", 6800)) # nosec debugpy.listen(("0.0.0.0", 6800)) # nosec

View file

@ -52,15 +52,15 @@ def release_lock():
if __name__ == "__main__": if __name__ == "__main__":
conn = connect( conn = connect(
dbname=CONFIG.y("postgresql.name"), dbname=CONFIG.get("postgresql.name"),
user=CONFIG.y("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.y("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.y("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.y("postgresql.port")), port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"), sslkey=CONFIG.get("postgresql.sslkey"),
) )
curr = conn.cursor() curr = conn.cursor()
try: try:

View file

@ -25,7 +25,7 @@ class Migration(BaseMigration):
# If we already have migrations in the database, assume we're upgrading an existing install # If we already have migrations in the database, assume we're upgrading an existing install
# and set the install id to the secret key # and set the install id to the secret key
self.cur.execute( self.cur.execute(
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.y("secret_key"),) "INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.get("secret_key"),)
) )
else: else:
# Otherwise assume a new install, generate an install ID based on a UUID # Otherwise assume a new install, generate an install ID based on a UUID

View file

@ -108,14 +108,14 @@ class Migration(BaseMigration):
self.con.commit() self.con.commit()
# We also need to clean the cache to make sure no pickeled objects still exist # We also need to clean the cache to make sure no pickeled objects still exist
for db in [ for db in [
CONFIG.y("redis.message_queue_db"), CONFIG.get("redis.message_queue_db"),
CONFIG.y("redis.cache_db"), CONFIG.get("redis.cache_db"),
CONFIG.y("redis.ws_db"), CONFIG.get("redis.ws_db"),
]: ]:
redis = Redis( redis = Redis(
host=CONFIG.y("redis.host"), host=CONFIG.get("redis.host"),
port=6379, port=6379,
db=db, db=db,
password=CONFIG.y("redis.password"), password=CONFIG.get("redis.password"),
) )
redis.flushall() redis.flushall()

View file

@ -14,7 +14,7 @@ from authentik.lib.config import CONFIG
CONFIG.log("info", "Starting authentik bootstrap") CONFIG.log("info", "Starting authentik bootstrap")
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity # Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0: if CONFIG.get("secret_key") is None or len(CONFIG.get("secret_key")) == 0:
CONFIG.log("info", "----------------------------------------------------------------------") CONFIG.log("info", "----------------------------------------------------------------------")
CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.") CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.")
CONFIG.log("info", "----------------------------------------------------------------------") CONFIG.log("info", "----------------------------------------------------------------------")
@ -24,15 +24,15 @@ if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
while True: while True:
try: try:
conn = connect( conn = connect(
dbname=CONFIG.y("postgresql.name"), dbname=CONFIG.get("postgresql.name"),
user=CONFIG.y("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.y("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.y("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.y("postgresql.port")), port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"), sslkey=CONFIG.get("postgresql.sslkey"),
) )
conn.cursor() conn.cursor()
break break
@ -42,12 +42,12 @@ while True:
CONFIG.log("info", "PostgreSQL connection successful") CONFIG.log("info", "PostgreSQL connection successful")
REDIS_PROTOCOL_PREFIX = "redis://" REDIS_PROTOCOL_PREFIX = "redis://"
if CONFIG.y_bool("redis.tls", False): if CONFIG.get_bool("redis.tls", False):
REDIS_PROTOCOL_PREFIX = "rediss://" REDIS_PROTOCOL_PREFIX = "rediss://"
REDIS_URL = ( REDIS_URL = (
f"{REDIS_PROTOCOL_PREFIX}:" f"{REDIS_PROTOCOL_PREFIX}:"
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.y('redis.port'))}/{CONFIG.y('redis.db')}" f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
) )
while True: while True:
try: try:

View file

@ -1,7 +1,6 @@
"""Test Enroll flow""" """Test Enroll flow"""
from time import sleep from time import sleep
from django.test import override_settings
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support.wait import WebDriverWait
@ -9,6 +8,7 @@ from selenium.webdriver.support.wait import WebDriverWait
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User from authentik.core.models import User
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.lib.config import CONFIG
from authentik.stages.identification.models import IdentificationStage from authentik.stages.identification.models import IdentificationStage
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@ -56,7 +56,7 @@ class TestFlowsEnroll(SeleniumTestCase):
@apply_blueprint( @apply_blueprint(
"example/flows-enrollment-email-verification.yaml", "example/flows-enrollment-email-verification.yaml",
) )
@override_settings(EMAIL_PORT=1025) @CONFIG.patch("email.port", 1025)
def test_enroll_email(self): def test_enroll_email(self):
"""Test enroll with Email verification""" """Test enroll with Email verification"""
# Attach enrollment flow to identification stage # Attach enrollment flow to identification stage

View file

@ -1,7 +1,6 @@
"""Test recovery flow""" """Test recovery flow"""
from time import sleep from time import sleep
from django.test import override_settings
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support.wait import WebDriverWait
@ -10,6 +9,7 @@ from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage from authentik.stages.identification.models import IdentificationStage
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@ -47,7 +47,7 @@ class TestFlowsRecovery(SeleniumTestCase):
@apply_blueprint( @apply_blueprint(
"example/flows-recovery-email-verification.yaml", "example/flows-recovery-email-verification.yaml",
) )
@override_settings(EMAIL_PORT=1025) @CONFIG.patch("email.port", 1025)
def test_recover_email(self): def test_recover_email(self):
"""Test recovery with Email verification""" """Test recovery with Email verification"""
# Attach recovery flow to identification stage # Attach recovery flow to identification stage