root: replace asgi-based logger with middleware
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
7cf8a31057
commit
e08077c73a
|
@ -12,7 +12,6 @@ LOCAL = local()
|
|||
RESPONSE_HEADER_ID = "X-authentik-id"
|
||||
KEY_AUTH_VIA = "auth_via"
|
||||
KEY_USER = "user"
|
||||
INTERNAL_HEADER_PREFIX = "X-authentik-internal-"
|
||||
|
||||
|
||||
class ImpersonateMiddleware:
|
||||
|
@ -53,9 +52,10 @@ class RequestIDMiddleware:
|
|||
}
|
||||
response = self.get_response(request)
|
||||
response[RESPONSE_HEADER_ID] = request.request_id
|
||||
setattr(response, "ak_context", {})
|
||||
if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None):
|
||||
response[INTERNAL_HEADER_PREFIX + KEY_AUTH_VIA] = auth_via
|
||||
response[INTERNAL_HEADER_PREFIX + KEY_USER] = request.user.username
|
||||
response.ak_context[KEY_AUTH_VIA] = auth_via
|
||||
response.ak_context[KEY_USER] = request.user.username
|
||||
for key in list(LOCAL.authentik.keys()):
|
||||
del LOCAL.authentik[key]
|
||||
return response
|
||||
|
|
|
@ -7,14 +7,11 @@ For more information on this file, see
|
|||
https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
|
||||
"""
|
||||
import django
|
||||
from asgiref.compatibility import guarantee_single_callable
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from defusedxml import defuse_stdlib
|
||||
from django.core.asgi import get_asgi_application
|
||||
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
|
||||
|
||||
from authentik.root.asgi.logger import ASGILogger
|
||||
|
||||
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
|
||||
|
||||
defuse_stdlib()
|
||||
|
@ -23,9 +20,7 @@ django.setup()
|
|||
# pylint: disable=wrong-import-position
|
||||
from authentik.root import websocket # noqa # isort:skip
|
||||
|
||||
application = ASGILogger(
|
||||
guarantee_single_callable(
|
||||
SentryAsgiMiddleware(
|
||||
application = SentryAsgiMiddleware(
|
||||
ProtocolTypeRouter(
|
||||
{
|
||||
"http": get_asgi_application(),
|
||||
|
@ -33,5 +28,3 @@ application = ASGILogger(
|
|||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
|
@ -1,107 +0,0 @@
|
|||
"""ASGI Logger"""
|
||||
from time import time
|
||||
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.middleware import INTERNAL_HEADER_PREFIX, RESPONSE_HEADER_ID
|
||||
from authentik.root.asgi.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
ASGI_IP_HEADERS = (
|
||||
b"x-forwarded-for",
|
||||
b"x-real-ip",
|
||||
)
|
||||
|
||||
LOGGER = get_logger("authentik.asgi")
|
||||
|
||||
|
||||
class ASGILogger:
|
||||
"""ASGI Logger, instantiated for each request"""
|
||||
|
||||
app: ASGIApp
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
content_length = 0
|
||||
status_code = 0
|
||||
request_id = ""
|
||||
# Copy all headers starting with X-authentik-internal
|
||||
copied_headers = {}
|
||||
location = ""
|
||||
start = time()
|
||||
|
||||
async def send_hooked(message: Message) -> None:
|
||||
"""Hooked send method, which records status code and content-length, and for the final
|
||||
requests logs it"""
|
||||
|
||||
headers = dict(message.get("headers", []))
|
||||
if "status" in message:
|
||||
nonlocal status_code
|
||||
status_code = message["status"]
|
||||
|
||||
if b"Content-Length" in headers:
|
||||
nonlocal content_length
|
||||
content_length += int(headers.get(b"Content-Length", b"0"))
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = dict(message["headers"])
|
||||
nonlocal request_id
|
||||
nonlocal copied_headers
|
||||
nonlocal location
|
||||
request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode()
|
||||
location = response_headers.get(b"Location", b"").decode()
|
||||
# Copy all internal headers to log, and remove them from the final response
|
||||
for header in list(response_headers.keys()):
|
||||
if not header.decode().startswith(INTERNAL_HEADER_PREFIX):
|
||||
continue
|
||||
copied_headers[
|
||||
header.decode().replace(INTERNAL_HEADER_PREFIX, "")
|
||||
] = response_headers[header].decode()
|
||||
del response_headers[header]
|
||||
message["headers"] = list(response_headers.items())
|
||||
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", True):
|
||||
nonlocal start
|
||||
runtime = int((time() - start) * 1000)
|
||||
kwargs = {"request_id": request_id}
|
||||
if location != "":
|
||||
kwargs["location"] = location
|
||||
kwargs.update(copied_headers)
|
||||
self.log(scope, runtime, content_length, status_code, **kwargs)
|
||||
await send(message)
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
# https://code.djangoproject.com/ticket/31508
|
||||
# https://github.com/encode/uvicorn/issues/266
|
||||
return
|
||||
return await self.app(scope, receive, send_hooked)
|
||||
|
||||
def _get_ip(self, headers: dict[bytes, bytes], scope: Scope) -> str:
|
||||
client_ip = None
|
||||
for header in ASGI_IP_HEADERS:
|
||||
if header in headers:
|
||||
client_ip = headers[header].decode()
|
||||
if not client_ip:
|
||||
client_ip, _ = scope.get("client", ("", 0))
|
||||
# Check if header has multiple values, and use the first one
|
||||
return client_ip.split(", ")[0]
|
||||
|
||||
def log(self, scope: Scope, content_length: int, runtime: float, status_code: int, **kwargs):
|
||||
"""Outpot access logs in a structured format"""
|
||||
headers = dict(scope.get("headers", []))
|
||||
host = self._get_ip(headers, scope)
|
||||
query_string = ""
|
||||
if scope.get("query_string", b"") != b"":
|
||||
query_string = f"?{scope.get('query_string').decode()}"
|
||||
LOGGER.info(
|
||||
f"{scope.get('path', '')}{query_string}",
|
||||
host=host,
|
||||
method=scope.get("method", ""),
|
||||
scheme=scope.get("scheme", ""),
|
||||
status=status_code,
|
||||
size=content_length / 1000 if content_length > 0 else 0,
|
||||
runtime=runtime,
|
||||
user_agent=headers.get(b"user-agent", b"").decode(),
|
||||
**kwargs,
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
"""ASGI Types"""
|
||||
import typing
|
||||
|
||||
# See https://github.com/encode/starlette/blob/master/starlette/types.py
|
||||
Scope = typing.MutableMapping[str, typing.Any]
|
||||
Message = typing.MutableMapping[str, typing.Any]
|
||||
|
||||
Receive = typing.Callable[[], typing.Awaitable[Message]]
|
||||
Send = typing.Callable[[Message], typing.Awaitable[None]]
|
||||
|
||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
|
@ -1,5 +1,6 @@
|
|||
"""Dynamically set SameSite depending if the upstream connection is TLS or not"""
|
||||
import time
|
||||
from time import time
|
||||
from typing import Callable
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.base import UpdateError
|
||||
|
@ -9,6 +10,13 @@ from django.http.request import HttpRequest
|
|||
from django.http.response import HttpResponse
|
||||
from django.utils.cache import patch_vary_headers
|
||||
from django.utils.http import http_date
|
||||
from structlog.stdlib import get_logger
|
||||
from typing_extensions import runtime
|
||||
|
||||
from authentik.core.middleware import KEY_AUTH_VIA, KEY_USER
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
|
||||
LOGGER = get_logger("authentik.asgi")
|
||||
|
||||
|
||||
class SessionMiddleware(UpstreamSessionMiddleware):
|
||||
|
@ -88,3 +96,36 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
|||
samesite=same_site,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class LoggingMiddleware:
|
||||
"""Logger middleware"""
|
||||
|
||||
get_response: Callable[[HttpRequest], HttpResponse]
|
||||
|
||||
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||
start = time()
|
||||
response = self.get_response(request)
|
||||
status_code = response.status_code
|
||||
kwargs = {
|
||||
"request_id": request.request_id,
|
||||
}
|
||||
kwargs.update(response.ak_context)
|
||||
self.log(request, status_code, int((time() - start) * 1000), **kwargs)
|
||||
return response
|
||||
|
||||
def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):
|
||||
"""Log request"""
|
||||
LOGGER.info(
|
||||
request.get_full_path(),
|
||||
remote=get_client_ip(request),
|
||||
method=request.method,
|
||||
scheme=request.scheme,
|
||||
status=status_code,
|
||||
runtime=runtime,
|
||||
user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -241,6 +241,7 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
|||
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
|
||||
|
||||
MIDDLEWARE = [
|
||||
"authentik.root.middleware.LoggingMiddleware",
|
||||
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
||||
"authentik.root.middleware.SessionMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
|
@ -275,7 +276,7 @@ TEMPLATES = [
|
|||
},
|
||||
]
|
||||
|
||||
ASGI_APPLICATION = "authentik.root.asgi.app.application"
|
||||
ASGI_APPLICATION = "authentik.root.asgi.application"
|
||||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
|
|
|
@ -38,7 +38,7 @@ func NewGoUnicorn() *GoUnicorn {
|
|||
|
||||
func (g *GoUnicorn) initCmd() {
|
||||
command := "gunicorn"
|
||||
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi.app:application"}
|
||||
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"}
|
||||
if config.G.Debug {
|
||||
command = "./manage.py"
|
||||
args = []string{"runserver"}
|
||||
|
|
Reference in a new issue