root: partial Live-updating config (#5959)
* stages/email: directly use email credentials from config Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use custom database backend that supports dynamic credentials Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add crude config reloader Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make method names for CONFIG clearer Signed-off-by: Jens Langhammer <jens@goauthentik.io> * replace config.set with environ Not sure if this is the cleanest way, but it persists through a config reload Signed-off-by: Jens Langhammer <jens@goauthentik.io> * re-add set for @patch Signed-off-by: Jens Langhammer <jens@goauthentik.io> * even more crudeness Signed-off-by: Jens Langhammer <jens@goauthentik.io> * clean up some old stuff? Signed-off-by: Jens Langhammer <jens@goauthentik.io> * somewhat rewrite config loader to keep track of a source of an attribute so we can refresh it Signed-off-by: Jens Langhammer <jens@goauthentik.io> * cleanup old things Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix flow e2e Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
fb4e4dc8db
commit
2f469d2709
|
@ -58,7 +58,7 @@ def clear_update_notifications():
|
||||||
@prefill_task
|
@prefill_task
|
||||||
def update_latest_version(self: MonitoredTask):
|
def update_latest_version(self: MonitoredTask):
|
||||||
"""Update latest version info"""
|
"""Update latest version info"""
|
||||||
if CONFIG.y_bool("disable_update_check"):
|
if CONFIG.get_bool("disable_update_check"):
|
||||||
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
||||||
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
|
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
|
||||||
return
|
return
|
||||||
|
|
|
@ -70,7 +70,7 @@ class ConfigView(APIView):
|
||||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||||
if GEOIP_READER.enabled:
|
if GEOIP_READER.enabled:
|
||||||
caps.append(Capabilities.CAN_GEO_IP)
|
caps.append(Capabilities.CAN_GEO_IP)
|
||||||
if CONFIG.y_bool("impersonation"):
|
if CONFIG.get_bool("impersonation"):
|
||||||
caps.append(Capabilities.CAN_IMPERSONATE)
|
caps.append(Capabilities.CAN_IMPERSONATE)
|
||||||
if settings.DEBUG: # pragma: no cover
|
if settings.DEBUG: # pragma: no cover
|
||||||
caps.append(Capabilities.CAN_DEBUG)
|
caps.append(Capabilities.CAN_DEBUG)
|
||||||
|
@ -86,17 +86,17 @@ class ConfigView(APIView):
|
||||||
return ConfigSerializer(
|
return ConfigSerializer(
|
||||||
{
|
{
|
||||||
"error_reporting": {
|
"error_reporting": {
|
||||||
"enabled": CONFIG.y("error_reporting.enabled"),
|
"enabled": CONFIG.get("error_reporting.enabled"),
|
||||||
"sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"),
|
"sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"),
|
||||||
"environment": CONFIG.y("error_reporting.environment"),
|
"environment": CONFIG.get("error_reporting.environment"),
|
||||||
"send_pii": CONFIG.y("error_reporting.send_pii"),
|
"send_pii": CONFIG.get("error_reporting.send_pii"),
|
||||||
"traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)),
|
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
||||||
},
|
},
|
||||||
"capabilities": self.get_capabilities(),
|
"capabilities": self.get_capabilities(),
|
||||||
"cache_timeout": int(CONFIG.y("redis.cache_timeout")),
|
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
|
||||||
"cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")),
|
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
|
||||||
"cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")),
|
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
|
||||||
"cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")),
|
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
|
||||||
return
|
return
|
||||||
blueprint_file.seek(0)
|
blueprint_file.seek(0)
|
||||||
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
|
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
|
||||||
rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir")))
|
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
|
||||||
meta = None
|
meta = None
|
||||||
if metadata:
|
if metadata:
|
||||||
meta = from_dict(BlueprintMetadata, metadata)
|
meta = from_dict(BlueprintMetadata, metadata)
|
||||||
|
@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
|
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True):
|
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
|
||||||
check_blueprint_v1_file(BlueprintInstance, Path(file))
|
check_blueprint_v1_file(BlueprintInstance, Path(file))
|
||||||
|
|
||||||
for blueprint in BlueprintInstance.objects.using(db_alias).all():
|
for blueprint in BlueprintInstance.objects.using(db_alias).all():
|
||||||
|
|
|
@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||||
def retrieve_file(self) -> str:
|
def retrieve_file(self) -> str:
|
||||||
"""Get blueprint from path"""
|
"""Get blueprint from path"""
|
||||||
try:
|
try:
|
||||||
base = Path(CONFIG.y("blueprints_dir"))
|
base = Path(CONFIG.get("blueprints_dir"))
|
||||||
full_path = base.joinpath(Path(self.path)).resolve()
|
full_path = base.joinpath(Path(self.path)).resolve()
|
||||||
if not str(full_path).startswith(str(base.resolve())):
|
if not str(full_path).startswith(str(base.resolve())):
|
||||||
raise BlueprintRetrievalFailed("Invalid blueprint path")
|
raise BlueprintRetrievalFailed("Invalid blueprint path")
|
||||||
|
|
|
@ -62,7 +62,7 @@ def start_blueprint_watcher():
|
||||||
if _file_watcher_started:
|
if _file_watcher_started:
|
||||||
return
|
return
|
||||||
observer = Observer()
|
observer = Observer()
|
||||||
observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True)
|
observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
|
||||||
observer.start()
|
observer.start()
|
||||||
_file_watcher_started = True
|
_file_watcher_started = True
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
||||||
blueprints_discovery.delay()
|
blueprints_discovery.delay()
|
||||||
if isinstance(event, FileModifiedEvent):
|
if isinstance(event, FileModifiedEvent):
|
||||||
path = Path(event.src_path)
|
path = Path(event.src_path)
|
||||||
root = Path(CONFIG.y("blueprints_dir")).absolute()
|
root = Path(CONFIG.get("blueprints_dir")).absolute()
|
||||||
rel_path = str(path.relative_to(root))
|
rel_path = str(path.relative_to(root))
|
||||||
for instance in BlueprintInstance.objects.filter(path=rel_path):
|
for instance in BlueprintInstance.objects.filter(path=rel_path):
|
||||||
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
|
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
|
||||||
|
@ -101,7 +101,7 @@ def blueprints_find_dict():
|
||||||
def blueprints_find():
|
def blueprints_find():
|
||||||
"""Find blueprints and return valid ones"""
|
"""Find blueprints and return valid ones"""
|
||||||
blueprints = []
|
blueprints = []
|
||||||
root = Path(CONFIG.y("blueprints_dir"))
|
root = Path(CONFIG.get("blueprints_dir"))
|
||||||
for path in root.rglob("**/*.yaml"):
|
for path in root.rglob("**/*.yaml"):
|
||||||
# Check if any part in the path starts with a dot and assume a hidden file
|
# Check if any part in the path starts with a dot and assume a hidden file
|
||||||
if any(part for part in path.parts if part.startswith(".")):
|
if any(part for part in path.parts if part.startswith(".")):
|
||||||
|
|
|
@ -596,7 +596,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||||
@action(detail=True, methods=["POST"])
|
@action(detail=True, methods=["POST"])
|
||||||
def impersonate(self, request: Request, pk: int) -> Response:
|
def impersonate(self, request: Request, pk: int) -> Response:
|
||||||
"""Impersonate a user"""
|
"""Impersonate a user"""
|
||||||
if not CONFIG.y_bool("impersonation"):
|
if not CONFIG.get_bool("impersonation"):
|
||||||
LOGGER.debug("User attempted to impersonate", user=request.user)
|
LOGGER.debug("User attempted to impersonate", user=request.user)
|
||||||
return Response(status=401)
|
return Response(status=401)
|
||||||
if not request.user.has_perm("impersonate"):
|
if not request.user.has_perm("impersonate"):
|
||||||
|
|
|
@ -18,7 +18,7 @@ class Command(BaseCommand):
|
||||||
|
|
||||||
def handle(self, **options):
|
def handle(self, **options):
|
||||||
close_old_connections()
|
close_old_connections()
|
||||||
if CONFIG.y_bool("remote_debug"):
|
if CONFIG.get_bool("remote_debug"):
|
||||||
import debugpy
|
import debugpy
|
||||||
|
|
||||||
debugpy.listen(("0.0.0.0", 6900)) # nosec
|
debugpy.listen(("0.0.0.0", 6900)) # nosec
|
||||||
|
|
|
@ -60,7 +60,7 @@ def default_token_key():
|
||||||
"""Default token key"""
|
"""Default token key"""
|
||||||
# We use generate_id since the chars in the key should be easy
|
# We use generate_id since the chars in the key should be easy
|
||||||
# to use in Emails (for verification) and URLs (for recovery)
|
# to use in Emails (for verification) and URLs (for recovery)
|
||||||
return generate_id(int(CONFIG.y("default_token_length")))
|
return generate_id(int(CONFIG.get("default_token_length")))
|
||||||
|
|
||||||
|
|
||||||
class UserTypes(models.TextChoices):
|
class UserTypes(models.TextChoices):
|
||||||
|
|
|
@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask):
|
||||||
certs = {}
|
certs = {}
|
||||||
private_keys = {}
|
private_keys = {}
|
||||||
discovered = 0
|
discovered = 0
|
||||||
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True):
|
for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True):
|
||||||
path = Path(file)
|
path = Path(file)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -33,7 +33,7 @@ class GeoIPReader:
|
||||||
|
|
||||||
def __open(self):
|
def __open(self):
|
||||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||||
path = CONFIG.y("geoip")
|
path = CONFIG.get("geoip")
|
||||||
if path == "" or not path:
|
if path == "" or not path:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
@ -46,7 +46,7 @@ class GeoIPReader:
|
||||||
def __check_expired(self):
|
def __check_expired(self):
|
||||||
"""Check if the modification date of the GeoIP database has
|
"""Check if the modification date of the GeoIP database has
|
||||||
changed, and reload it if so"""
|
changed, and reload it if so"""
|
||||||
path = CONFIG.y("geoip")
|
path = CONFIG.get("geoip")
|
||||||
try:
|
try:
|
||||||
mtime = stat(path).st_mtime
|
mtime = stat(path).st_mtime
|
||||||
diff = self.__last_mtime < mtime
|
diff = self.__last_mtime < mtime
|
||||||
|
|
|
@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
|
||||||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||||
# was restored.
|
# was restored.
|
||||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows"))
|
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
|
||||||
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ from authentik.flows.planner import FlowPlan, FlowPlanner
|
||||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||||
from authentik.flows.tests import FlowTestCase
|
from authentik.flows.tests import FlowTestCase
|
||||||
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
||||||
from authentik.lib.config import CONFIG
|
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
from authentik.policies.models import PolicyBinding
|
from authentik.policies.models import PolicyBinding
|
||||||
|
@ -85,7 +84,6 @@ class TestFlowExecutor(FlowTestCase):
|
||||||
FlowDesignation.AUTHENTICATION,
|
FlowDesignation.AUTHENTICATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG.update_from_dict({"domain": "testserver"})
|
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
)
|
)
|
||||||
|
@ -111,7 +109,6 @@ class TestFlowExecutor(FlowTestCase):
|
||||||
denied_action=FlowDeniedAction.CONTINUE,
|
denied_action=FlowDeniedAction.CONTINUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG.update_from_dict({"domain": "testserver"})
|
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
)
|
)
|
||||||
|
@ -128,7 +125,6 @@ class TestFlowExecutor(FlowTestCase):
|
||||||
FlowDesignation.AUTHENTICATION,
|
FlowDesignation.AUTHENTICATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG.update_from_dict({"domain": "testserver"})
|
|
||||||
dest = "/unique-string"
|
dest = "/unique-string"
|
||||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
|
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
|
||||||
|
@ -145,7 +141,6 @@ class TestFlowExecutor(FlowTestCase):
|
||||||
FlowDesignation.AUTHENTICATION,
|
FlowDesignation.AUTHENTICATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG.update_from_dict({"domain": "testserver"})
|
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
)
|
)
|
||||||
|
|
|
@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str:
|
||||||
"initials": avatar_mode_generated,
|
"initials": avatar_mode_generated,
|
||||||
"gravatar": avatar_mode_gravatar,
|
"gravatar": avatar_mode_gravatar,
|
||||||
}
|
}
|
||||||
modes: str = CONFIG.y("avatars", "none")
|
modes: str = CONFIG.get("avatars", "none")
|
||||||
for mode in modes.split(","):
|
for mode in modes.split(","):
|
||||||
avatar = None
|
avatar = None
|
||||||
if mode in mode_map:
|
if mode in mode_map:
|
||||||
|
|
|
@ -2,13 +2,15 @@
|
||||||
import os
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from json import dumps, loads
|
from json import JSONEncoder, dumps, loads
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sys import argv, stderr
|
from sys import argv, stderr
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -32,15 +34,44 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Attr:
|
||||||
|
"""Single configuration attribute"""
|
||||||
|
|
||||||
|
class Source(Enum):
|
||||||
|
"""Sources a configuration attribute can come from, determines what should be done with
|
||||||
|
Attr.source (and if it's set at all)"""
|
||||||
|
|
||||||
|
UNSPECIFIED = "unspecified"
|
||||||
|
ENV = "env"
|
||||||
|
CONFIG_FILE = "config_file"
|
||||||
|
URI = "uri"
|
||||||
|
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
source_type: Source = field(default=Source.UNSPECIFIED)
|
||||||
|
|
||||||
|
# depending on source_type, might contain the environment variable or the path
|
||||||
|
# to the config file containing this change or the file containing this value
|
||||||
|
source: Optional[str] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AttrEncoder(JSONEncoder):
|
||||||
|
"""JSON encoder that can deal with `Attr` classes"""
|
||||||
|
|
||||||
|
def default(self, o: Any) -> Any:
|
||||||
|
if isinstance(o, Attr):
|
||||||
|
return o.value
|
||||||
|
return super().default(o)
|
||||||
|
|
||||||
|
|
||||||
class ConfigLoader:
|
class ConfigLoader:
|
||||||
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
|
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
|
||||||
`ENV_PREFIX` are also applied.
|
`ENV_PREFIX` are also applied.
|
||||||
|
|
||||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||||
|
|
||||||
loaded_file = []
|
def __init__(self, **kwargs):
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__config = {}
|
self.__config = {}
|
||||||
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
|
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
|
||||||
|
@ -65,6 +96,7 @@ class ConfigLoader:
|
||||||
# Update config with env file
|
# Update config with env file
|
||||||
self.update_from_file(env_file)
|
self.update_from_file(env_file)
|
||||||
self.update_from_env()
|
self.update_from_env()
|
||||||
|
self.update(self.__config, kwargs)
|
||||||
|
|
||||||
def log(self, level: str, message: str, **kwargs):
|
def log(self, level: str, message: str, **kwargs):
|
||||||
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
||||||
|
@ -86,22 +118,32 @@ class ConfigLoader:
|
||||||
else:
|
else:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = self.parse_uri(value)
|
value = self.parse_uri(value)
|
||||||
|
elif not isinstance(value, Attr):
|
||||||
|
value = Attr(value)
|
||||||
root[key] = value
|
root[key] = value
|
||||||
return root
|
return root
|
||||||
|
|
||||||
def parse_uri(self, value: str) -> str:
|
def refresh(self, key: str):
|
||||||
|
"""Update a single value"""
|
||||||
|
attr: Attr = get_path_from_dict(self.raw, key)
|
||||||
|
if attr.source_type != Attr.Source.URI:
|
||||||
|
return
|
||||||
|
attr.value = self.parse_uri(attr.source).value
|
||||||
|
|
||||||
|
def parse_uri(self, value: str) -> Attr:
|
||||||
"""Parse string values which start with a URI"""
|
"""Parse string values which start with a URI"""
|
||||||
url = urlparse(value)
|
url = urlparse(value)
|
||||||
|
parsed_value = value
|
||||||
if url.scheme == "env":
|
if url.scheme == "env":
|
||||||
value = os.getenv(url.netloc, url.query)
|
parsed_value = os.getenv(url.netloc, url.query)
|
||||||
if url.scheme == "file":
|
if url.scheme == "file":
|
||||||
try:
|
try:
|
||||||
with open(url.path, "r", encoding="utf8") as _file:
|
with open(url.path, "r", encoding="utf8") as _file:
|
||||||
value = _file.read().strip()
|
parsed_value = _file.read().strip()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
||||||
value = url.query
|
parsed_value = url.query
|
||||||
return value
|
return Attr(parsed_value, Attr.Source.URI, value)
|
||||||
|
|
||||||
def update_from_file(self, path: Path):
|
def update_from_file(self, path: Path):
|
||||||
"""Update config from file contents"""
|
"""Update config from file contents"""
|
||||||
|
@ -110,7 +152,6 @@ class ConfigLoader:
|
||||||
try:
|
try:
|
||||||
self.update(self.__config, yaml.safe_load(file))
|
self.update(self.__config, yaml.safe_load(file))
|
||||||
self.log("debug", "Loaded config", file=str(path))
|
self.log("debug", "Loaded config", file=str(path))
|
||||||
self.loaded_file.append(path)
|
|
||||||
except yaml.YAMLError as exc:
|
except yaml.YAMLError as exc:
|
||||||
raise ImproperlyConfigured from exc
|
raise ImproperlyConfigured from exc
|
||||||
except PermissionError as exc:
|
except PermissionError as exc:
|
||||||
|
@ -121,10 +162,6 @@ class ConfigLoader:
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_from_dict(self, update: dict):
|
|
||||||
"""Update config from dict"""
|
|
||||||
self.__config.update(update)
|
|
||||||
|
|
||||||
def update_from_env(self):
|
def update_from_env(self):
|
||||||
"""Check environment variables"""
|
"""Check environment variables"""
|
||||||
outer = {}
|
outer = {}
|
||||||
|
@ -145,7 +182,7 @@ class ConfigLoader:
|
||||||
value = loads(value)
|
value = loads(value)
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
current_obj[dot_parts[-1]] = value
|
current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
|
||||||
idx += 1
|
idx += 1
|
||||||
if idx > 0:
|
if idx > 0:
|
||||||
self.log("debug", "Loaded environment variables", count=idx)
|
self.log("debug", "Loaded environment variables", count=idx)
|
||||||
|
@ -154,28 +191,32 @@ class ConfigLoader:
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch(self, path: str, value: Any):
|
def patch(self, path: str, value: Any):
|
||||||
"""Context manager for unittests to patch a value"""
|
"""Context manager for unittests to patch a value"""
|
||||||
original_value = self.y(path)
|
original_value = self.get(path)
|
||||||
self.y_set(path, value)
|
self.set(path, value)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self.y_set(path, original_value)
|
self.set(path, original_value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw(self) -> dict:
|
def raw(self) -> dict:
|
||||||
"""Get raw config dictionary"""
|
"""Get raw config dictionary"""
|
||||||
return self.__config
|
return self.__config
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
def get(self, path: str, default=None, sep=".") -> Any:
|
||||||
def y(self, path: str, default=None, sep=".") -> Any:
|
|
||||||
"""Access attribute by using yaml path"""
|
"""Access attribute by using yaml path"""
|
||||||
# Walk sub_dicts before parsing path
|
# Walk sub_dicts before parsing path
|
||||||
root = self.raw
|
root = self.raw
|
||||||
# Walk each component of the path
|
# Walk each component of the path
|
||||||
return get_path_from_dict(root, path, sep=sep, default=default)
|
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
||||||
|
return attr.value
|
||||||
|
|
||||||
def y_set(self, path: str, value: Any, sep="."):
|
def get_bool(self, path: str, default=False) -> bool:
|
||||||
"""Set value using same syntax as y()"""
|
"""Wrapper for get that converts value into boolean"""
|
||||||
|
return str(self.get(path, default)).lower() == "true"
|
||||||
|
|
||||||
|
def set(self, path: str, value: Any, sep="."):
|
||||||
|
"""Set value using same syntax as get()"""
|
||||||
# Walk sub_dicts before parsing path
|
# Walk sub_dicts before parsing path
|
||||||
root = self.raw
|
root = self.raw
|
||||||
# Walk each component of the path
|
# Walk each component of the path
|
||||||
|
@ -184,17 +225,14 @@ class ConfigLoader:
|
||||||
if comp not in root:
|
if comp not in root:
|
||||||
root[comp] = {}
|
root[comp] = {}
|
||||||
root = root.get(comp, {})
|
root = root.get(comp, {})
|
||||||
root[path_parts[-1]] = value
|
root[path_parts[-1]] = Attr(value)
|
||||||
|
|
||||||
def y_bool(self, path: str, default=False) -> bool:
|
|
||||||
"""Wrapper for y that converts value into boolean"""
|
|
||||||
return str(self.y(path, default)).lower() == "true"
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG = ConfigLoader()
|
CONFIG = ConfigLoader()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(argv) < 2:
|
if len(argv) < 2:
|
||||||
print(dumps(CONFIG.raw, indent=4))
|
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
|
||||||
else:
|
else:
|
||||||
print(CONFIG.y(argv[1]))
|
print(CONFIG.get(argv[1]))
|
||||||
|
|
|
@ -51,18 +51,18 @@ class SentryTransport(HttpTransport):
|
||||||
|
|
||||||
def sentry_init(**sentry_init_kwargs):
|
def sentry_init(**sentry_init_kwargs):
|
||||||
"""Configure sentry SDK"""
|
"""Configure sentry SDK"""
|
||||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"environment": sentry_env,
|
"environment": sentry_env,
|
||||||
"send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False),
|
"send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False),
|
||||||
"_experiments": {
|
"_experiments": {
|
||||||
"profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)),
|
"profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
kwargs.update(**sentry_init_kwargs)
|
kwargs.update(**sentry_init_kwargs)
|
||||||
# pylint: disable=abstract-class-instantiated
|
# pylint: disable=abstract-class-instantiated
|
||||||
sentry_sdk_init(
|
sentry_sdk_init(
|
||||||
dsn=CONFIG.y("error_reporting.sentry_dsn"),
|
dsn=CONFIG.get("error_reporting.sentry_dsn"),
|
||||||
integrations=[
|
integrations=[
|
||||||
ArgvIntegration(),
|
ArgvIntegration(),
|
||||||
StdlibIntegration(),
|
StdlibIntegration(),
|
||||||
|
@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float:
|
||||||
return 0
|
return 0
|
||||||
if _type == "websocket":
|
if _type == "websocket":
|
||||||
return 0
|
return 0
|
||||||
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
|
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
|
||||||
|
|
||||||
|
|
||||||
def before_send(event: dict, hint: dict) -> Optional[dict]:
|
def before_send(event: dict, hint: dict) -> Optional[dict]:
|
||||||
|
|
|
@ -16,23 +16,23 @@ class TestConfig(TestCase):
|
||||||
config = ConfigLoader()
|
config = ConfigLoader()
|
||||||
environ[ENV_PREFIX + "_test__test"] = "bar"
|
environ[ENV_PREFIX + "_test__test"] = "bar"
|
||||||
config.update_from_env()
|
config.update_from_env()
|
||||||
self.assertEqual(config.y("test.test"), "bar")
|
self.assertEqual(config.get("test.test"), "bar")
|
||||||
|
|
||||||
def test_patch(self):
|
def test_patch(self):
|
||||||
"""Test patch decorator"""
|
"""Test patch decorator"""
|
||||||
config = ConfigLoader()
|
config = ConfigLoader()
|
||||||
config.y_set("foo.bar", "bar")
|
config.set("foo.bar", "bar")
|
||||||
self.assertEqual(config.y("foo.bar"), "bar")
|
self.assertEqual(config.get("foo.bar"), "bar")
|
||||||
with config.patch("foo.bar", "baz"):
|
with config.patch("foo.bar", "baz"):
|
||||||
self.assertEqual(config.y("foo.bar"), "baz")
|
self.assertEqual(config.get("foo.bar"), "baz")
|
||||||
self.assertEqual(config.y("foo.bar"), "bar")
|
self.assertEqual(config.get("foo.bar"), "bar")
|
||||||
|
|
||||||
def test_uri_env(self):
|
def test_uri_env(self):
|
||||||
"""Test URI parsing (environment)"""
|
"""Test URI parsing (environment)"""
|
||||||
config = ConfigLoader()
|
config = ConfigLoader()
|
||||||
environ["foo"] = "bar"
|
environ["foo"] = "bar"
|
||||||
self.assertEqual(config.parse_uri("env://foo"), "bar")
|
self.assertEqual(config.parse_uri("env://foo").value, "bar")
|
||||||
self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
|
self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
|
||||||
|
|
||||||
def test_uri_file(self):
|
def test_uri_file(self):
|
||||||
"""Test URI parsing (file load)"""
|
"""Test URI parsing (file load)"""
|
||||||
|
@ -41,11 +41,25 @@ class TestConfig(TestCase):
|
||||||
write(file, "foo".encode())
|
write(file, "foo".encode())
|
||||||
_, file2_name = mkstemp()
|
_, file2_name = mkstemp()
|
||||||
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
||||||
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
|
self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo")
|
||||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
|
self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def")
|
||||||
unlink(file_name)
|
unlink(file_name)
|
||||||
unlink(file2_name)
|
unlink(file2_name)
|
||||||
|
|
||||||
|
def test_uri_file_update(self):
|
||||||
|
"""Test URI parsing (file load and update)"""
|
||||||
|
file, file_name = mkstemp()
|
||||||
|
write(file, "foo".encode())
|
||||||
|
config = ConfigLoader(file_test=f"file://{file_name}")
|
||||||
|
self.assertEqual(config.get("file_test"), "foo")
|
||||||
|
|
||||||
|
# Update config file
|
||||||
|
write(file, "bar".encode())
|
||||||
|
config.refresh("file_test")
|
||||||
|
self.assertEqual(config.get("file_test"), "foobar")
|
||||||
|
|
||||||
|
unlink(file_name)
|
||||||
|
|
||||||
def test_file_update(self):
|
def test_file_update(self):
|
||||||
"""Test update_from_file"""
|
"""Test update_from_file"""
|
||||||
config = ConfigLoader()
|
config = ConfigLoader()
|
||||||
|
|
|
@ -50,7 +50,7 @@ def get_env() -> str:
|
||||||
"""Get environment in which authentik is currently running"""
|
"""Get environment in which authentik is currently running"""
|
||||||
if "CI" in os.environ:
|
if "CI" in os.environ:
|
||||||
return "ci"
|
return "ci"
|
||||||
if CONFIG.y_bool("debug"):
|
if CONFIG.get_bool("debug"):
|
||||||
return "dev"
|
return "dev"
|
||||||
if SERVICE_HOST_ENV_NAME in os.environ:
|
if SERVICE_HOST_ENV_NAME in os.environ:
|
||||||
return "kubernetes"
|
return "kubernetes"
|
||||||
|
|
|
@ -97,7 +97,7 @@ class BaseController:
|
||||||
if self.outpost.config.container_image is not None:
|
if self.outpost.config.container_image is not None:
|
||||||
return self.outpost.config.container_image
|
return self.outpost.config.container_image
|
||||||
|
|
||||||
image_name_template: str = CONFIG.y("outposts.container_image_base")
|
image_name_template: str = CONFIG.get("outposts.container_image_base")
|
||||||
return image_name_template % {
|
return image_name_template % {
|
||||||
"type": self.outpost.type,
|
"type": self.outpost.type,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
|
|
|
@ -58,7 +58,7 @@ class OutpostConfig:
|
||||||
authentik_host_insecure: bool = False
|
authentik_host_insecure: bool = False
|
||||||
authentik_host_browser: str = ""
|
authentik_host_browser: str = ""
|
||||||
|
|
||||||
log_level: str = CONFIG.y("log_level")
|
log_level: str = CONFIG.get("log_level")
|
||||||
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
||||||
|
|
||||||
container_image: Optional[str] = field(default=None)
|
container_image: Optional[str] = field(default=None)
|
||||||
|
|
|
@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||||
def outpost_connection_discovery(self: MonitoredTask):
|
def outpost_connection_discovery(self: MonitoredTask):
|
||||||
"""Checks the local environment and create Service connections."""
|
"""Checks the local environment and create Service connections."""
|
||||||
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
||||||
if not CONFIG.y_bool("outposts.discover"):
|
if not CONFIG.get_bool("outposts.discover"):
|
||||||
status.messages.append("Outpost integration discovery is disabled")
|
status.messages.append("Outpost integration discovery is disabled")
|
||||||
self.set_status(status)
|
self.set_status(status)
|
||||||
return
|
return
|
||||||
|
|
|
@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
FORK_CTX = get_context("fork")
|
FORK_CTX = get_context("fork")
|
||||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies"))
|
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
|
||||||
PROCESS_CLASS = FORK_CTX.Process
|
PROCESS_CLASS = FORK_CTX.Process
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
|
||||||
from authentik.stages.identification.signals import identification_failed
|
from authentik.stages.identification.signals import identification_failed
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation"))
|
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
|
||||||
|
|
||||||
|
|
||||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||||
|
|
|
@ -46,7 +46,7 @@ class DeviceView(View):
|
||||||
|
|
||||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
throttle = AnonRateThrottle()
|
throttle = AnonRateThrottle()
|
||||||
throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour")
|
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||||
if not throttle.allow_request(request, self):
|
if not throttle.allow_request(request, self):
|
||||||
return HttpResponse(status=429)
|
return HttpResponse(status=429)
|
||||||
|
|
0
authentik/root/db/__init__.py
Normal file
0
authentik/root/db/__init__.py
Normal file
15
authentik/root/db/base.py
Normal file
15
authentik/root/db/base.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
"""authentik database backend"""
|
||||||
|
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
|
||||||
|
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
|
"""database backend which supports rotating credentials"""
|
||||||
|
|
||||||
|
def get_connection_params(self):
|
||||||
|
CONFIG.refresh("postgresql.password")
|
||||||
|
conn_params = super().get_connection_params()
|
||||||
|
conn_params["user"] = CONFIG.get("postgresql.user")
|
||||||
|
conn_params["password"] = CONFIG.get("postgresql.password")
|
||||||
|
return conn_params
|
|
@ -26,15 +26,15 @@ def get_install_id_raw():
|
||||||
"""Get install_id without django loaded, this is required for the startup when we get
|
"""Get install_id without django loaded, this is required for the startup when we get
|
||||||
the install_id but django isn't loaded yet and we can't use the function above."""
|
the install_id but django isn't loaded yet and we can't use the function above."""
|
||||||
conn = connect(
|
conn = connect(
|
||||||
dbname=CONFIG.y("postgresql.name"),
|
dbname=CONFIG.get("postgresql.name"),
|
||||||
user=CONFIG.y("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.y("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.y("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.y("postgresql.port")),
|
port=int(CONFIG.get("postgresql.port")),
|
||||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||||
)
|
)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")
|
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")
|
||||||
|
|
|
@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent
|
||||||
STATICFILES_DIRS = [BASE_DIR / Path("web")]
|
STATICFILES_DIRS = [BASE_DIR / Path("web")]
|
||||||
MEDIA_ROOT = BASE_DIR / Path("media")
|
MEDIA_ROOT = BASE_DIR / Path("media")
|
||||||
|
|
||||||
DEBUG = CONFIG.y_bool("debug")
|
DEBUG = CONFIG.get_bool("debug")
|
||||||
SECRET_KEY = CONFIG.y("secret_key")
|
SECRET_KEY = CONFIG.get("secret_key")
|
||||||
|
|
||||||
INTERNAL_IPS = ["127.0.0.1"]
|
INTERNAL_IPS = ["127.0.0.1"]
|
||||||
ALLOWED_HOSTS = ["*"]
|
ALLOWED_HOSTS = ["*"]
|
||||||
|
@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf"
|
||||||
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
|
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
|
||||||
LANGUAGE_COOKIE_NAME = "authentik_language"
|
LANGUAGE_COOKIE_NAME = "authentik_language"
|
||||||
SESSION_COOKIE_NAME = "authentik_session"
|
SESSION_COOKIE_NAME = "authentik_session"
|
||||||
SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None)
|
SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
|
||||||
|
|
||||||
AUTHENTICATION_BACKENDS = [
|
AUTHENTICATION_BACKENDS = [
|
||||||
"django.contrib.auth.backends.ModelBackend",
|
"django.contrib.auth.backends.ModelBackend",
|
||||||
|
@ -179,26 +179,26 @@ REST_FRAMEWORK = {
|
||||||
"TEST_REQUEST_DEFAULT_FORMAT": "json",
|
"TEST_REQUEST_DEFAULT_FORMAT": "json",
|
||||||
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
|
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
|
||||||
"DEFAULT_THROTTLE_RATES": {
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
"anon": CONFIG.y("throttle.default"),
|
"anon": CONFIG.get("throttle.default"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_redis_protocol_prefix = "redis://"
|
_redis_protocol_prefix = "redis://"
|
||||||
_redis_celery_tls_requirements = ""
|
_redis_celery_tls_requirements = ""
|
||||||
if CONFIG.y_bool("redis.tls", False):
|
if CONFIG.get_bool("redis.tls", False):
|
||||||
_redis_protocol_prefix = "rediss://"
|
_redis_protocol_prefix = "rediss://"
|
||||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}"
|
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
|
||||||
_redis_url = (
|
_redis_url = (
|
||||||
f"{_redis_protocol_prefix}:"
|
f"{_redis_protocol_prefix}:"
|
||||||
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
|
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||||
f"{int(CONFIG.y('redis.port'))}"
|
f"{int(CONFIG.get('redis.port'))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CACHES = {
|
CACHES = {
|
||||||
"default": {
|
"default": {
|
||||||
"BACKEND": "django_redis.cache.RedisCache",
|
"BACKEND": "django_redis.cache.RedisCache",
|
||||||
"LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}",
|
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||||
"TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)),
|
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
|
||||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||||
"KEY_PREFIX": "authentik_cache",
|
"KEY_PREFIX": "authentik_cache",
|
||||||
}
|
}
|
||||||
|
@ -238,7 +238,7 @@ ROOT_URLCONF = "authentik.root.urls"
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
"DIRS": [CONFIG.y("email.template_dir")],
|
"DIRS": [CONFIG.get("email.template_dir")],
|
||||||
"APP_DIRS": True,
|
"APP_DIRS": True,
|
||||||
"OPTIONS": {
|
"OPTIONS": {
|
||||||
"context_processors": [
|
"context_processors": [
|
||||||
|
@ -258,7 +258,7 @@ CHANNEL_LAYERS = {
|
||||||
"default": {
|
"default": {
|
||||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||||
"CONFIG": {
|
"CONFIG": {
|
||||||
"hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"],
|
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
||||||
"prefix": "authentik_channels",
|
"prefix": "authentik_channels",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -270,34 +270,37 @@ CHANNEL_LAYERS = {
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django_prometheus.db.backends.postgresql",
|
"ENGINE": "authentik.root.db",
|
||||||
"HOST": CONFIG.y("postgresql.host"),
|
"HOST": CONFIG.get("postgresql.host"),
|
||||||
"NAME": CONFIG.y("postgresql.name"),
|
"NAME": CONFIG.get("postgresql.name"),
|
||||||
"USER": CONFIG.y("postgresql.user"),
|
"USER": CONFIG.get("postgresql.user"),
|
||||||
"PASSWORD": CONFIG.y("postgresql.password"),
|
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||||
"PORT": int(CONFIG.y("postgresql.port")),
|
"PORT": int(CONFIG.get("postgresql.port")),
|
||||||
"SSLMODE": CONFIG.y("postgresql.sslmode"),
|
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||||
"SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"),
|
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||||
"SSLCERT": CONFIG.y("postgresql.sslcert"),
|
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||||
"SSLKEY": CONFIG.y("postgresql.sslkey"),
|
"SSLKEY": CONFIG.get("postgresql.sslkey"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if CONFIG.y_bool("postgresql.use_pgbouncer", False):
|
if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
||||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
|
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
|
||||||
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
|
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
|
||||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
|
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
|
||||||
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
|
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
|
||||||
|
|
||||||
# Email
|
# Email
|
||||||
EMAIL_HOST = CONFIG.y("email.host")
|
# These values should never actually be used, emails are only sent from email stages, which
|
||||||
EMAIL_PORT = int(CONFIG.y("email.port"))
|
# loads the config directly from CONFIG
|
||||||
EMAIL_HOST_USER = CONFIG.y("email.username")
|
# See authentik/stages/email/models.py, line 105
|
||||||
EMAIL_HOST_PASSWORD = CONFIG.y("email.password")
|
EMAIL_HOST = CONFIG.get("email.host")
|
||||||
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False)
|
EMAIL_PORT = int(CONFIG.get("email.port"))
|
||||||
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False)
|
EMAIL_HOST_USER = CONFIG.get("email.username")
|
||||||
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout"))
|
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
||||||
DEFAULT_FROM_EMAIL = CONFIG.y("email.from")
|
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
||||||
|
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
||||||
|
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
|
||||||
|
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
||||||
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
||||||
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
||||||
|
|
||||||
|
@ -345,15 +348,15 @@ CELERY = {
|
||||||
},
|
},
|
||||||
"task_create_missing_queues": True,
|
"task_create_missing_queues": True,
|
||||||
"task_default_queue": "authentik",
|
"task_default_queue": "authentik",
|
||||||
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||||
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sentry integration
|
# Sentry integration
|
||||||
env = get_env()
|
env = get_env()
|
||||||
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False)
|
_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
|
||||||
if _ERROR_REPORTING:
|
if _ERROR_REPORTING:
|
||||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||||
sentry_init()
|
sentry_init()
|
||||||
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
|
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
|
||||||
|
|
||||||
|
@ -367,7 +370,7 @@ MEDIA_URL = "/media/"
|
||||||
TEST = False
|
TEST = False
|
||||||
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
|
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
|
||||||
# We can't check TEST here as its set later by the test runner
|
# We can't check TEST here as its set later by the test runner
|
||||||
LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
||||||
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
|
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
|
||||||
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
|
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
|
||||||
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
|
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
|
||||||
|
|
|
@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover
|
||||||
|
|
||||||
settings.TEST = True
|
settings.TEST = True
|
||||||
settings.CELERY["task_always_eager"] = True
|
settings.CELERY["task_always_eager"] = True
|
||||||
CONFIG.y_set("avatars", "none")
|
CONFIG.set("avatars", "none")
|
||||||
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||||
CONFIG.y_set("blueprints_dir", "./blueprints")
|
CONFIG.set("blueprints_dir", "./blueprints")
|
||||||
CONFIG.y_set(
|
CONFIG.set(
|
||||||
"outposts.container_image_base",
|
"outposts.container_image_base",
|
||||||
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
||||||
)
|
)
|
||||||
CONFIG.y_set("error_reporting.sample_rate", 0)
|
CONFIG.set("error_reporting.sample_rate", 0)
|
||||||
sentry_init(
|
sentry_init(
|
||||||
environment="testing",
|
environment="testing",
|
||||||
send_default_pii=True,
|
send_default_pii=True,
|
||||||
|
|
|
@ -136,7 +136,7 @@ class LDAPSource(Source):
|
||||||
chmod(private_key_file, 0o600)
|
chmod(private_key_file, 0o600)
|
||||||
tls_kwargs["local_private_key_file"] = private_key_file
|
tls_kwargs["local_private_key_file"] = private_key_file
|
||||||
tls_kwargs["local_certificate_file"] = certificate_file
|
tls_kwargs["local_certificate_file"] = certificate_file
|
||||||
if ciphers := CONFIG.y("ldap.tls.ciphers", None):
|
if ciphers := CONFIG.get("ldap.tls.ciphers", None):
|
||||||
tls_kwargs["ciphers"] = ciphers.strip()
|
tls_kwargs["ciphers"] = ciphers.strip()
|
||||||
if self.sni:
|
if self.sni:
|
||||||
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()
|
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()
|
||||||
|
|
|
@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
|
||||||
types_only=False,
|
types_only=False,
|
||||||
get_operational_attributes=False,
|
get_operational_attributes=False,
|
||||||
controls=None,
|
controls=None,
|
||||||
paged_size=int(CONFIG.y("ldap.page_size", 50)),
|
paged_size=int(CONFIG.get("ldap.page_size", 50)),
|
||||||
paged_criticality=False,
|
paged_criticality=False,
|
||||||
):
|
):
|
||||||
"""Search in pages, returns each page"""
|
"""Search in pages, returns each page"""
|
||||||
|
|
|
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||||
@CELERY_APP.task(
|
@CELERY_APP.task(
|
||||||
bind=True,
|
bind=True,
|
||||||
base=MonitoredTask,
|
base=MonitoredTask,
|
||||||
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||||
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||||
)
|
)
|
||||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||||
"""Synchronization of an LDAP Source"""
|
"""Synchronization of an LDAP Source"""
|
||||||
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
|
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
|
||||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||||
if not source:
|
if not source:
|
||||||
# Because the source couldn't be found, we don't have a UID
|
# Because the source couldn't be found, we don't have a UID
|
||||||
|
|
|
@ -13,6 +13,7 @@ from rest_framework.serializers import BaseSerializer
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.flows.models import Stage
|
from authentik.flows.models import Stage
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
@ -104,7 +105,16 @@ class EmailStage(Stage):
|
||||||
def backend(self) -> BaseEmailBackend:
|
def backend(self) -> BaseEmailBackend:
|
||||||
"""Get fully configured Email Backend instance"""
|
"""Get fully configured Email Backend instance"""
|
||||||
if self.use_global_settings:
|
if self.use_global_settings:
|
||||||
return self.backend_class()
|
CONFIG.refresh("email.password")
|
||||||
|
return self.backend_class(
|
||||||
|
host=CONFIG.get("email.host"),
|
||||||
|
port=int(CONFIG.get("email.port")),
|
||||||
|
username=CONFIG.get("email.username"),
|
||||||
|
password=CONFIG.get("email.password"),
|
||||||
|
use_tls=CONFIG.get_bool("email.use_tls", False),
|
||||||
|
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
||||||
|
timeout=int(CONFIG.get("email.timeout")),
|
||||||
|
)
|
||||||
return self.backend_class(
|
return self.backend_class(
|
||||||
host=self.host,
|
host=self.host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
|
|
|
@ -13,6 +13,7 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||||
from authentik.flows.tests import FlowTestCase
|
from authentik.flows.tests import FlowTestCase
|
||||||
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
|
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.stages.email.models import EmailStage
|
from authentik.stages.email.models import EmailStage
|
||||||
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
|
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
|
||||||
|
|
||||||
|
@ -120,7 +121,7 @@ class TestEmailStage(FlowTestCase):
|
||||||
def test_use_global_settings(self):
|
def test_use_global_settings(self):
|
||||||
"""Test use_global_settings"""
|
"""Test use_global_settings"""
|
||||||
host = "some-unique-string"
|
host = "some-unique-string"
|
||||||
with self.settings(EMAIL_HOST=host):
|
with CONFIG.patch("email.host", host):
|
||||||
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
|
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
|
||||||
|
|
||||||
def test_token(self):
|
def test_token(self):
|
||||||
|
|
|
@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer):
|
||||||
ui_footer_links = ListField(
|
ui_footer_links = ListField(
|
||||||
child=FooterLinkSerializer(),
|
child=FooterLinkSerializer(),
|
||||||
read_only=True,
|
read_only=True,
|
||||||
default=CONFIG.y("footer_links", []),
|
default=CONFIG.get("footer_links", []),
|
||||||
)
|
)
|
||||||
ui_theme = ChoiceField(
|
ui_theme = ChoiceField(
|
||||||
choices=Themes.choices,
|
choices=Themes.choices,
|
||||||
|
|
|
@ -24,7 +24,7 @@ class TestTenants(APITestCase):
|
||||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||||
"branding_title": "authentik",
|
"branding_title": "authentik",
|
||||||
"matched_domain": tenant.domain,
|
"matched_domain": tenant.domain,
|
||||||
"ui_footer_links": CONFIG.y("footer_links"),
|
"ui_footer_links": CONFIG.get("footer_links"),
|
||||||
"ui_theme": Themes.AUTOMATIC,
|
"ui_theme": Themes.AUTOMATIC,
|
||||||
"default_locale": "",
|
"default_locale": "",
|
||||||
},
|
},
|
||||||
|
@ -43,7 +43,7 @@ class TestTenants(APITestCase):
|
||||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||||
"branding_title": "custom",
|
"branding_title": "custom",
|
||||||
"matched_domain": "bar.baz",
|
"matched_domain": "bar.baz",
|
||||||
"ui_footer_links": CONFIG.y("footer_links"),
|
"ui_footer_links": CONFIG.get("footer_links"),
|
||||||
"ui_theme": Themes.AUTOMATIC,
|
"ui_theme": Themes.AUTOMATIC,
|
||||||
"default_locale": "",
|
"default_locale": "",
|
||||||
},
|
},
|
||||||
|
@ -59,7 +59,7 @@ class TestTenants(APITestCase):
|
||||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||||
"branding_title": "authentik",
|
"branding_title": "authentik",
|
||||||
"matched_domain": "fallback",
|
"matched_domain": "fallback",
|
||||||
"ui_footer_links": CONFIG.y("footer_links"),
|
"ui_footer_links": CONFIG.get("footer_links"),
|
||||||
"ui_theme": Themes.AUTOMATIC,
|
"ui_theme": Themes.AUTOMATIC,
|
||||||
"default_locale": "",
|
"default_locale": "",
|
||||||
},
|
},
|
||||||
|
|
|
@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
||||||
trace = span.to_traceparent()
|
trace = span.to_traceparent()
|
||||||
return {
|
return {
|
||||||
"tenant": tenant,
|
"tenant": tenant,
|
||||||
"footer_links": CONFIG.y("footer_links"),
|
"footer_links": CONFIG.get("footer_links"),
|
||||||
"sentry_trace": trace,
|
"sentry_trace": trace,
|
||||||
"version": get_full_version(),
|
"version": get_full_version(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,21 +94,21 @@ entries:
|
||||||
prompt_data = request.context.get("prompt_data")
|
prompt_data = request.context.get("prompt_data")
|
||||||
|
|
||||||
if not request.user.group_attributes(request.http_request).get(
|
if not request.user.group_attributes(request.http_request).get(
|
||||||
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True)
|
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True)
|
||||||
):
|
):
|
||||||
if prompt_data.get("email") != request.user.email:
|
if prompt_data.get("email") != request.user.email:
|
||||||
ak_message("Not allowed to change email address.")
|
ak_message("Not allowed to change email address.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not request.user.group_attributes(request.http_request).get(
|
if not request.user.group_attributes(request.http_request).get(
|
||||||
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True)
|
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True)
|
||||||
):
|
):
|
||||||
if prompt_data.get("name") != request.user.name:
|
if prompt_data.get("name") != request.user.name:
|
||||||
ak_message("Not allowed to change name.")
|
ak_message("Not allowed to change name.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not request.user.group_attributes(request.http_request).get(
|
if not request.user.group_attributes(request.http_request).get(
|
||||||
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True)
|
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True)
|
||||||
):
|
):
|
||||||
if prompt_data.get("username") != request.user.username:
|
if prompt_data.get("username") != request.user.username:
|
||||||
ak_message("Not allowed to change username.")
|
ak_message("Not allowed to change username.")
|
||||||
|
|
|
@ -37,7 +37,7 @@ makedirs(prometheus_tmp_dir, exist_ok=True)
|
||||||
max_requests = 1000
|
max_requests = 1000
|
||||||
max_requests_jitter = 50
|
max_requests_jitter = 50
|
||||||
|
|
||||||
_debug = CONFIG.y_bool("DEBUG", False)
|
_debug = CONFIG.get_bool("DEBUG", False)
|
||||||
|
|
||||||
logconfig_dict = {
|
logconfig_dict = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
|
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
|
||||||
else:
|
else:
|
||||||
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
||||||
|
|
||||||
workers = int(CONFIG.y("web.workers", default_workers))
|
workers = int(CONFIG.get("web.workers", default_workers))
|
||||||
threads = int(CONFIG.y("web.threads", 4))
|
threads = int(CONFIG.get("web.threads", 4))
|
||||||
|
|
||||||
|
|
||||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||||
|
@ -133,7 +133,7 @@ def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||||
worker._worker_id = _next_worker_id(server)
|
worker._worker_id = _next_worker_id(server)
|
||||||
|
|
||||||
|
|
||||||
if not CONFIG.y_bool("disable_startup_analytics", False):
|
if not CONFIG.get_bool("disable_startup_analytics", False):
|
||||||
env = get_env()
|
env = get_env()
|
||||||
should_send = env not in ["dev", "ci"]
|
should_send = env not in ["dev", "ci"]
|
||||||
if should_send:
|
if should_send:
|
||||||
|
@ -158,7 +158,7 @@ if not CONFIG.y_bool("disable_startup_analytics", False):
|
||||||
except Exception: # nosec
|
except Exception: # nosec
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if CONFIG.y_bool("remote_debug"):
|
if CONFIG.get_bool("remote_debug"):
|
||||||
import debugpy
|
import debugpy
|
||||||
|
|
||||||
debugpy.listen(("0.0.0.0", 6800)) # nosec
|
debugpy.listen(("0.0.0.0", 6800)) # nosec
|
||||||
|
|
|
@ -52,15 +52,15 @@ def release_lock():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
conn = connect(
|
conn = connect(
|
||||||
dbname=CONFIG.y("postgresql.name"),
|
dbname=CONFIG.get("postgresql.name"),
|
||||||
user=CONFIG.y("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.y("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.y("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.y("postgresql.port")),
|
port=int(CONFIG.get("postgresql.port")),
|
||||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||||
)
|
)
|
||||||
curr = conn.cursor()
|
curr = conn.cursor()
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -25,7 +25,7 @@ class Migration(BaseMigration):
|
||||||
# If we already have migrations in the database, assume we're upgrading an existing install
|
# If we already have migrations in the database, assume we're upgrading an existing install
|
||||||
# and set the install id to the secret key
|
# and set the install id to the secret key
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.y("secret_key"),)
|
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.get("secret_key"),)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Otherwise assume a new install, generate an install ID based on a UUID
|
# Otherwise assume a new install, generate an install ID based on a UUID
|
||||||
|
|
|
@ -108,14 +108,14 @@ class Migration(BaseMigration):
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
# We also need to clean the cache to make sure no pickeled objects still exist
|
# We also need to clean the cache to make sure no pickeled objects still exist
|
||||||
for db in [
|
for db in [
|
||||||
CONFIG.y("redis.message_queue_db"),
|
CONFIG.get("redis.message_queue_db"),
|
||||||
CONFIG.y("redis.cache_db"),
|
CONFIG.get("redis.cache_db"),
|
||||||
CONFIG.y("redis.ws_db"),
|
CONFIG.get("redis.ws_db"),
|
||||||
]:
|
]:
|
||||||
redis = Redis(
|
redis = Redis(
|
||||||
host=CONFIG.y("redis.host"),
|
host=CONFIG.get("redis.host"),
|
||||||
port=6379,
|
port=6379,
|
||||||
db=db,
|
db=db,
|
||||||
password=CONFIG.y("redis.password"),
|
password=CONFIG.get("redis.password"),
|
||||||
)
|
)
|
||||||
redis.flushall()
|
redis.flushall()
|
||||||
|
|
|
@ -14,7 +14,7 @@ from authentik.lib.config import CONFIG
|
||||||
CONFIG.log("info", "Starting authentik bootstrap")
|
CONFIG.log("info", "Starting authentik bootstrap")
|
||||||
|
|
||||||
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
|
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
|
||||||
if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
|
if CONFIG.get("secret_key") is None or len(CONFIG.get("secret_key")) == 0:
|
||||||
CONFIG.log("info", "----------------------------------------------------------------------")
|
CONFIG.log("info", "----------------------------------------------------------------------")
|
||||||
CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.")
|
CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.")
|
||||||
CONFIG.log("info", "----------------------------------------------------------------------")
|
CONFIG.log("info", "----------------------------------------------------------------------")
|
||||||
|
@ -24,15 +24,15 @@ if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
conn = connect(
|
conn = connect(
|
||||||
dbname=CONFIG.y("postgresql.name"),
|
dbname=CONFIG.get("postgresql.name"),
|
||||||
user=CONFIG.y("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.y("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.y("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.y("postgresql.port")),
|
port=int(CONFIG.get("postgresql.port")),
|
||||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||||
)
|
)
|
||||||
conn.cursor()
|
conn.cursor()
|
||||||
break
|
break
|
||||||
|
@ -42,12 +42,12 @@ while True:
|
||||||
CONFIG.log("info", "PostgreSQL connection successful")
|
CONFIG.log("info", "PostgreSQL connection successful")
|
||||||
|
|
||||||
REDIS_PROTOCOL_PREFIX = "redis://"
|
REDIS_PROTOCOL_PREFIX = "redis://"
|
||||||
if CONFIG.y_bool("redis.tls", False):
|
if CONFIG.get_bool("redis.tls", False):
|
||||||
REDIS_PROTOCOL_PREFIX = "rediss://"
|
REDIS_PROTOCOL_PREFIX = "rediss://"
|
||||||
REDIS_URL = (
|
REDIS_URL = (
|
||||||
f"{REDIS_PROTOCOL_PREFIX}:"
|
f"{REDIS_PROTOCOL_PREFIX}:"
|
||||||
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
|
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||||
f"{int(CONFIG.y('redis.port'))}/{CONFIG.y('redis.db')}"
|
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Test Enroll flow"""
|
"""Test Enroll flow"""
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
from django.test import override_settings
|
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.support import expected_conditions as ec
|
from selenium.webdriver.support import expected_conditions as ec
|
||||||
from selenium.webdriver.support.wait import WebDriverWait
|
from selenium.webdriver.support.wait import WebDriverWait
|
||||||
|
@ -9,6 +8,7 @@ from selenium.webdriver.support.wait import WebDriverWait
|
||||||
from authentik.blueprints.tests import apply_blueprint
|
from authentik.blueprints.tests import apply_blueprint
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.stages.identification.models import IdentificationStage
|
from authentik.stages.identification.models import IdentificationStage
|
||||||
from tests.e2e.utils import SeleniumTestCase, retry
|
from tests.e2e.utils import SeleniumTestCase, retry
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class TestFlowsEnroll(SeleniumTestCase):
|
||||||
@apply_blueprint(
|
@apply_blueprint(
|
||||||
"example/flows-enrollment-email-verification.yaml",
|
"example/flows-enrollment-email-verification.yaml",
|
||||||
)
|
)
|
||||||
@override_settings(EMAIL_PORT=1025)
|
@CONFIG.patch("email.port", 1025)
|
||||||
def test_enroll_email(self):
|
def test_enroll_email(self):
|
||||||
"""Test enroll with Email verification"""
|
"""Test enroll with Email verification"""
|
||||||
# Attach enrollment flow to identification stage
|
# Attach enrollment flow to identification stage
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Test recovery flow"""
|
"""Test recovery flow"""
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
from django.test import override_settings
|
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.support import expected_conditions as ec
|
from selenium.webdriver.support import expected_conditions as ec
|
||||||
from selenium.webdriver.support.wait import WebDriverWait
|
from selenium.webdriver.support.wait import WebDriverWait
|
||||||
|
@ -10,6 +9,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.stages.identification.models import IdentificationStage
|
from authentik.stages.identification.models import IdentificationStage
|
||||||
from tests.e2e.utils import SeleniumTestCase, retry
|
from tests.e2e.utils import SeleniumTestCase, retry
|
||||||
|
@ -47,7 +47,7 @@ class TestFlowsRecovery(SeleniumTestCase):
|
||||||
@apply_blueprint(
|
@apply_blueprint(
|
||||||
"example/flows-recovery-email-verification.yaml",
|
"example/flows-recovery-email-verification.yaml",
|
||||||
)
|
)
|
||||||
@override_settings(EMAIL_PORT=1025)
|
@CONFIG.patch("email.port", 1025)
|
||||||
def test_recover_email(self):
|
def test_recover_email(self):
|
||||||
"""Test recovery with Email verification"""
|
"""Test recovery with Email verification"""
|
||||||
# Attach recovery flow to identification stage
|
# Attach recovery flow to identification stage
|
||||||
|
|
Reference in a new issue