diff --git a/authentik/api/authentication.py b/authentik/api/authentication.py index 25e1b6a6f..4f3e65909 100644 --- a/authentik/api/authentication.py +++ b/authentik/api/authentication.py @@ -7,7 +7,7 @@ from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request from structlog.stdlib import get_logger -from authentik.core.middleware import KEY_AUTH_VIA, LOCAL +from authentik.core.middleware import CTX_AUTH_VIA from authentik.core.models import Token, TokenIntents, User from authentik.outposts.models import Outpost from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API @@ -36,14 +36,12 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: auth_credentials = validate_auth(raw_header) if not auth_credentials: return None - if not hasattr(LOCAL, "authentik"): - LOCAL.authentik = {} # first, check traditional tokens key_token = Token.filter_not_expired( key=auth_credentials, intent=TokenIntents.INTENT_API ).first() if key_token: - LOCAL.authentik[KEY_AUTH_VIA] = "api_token" + CTX_AUTH_VIA.set("api_token") return key_token.user # then try to auth via JWT jwt_token = RefreshToken.filter_not_expired( @@ -54,12 +52,12 @@ def bearer_auth(raw_header: bytes) -> Optional[User]: # we want to check the parsed version too if SCOPE_AUTHENTIK_API not in jwt_token.scope: raise AuthenticationFailed("Token invalid/expired") - LOCAL.authentik[KEY_AUTH_VIA] = "jwt" + CTX_AUTH_VIA.set("jwt") return jwt_token.user # then try to auth via secret key (for embedded outpost/etc) user = token_secret_key(auth_credentials) if user: - LOCAL.authentik[KEY_AUTH_VIA] = "secret_key" + CTX_AUTH_VIA.set("secret_key") return user raise AuthenticationFailed("Token invalid/expired") diff --git a/authentik/core/middleware.py b/authentik/core/middleware.py index 494e7f0bb..041fb700b 100644 --- a/authentik/core/middleware.py +++ b/authentik/core/middleware.py @@ -1,19 +1,22 @@ """authentik admin Middleware to impersonate users""" -from logging import Logger -from threading import local +from contextvars import ContextVar from typing import Callable from uuid import uuid4 from django.http import HttpRequest, HttpResponse from sentry_sdk.api import set_tag +from structlog.contextvars import STRUCTLOG_KEY_PREFIX SESSION_KEY_IMPERSONATE_USER = "authentik/impersonate/user" SESSION_KEY_IMPERSONATE_ORIGINAL_USER = "authentik/impersonate/original_user" -LOCAL = local() RESPONSE_HEADER_ID = "X-authentik-id" KEY_AUTH_VIA = "auth_via" KEY_USER = "user" +CTX_REQUEST_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "request_id", default=None) +CTX_HOST = ContextVar(STRUCTLOG_KEY_PREFIX + "host", default=None) +CTX_AUTH_VIA = ContextVar(STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) + class ImpersonateMiddleware: """Middleware to impersonate users""" @@ -47,26 +50,20 @@ class RequestIDMiddleware: if not hasattr(request, "request_id"): request_id = uuid4().hex setattr(request, "request_id", request_id) - LOCAL.authentik = { - "request_id": request_id, - "host": request.get_host(), - } + CTX_REQUEST_ID.set(request_id) + CTX_HOST.set(request.get_host()) set_tag("authentik.request_id", request_id) + if hasattr(request, "user") and getattr(request.user, "is_authenticated", False): + CTX_AUTH_VIA.set("session") + else: + CTX_AUTH_VIA.set("unauthenticated") + response = self.get_response(request) + response[RESPONSE_HEADER_ID] = request.request_id setattr(response, "ak_context", {}) - response.ak_context.update(LOCAL.authentik) - response.ak_context.setdefault(KEY_USER, request.user.username) - for key in list(LOCAL.authentik.keys()): - del LOCAL.authentik[key] + response.ak_context["request_id"] = CTX_REQUEST_ID.get() + response.ak_context["host"] = CTX_HOST.get() + response.ak_context[KEY_AUTH_VIA] = CTX_AUTH_VIA.get() + response.ak_context[KEY_USER] = request.user.username return response - - -# pylint: disable=unused-argument -def structlog_add_request_id(logger: Logger, method_name: str, event_dict: dict): - """If threadlocal has authentik defined, add request_id to log""" - if hasattr(LOCAL, "authentik"): - event_dict.update(LOCAL.authentik) - if hasattr(LOCAL, "authentik_task"): - event_dict.update(LOCAL.authentik_task) - return event_dict diff --git a/authentik/events/middleware.py b/authentik/events/middleware.py index ab641e5c8..4beed8fe2 100644 --- a/authentik/events/middleware.py +++ b/authentik/events/middleware.py @@ -11,7 +11,6 @@ from django.http import HttpRequest, HttpResponse from django_otp.plugins.otp_static.models import StaticToken from guardian.models import UserObjectPermission -from authentik.core.middleware import LOCAL from authentik.core.models import AuthenticatedSession, User from authentik.events.models import Event, EventAction, Notification from authentik.events.signals import EventNewThread @@ -45,36 +44,46 @@ class AuditMiddleware: def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): self.get_response = get_response + def connect(self, request: HttpRequest): + """Connect signal for automatic logging""" + if not hasattr(request, "user"): + return + if not getattr(request.user, "is_authenticated", False): + return + if not hasattr(request, "request_id"): + return + post_save_handler = partial(self.post_save_handler, user=request.user, request=request) + pre_delete_handler = partial(self.pre_delete_handler, user=request.user, request=request) + post_save.connect( + post_save_handler, + dispatch_uid=request.request_id, + weak=False, + ) + pre_delete.connect( + pre_delete_handler, + dispatch_uid=request.request_id, + weak=False, + ) + + def disconnect(self, request: HttpRequest): + """Disconnect signals""" + if not hasattr(request, "request_id"): + return + post_save.disconnect(dispatch_uid=request.request_id) + pre_delete.disconnect(dispatch_uid=request.request_id) + def __call__(self, request: HttpRequest) -> HttpResponse: - # Connect signal for automatic logging - if hasattr(request, "user") and getattr(request.user, "is_authenticated", False): - post_save_handler = partial(self.post_save_handler, user=request.user, request=request) - pre_delete_handler = partial( - self.pre_delete_handler, user=request.user, request=request - ) - post_save.connect( - post_save_handler, - dispatch_uid=LOCAL.authentik["request_id"], - weak=False, - ) - pre_delete.connect( - pre_delete_handler, - dispatch_uid=LOCAL.authentik["request_id"], - weak=False, - ) + self.connect(request) response = self.get_response(request) - post_save.disconnect(dispatch_uid=LOCAL.authentik["request_id"]) - pre_delete.disconnect(dispatch_uid=LOCAL.authentik["request_id"]) - + self.disconnect(request) return response # pylint: disable=unused-argument def process_exception(self, request: HttpRequest, exception: Exception): """Disconnect handlers in case of exception""" - post_save.disconnect(dispatch_uid=LOCAL.authentik["request_id"]) - pre_delete.disconnect(dispatch_uid=LOCAL.authentik["request_id"]) + self.disconnect(request) if settings.DEBUG: return diff --git a/authentik/root/celery.py b/authentik/root/celery.py index 2b3d16c85..decd5b0d8 100644 --- a/authentik/root/celery.py +++ b/authentik/root/celery.py @@ -17,7 +17,7 @@ from django.conf import settings from django.db import ProgrammingError from structlog.stdlib import get_logger -from authentik.core.middleware import LOCAL +from authentik.core.middleware import CTX_AUTH_VIA, CTX_HOST, CTX_REQUEST_ID from authentik.lib.sentry import before_send from authentik.lib.utils.errors import exception_to_string @@ -48,9 +48,9 @@ def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): def task_prerun_hook(task_id: str, task, *args, **kwargs): """Log task_id on worker""" request_id = "task-" + task_id.replace("-", "") - LOCAL.authentik_task = { - "request_id": request_id, - } + CTX_REQUEST_ID.set(request_id) + CTX_AUTH_VIA.set(Ellipsis) + CTX_HOST.set(Ellipsis) LOGGER.info("Task started", task_id=task_id, task_name=task.__name__) @@ -59,10 +59,6 @@ def task_prerun_hook(task_id: str, task, *args, **kwargs): def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs): """Log task_id on worker""" LOGGER.info("Task finished", task_id=task_id, task_name=task.__name__, state=state) - if not hasattr(LOCAL, "authentik_task"): - return - for key in list(LOCAL.authentik_task.keys()): - del LOCAL.authentik_task[key] # pylint: disable=unused-argument diff --git a/authentik/root/settings.py b/authentik/root/settings.py index a374ef8b4..ea6b37f9d 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -14,7 +14,6 @@ from celery.schedules import crontab from sentry_sdk import set_tag from authentik import ENV_GIT_HASH_KEY, __version__ -from authentik.core.middleware import structlog_add_request_id from authentik.lib.config import CONFIG from authentik.lib.logging import add_process_id from authentik.lib.sentry import sentry_init @@ -380,12 +379,12 @@ structlog.configure_once( processors=[ structlog.stdlib.add_log_level, structlog.stdlib.add_logger_name, - structlog.threadlocal.merge_threadlocal_context, + structlog.contextvars.merge_contextvars, add_process_id, - structlog_add_request_id, structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.TimeStamper(fmt="iso", utc=False), structlog.processors.StackInfoRenderer(), + structlog.processors.dict_tracebacks, structlog.stdlib.ProcessorFormatter.wrap_for_formatter, ], logger_factory=structlog.stdlib.LoggerFactory(), @@ -400,6 +399,7 @@ LOG_PRE_CHAIN = [ # is not from structlog. structlog.stdlib.add_log_level, structlog.stdlib.add_logger_name, + structlog.processors.dict_tracebacks, structlog.processors.TimeStamper(), structlog.processors.StackInfoRenderer(), ] diff --git a/poetry.lock b/poetry.lock index 1421ea9ec..918696501 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1715,16 +1715,16 @@ pbr = ">=2.0.0,<2.1.0 || >2.1.0" [[package]] name = "structlog" -version = "21.5.0" +version = "22.1.0" description = "Structured Logging for Python" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.extras] -dev = ["pre-commit", "rich", "cogapp", "tomli", "coverage", "freezegun (>=0.2.8)", "pretend", "pytest-asyncio", "pytest (>=6.0)", "simplejson", "furo", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "twisted"] -docs = ["furo", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "twisted"] -tests = ["coverage", "freezegun (>=0.2.8)", "pretend", "pytest-asyncio", "pytest (>=6.0)", "simplejson"] +dev = ["pre-commit", "rich", "cogapp", "tomli", "coverage", "freezegun (>=0.2.8)", "pretend", "pytest-asyncio (>=0.17)", "pytest (>=6.0)", "simplejson", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "twisted"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "twisted"] +tests = ["coverage", "freezegun (>=0.2.8)", "pretend", "pytest-asyncio (>=0.17)", "pytest (>=6.0)", "simplejson"] [[package]] name = "swagger-spec-validator" @@ -3433,8 +3433,8 @@ stevedore = [ {file = "stevedore-3.5.0.tar.gz", hash = "sha256:f40253887d8712eaa2bb0ea3830374416736dc8ec0e22f5a65092c1174c44335"}, ] structlog = [ - {file = "structlog-21.5.0-py3-none-any.whl", hash = "sha256:fd7922e195262b337da85c2a91c84be94ccab1f8fd1957bd6986f6904e3761c8"}, - {file = "structlog-21.5.0.tar.gz", hash = "sha256:68c4c29c003714fe86834f347cb107452847ba52414390a7ee583472bde00fc9"}, + {file = "structlog-22.1.0-py3-none-any.whl", hash = "sha256:760d37b8839bd4fe1747bed7b80f7f4de160078405f4b6a1db9270ccbfce6c30"}, + {file = "structlog-22.1.0.tar.gz", hash = "sha256:94b29b1d62b2659db154f67a9379ec1770183933d6115d21f21aa25cfc9a7393"}, ] swagger-spec-validator = [ {file = "swagger-spec-validator-2.7.4.tar.gz", hash = "sha256:2aee5e1fc0503be9f8299378b10c92169572781573c6de3315e831fd0559ba73"},