diff --git a/authentik/core/middleware.py b/authentik/core/middleware.py index d9d1be46e..b194349b8 100644 --- a/authentik/core/middleware.py +++ b/authentik/core/middleware.py @@ -12,7 +12,6 @@ LOCAL = local() RESPONSE_HEADER_ID = "X-authentik-id" KEY_AUTH_VIA = "auth_via" KEY_USER = "user" -INTERNAL_HEADER_PREFIX = "X-authentik-internal-" class ImpersonateMiddleware: @@ -53,9 +52,10 @@ class RequestIDMiddleware: } response = self.get_response(request) response[RESPONSE_HEADER_ID] = request.request_id + setattr(response, "ak_context", {}) if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None): - response[INTERNAL_HEADER_PREFIX + KEY_AUTH_VIA] = auth_via - response[INTERNAL_HEADER_PREFIX + KEY_USER] = request.user.username + response.ak_context[KEY_AUTH_VIA] = auth_via + response.ak_context[KEY_USER] = request.user.username for key in list(LOCAL.authentik.keys()): del LOCAL.authentik[key] return response diff --git a/authentik/root/asgi/app.py b/authentik/root/asgi.py similarity index 60% rename from authentik/root/asgi/app.py rename to authentik/root/asgi.py index 00010f86a..dc40287f7 100644 --- a/authentik/root/asgi/app.py +++ b/authentik/root/asgi.py @@ -7,14 +7,11 @@ For more information on this file, see https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ """ import django -from asgiref.compatibility import guarantee_single_callable from channels.routing import ProtocolTypeRouter, URLRouter from defusedxml import defuse_stdlib from django.core.asgi import get_asgi_application from sentry_sdk.integrations.asgi import SentryAsgiMiddleware -from authentik.root.asgi.logger import ASGILogger - # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py defuse_stdlib() @@ -23,15 +20,11 @@ django.setup() # pylint: disable=wrong-import-position from authentik.root import websocket # noqa # isort:skip -application = ASGILogger( - guarantee_single_callable( - SentryAsgiMiddleware( - ProtocolTypeRouter( - { - "http": get_asgi_application(), - "websocket": URLRouter(websocket.websocket_urlpatterns), - } - ) - ) +application = SentryAsgiMiddleware( + ProtocolTypeRouter( + { + "http": get_asgi_application(), + "websocket": URLRouter(websocket.websocket_urlpatterns), + } ) ) diff --git a/authentik/root/asgi/__init__.py b/authentik/root/asgi/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/authentik/root/asgi/logger.py b/authentik/root/asgi/logger.py deleted file mode 100644 index 6b5cade99..000000000 --- a/authentik/root/asgi/logger.py +++ /dev/null @@ -1,107 +0,0 @@ -"""ASGI Logger""" -from time import time - -from structlog.stdlib import get_logger - -from authentik.core.middleware import INTERNAL_HEADER_PREFIX, RESPONSE_HEADER_ID -from authentik.root.asgi.types import ASGIApp, Message, Receive, Scope, Send - -ASGI_IP_HEADERS = ( - b"x-forwarded-for", - b"x-real-ip", -) - -LOGGER = get_logger("authentik.asgi") - - -class ASGILogger: - """ASGI Logger, instantiated for each request""" - - app: ASGIApp - - def __init__(self, app: ASGIApp): - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - content_length = 0 - status_code = 0 - request_id = "" - # Copy all headers starting with X-authentik-internal - copied_headers = {} - location = "" - start = time() - - async def send_hooked(message: Message) -> None: - """Hooked send method, which records status code and content-length, and for the final - requests logs it""" - - headers = dict(message.get("headers", [])) - if "status" in message: - nonlocal status_code - status_code = message["status"] - - if b"Content-Length" in headers: - nonlocal content_length - content_length += int(headers.get(b"Content-Length", b"0")) - - if message["type"] == "http.response.start": - response_headers = dict(message["headers"]) - nonlocal request_id - nonlocal copied_headers - nonlocal location - request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode() - location = response_headers.get(b"Location", b"").decode() - # Copy all internal headers to log, and remove them from the final response - for header in list(response_headers.keys()): - if not header.decode().startswith(INTERNAL_HEADER_PREFIX): - continue - copied_headers[ - header.decode().replace(INTERNAL_HEADER_PREFIX, "") - ] = response_headers[header].decode() - del response_headers[header] - message["headers"] = list(response_headers.items()) - - if message["type"] == "http.response.body" and not message.get("more_body", True): - nonlocal start - runtime = int((time() - start) * 1000) - kwargs = {"request_id": request_id} - if location != "": - kwargs["location"] = location - kwargs.update(copied_headers) - self.log(scope, runtime, content_length, status_code, **kwargs) - await send(message) - - if scope["type"] == "lifespan": - # https://code.djangoproject.com/ticket/31508 - # https://github.com/encode/uvicorn/issues/266 - return - return await self.app(scope, receive, send_hooked) - - def _get_ip(self, headers: dict[bytes, bytes], scope: Scope) -> str: - client_ip = None - for header in ASGI_IP_HEADERS: - if header in headers: - client_ip = headers[header].decode() - if not client_ip: - client_ip, _ = scope.get("client", ("", 0)) - # Check if header has multiple values, and use the first one - return client_ip.split(", ")[0] - - def log(self, scope: Scope, content_length: int, runtime: float, status_code: int, **kwargs): - """Outpot access logs in a structured format""" - headers = dict(scope.get("headers", [])) - host = self._get_ip(headers, scope) - query_string = "" - if scope.get("query_string", b"") != b"": - query_string = f"?{scope.get('query_string').decode()}" - LOGGER.info( - f"{scope.get('path', '')}{query_string}", - host=host, - method=scope.get("method", ""), - scheme=scope.get("scheme", ""), - status=status_code, - size=content_length / 1000 if content_length > 0 else 0, - runtime=runtime, - user_agent=headers.get(b"user-agent", b"").decode(), - **kwargs, - ) diff --git a/authentik/root/asgi/types.py b/authentik/root/asgi/types.py deleted file mode 100644 index 4d82a197d..000000000 --- a/authentik/root/asgi/types.py +++ /dev/null @@ -1,11 +0,0 @@ -"""ASGI Types""" -import typing - -# See https://github.com/encode/starlette/blob/master/starlette/types.py -Scope = typing.MutableMapping[str, typing.Any] -Message = typing.MutableMapping[str, typing.Any] - -Receive = typing.Callable[[], typing.Awaitable[Message]] -Send = typing.Callable[[Message], typing.Awaitable[None]] - -ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index cd56adf0f..88b556b69 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -1,5 +1,6 @@ """Dynamically set SameSite depending if the upstream connection is TLS or not""" -import time +from time import time +from typing import Callable from django.conf import settings from django.contrib.sessions.backends.base import UpdateError @@ -9,6 +10,13 @@ from django.http.request import HttpRequest from django.http.response import HttpResponse from django.utils.cache import patch_vary_headers from django.utils.http import http_date +from structlog.stdlib import get_logger +from typing_extensions import runtime + +from authentik.core.middleware import KEY_AUTH_VIA, KEY_USER +from authentik.lib.utils.http import get_client_ip + +LOGGER = get_logger("authentik.asgi") class SessionMiddleware(UpstreamSessionMiddleware): @@ -88,3 +96,36 @@ class SessionMiddleware(UpstreamSessionMiddleware): samesite=same_site, ) return response + + +class LoggingMiddleware: + """Logger middleware""" + + get_response: Callable[[HttpRequest], HttpResponse] + + def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): + self.get_response = get_response + + def __call__(self, request: HttpRequest) -> HttpResponse: + start = time() + response = self.get_response(request) + status_code = response.status_code + kwargs = { + "request_id": request.request_id, + } + kwargs.update(response.ak_context) + self.log(request, status_code, int((time() - start) * 1000), **kwargs) + return response + + def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs): + """Log request""" + LOGGER.info( + request.get_full_path(), + remote=get_client_ip(request), + method=request.method, + scheme=request.scheme, + status=status_code, + runtime=runtime, + user_agent=request.META.get("HTTP_USER_AGENT", ""), + **kwargs, + ) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index b51232c1a..ba3246666 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -241,6 +241,7 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = True MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage" MIDDLEWARE = [ + "authentik.root.middleware.LoggingMiddleware", "django_prometheus.middleware.PrometheusBeforeMiddleware", "authentik.root.middleware.SessionMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", @@ -275,7 +276,7 @@ TEMPLATES = [ }, ] -ASGI_APPLICATION = "authentik.root.asgi.app.application" +ASGI_APPLICATION = "authentik.root.asgi.application" CHANNEL_LAYERS = { "default": { diff --git a/internal/gounicorn/gounicorn.go b/internal/gounicorn/gounicorn.go index 3d338f2cf..0ec080b6b 100644 --- a/internal/gounicorn/gounicorn.go +++ b/internal/gounicorn/gounicorn.go @@ -38,7 +38,7 @@ func NewGoUnicorn() *GoUnicorn { func (g *GoUnicorn) initCmd() { command := "gunicorn" - args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi.app:application"} + args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} if config.G.Debug { command = "./manage.py" args = []string{"runserver"}