root: replace asgi-based logger with middleware

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-11-15 16:32:56 +01:00
parent 7cf8a31057
commit e08077c73a
8 changed files with 54 additions and 137 deletions

View file

@ -12,7 +12,6 @@ LOCAL = local()
RESPONSE_HEADER_ID = "X-authentik-id" RESPONSE_HEADER_ID = "X-authentik-id"
KEY_AUTH_VIA = "auth_via" KEY_AUTH_VIA = "auth_via"
KEY_USER = "user" KEY_USER = "user"
INTERNAL_HEADER_PREFIX = "X-authentik-internal-"
class ImpersonateMiddleware: class ImpersonateMiddleware:
@ -53,9 +52,10 @@ class RequestIDMiddleware:
} }
response = self.get_response(request) response = self.get_response(request)
response[RESPONSE_HEADER_ID] = request.request_id response[RESPONSE_HEADER_ID] = request.request_id
setattr(response, "ak_context", {})
if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None): if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None):
response[INTERNAL_HEADER_PREFIX + KEY_AUTH_VIA] = auth_via response.ak_context[KEY_AUTH_VIA] = auth_via
response[INTERNAL_HEADER_PREFIX + KEY_USER] = request.user.username response.ak_context[KEY_USER] = request.user.username
for key in list(LOCAL.authentik.keys()): for key in list(LOCAL.authentik.keys()):
del LOCAL.authentik[key] del LOCAL.authentik[key]
return response return response

View file

@ -7,14 +7,11 @@ For more information on this file, see
https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
""" """
import django import django
from asgiref.compatibility import guarantee_single_callable
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from defusedxml import defuse_stdlib from defusedxml import defuse_stdlib
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from authentik.root.asgi.logger import ASGILogger
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
defuse_stdlib() defuse_stdlib()
@ -23,15 +20,11 @@ django.setup()
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from authentik.root import websocket # noqa # isort:skip from authentik.root import websocket # noqa # isort:skip
application = ASGILogger( application = SentryAsgiMiddleware(
guarantee_single_callable( ProtocolTypeRouter(
SentryAsgiMiddleware( {
ProtocolTypeRouter( "http": get_asgi_application(),
{ "websocket": URLRouter(websocket.websocket_urlpatterns),
"http": get_asgi_application(), }
"websocket": URLRouter(websocket.websocket_urlpatterns),
}
)
)
) )
) )

View file

@ -1,107 +0,0 @@
"""ASGI Logger"""
from time import time
from structlog.stdlib import get_logger
from authentik.core.middleware import INTERNAL_HEADER_PREFIX, RESPONSE_HEADER_ID
from authentik.root.asgi.types import ASGIApp, Message, Receive, Scope, Send
ASGI_IP_HEADERS = (
b"x-forwarded-for",
b"x-real-ip",
)
LOGGER = get_logger("authentik.asgi")
class ASGILogger:
"""ASGI Logger, instantiated for each request"""
app: ASGIApp
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
content_length = 0
status_code = 0
request_id = ""
# Copy all headers starting with X-authentik-internal
copied_headers = {}
location = ""
start = time()
async def send_hooked(message: Message) -> None:
"""Hooked send method, which records status code and content-length, and for the final
requests logs it"""
headers = dict(message.get("headers", []))
if "status" in message:
nonlocal status_code
status_code = message["status"]
if b"Content-Length" in headers:
nonlocal content_length
content_length += int(headers.get(b"Content-Length", b"0"))
if message["type"] == "http.response.start":
response_headers = dict(message["headers"])
nonlocal request_id
nonlocal copied_headers
nonlocal location
request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode()
location = response_headers.get(b"Location", b"").decode()
# Copy all internal headers to log, and remove them from the final response
for header in list(response_headers.keys()):
if not header.decode().startswith(INTERNAL_HEADER_PREFIX):
continue
copied_headers[
header.decode().replace(INTERNAL_HEADER_PREFIX, "")
] = response_headers[header].decode()
del response_headers[header]
message["headers"] = list(response_headers.items())
if message["type"] == "http.response.body" and not message.get("more_body", True):
nonlocal start
runtime = int((time() - start) * 1000)
kwargs = {"request_id": request_id}
if location != "":
kwargs["location"] = location
kwargs.update(copied_headers)
self.log(scope, runtime, content_length, status_code, **kwargs)
await send(message)
if scope["type"] == "lifespan":
# https://code.djangoproject.com/ticket/31508
# https://github.com/encode/uvicorn/issues/266
return
return await self.app(scope, receive, send_hooked)
def _get_ip(self, headers: dict[bytes, bytes], scope: Scope) -> str:
client_ip = None
for header in ASGI_IP_HEADERS:
if header in headers:
client_ip = headers[header].decode()
if not client_ip:
client_ip, _ = scope.get("client", ("", 0))
# Check if header has multiple values, and use the first one
return client_ip.split(", ")[0]
def log(self, scope: Scope, content_length: int, runtime: float, status_code: int, **kwargs):
"""Outpot access logs in a structured format"""
headers = dict(scope.get("headers", []))
host = self._get_ip(headers, scope)
query_string = ""
if scope.get("query_string", b"") != b"":
query_string = f"?{scope.get('query_string').decode()}"
LOGGER.info(
f"{scope.get('path', '')}{query_string}",
host=host,
method=scope.get("method", ""),
scheme=scope.get("scheme", ""),
status=status_code,
size=content_length / 1000 if content_length > 0 else 0,
runtime=runtime,
user_agent=headers.get(b"user-agent", b"").decode(),
**kwargs,
)

View file

@ -1,11 +0,0 @@
"""ASGI Types"""
import typing
# See https://github.com/encode/starlette/blob/master/starlette/types.py
Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any]
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[Message], typing.Awaitable[None]]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

View file

@ -1,5 +1,6 @@
"""Dynamically set SameSite depending if the upstream connection is TLS or not""" """Dynamically set SameSite depending if the upstream connection is TLS or not"""
import time from time import time
from typing import Callable
from django.conf import settings from django.conf import settings
from django.contrib.sessions.backends.base import UpdateError from django.contrib.sessions.backends.base import UpdateError
@ -9,6 +10,13 @@ from django.http.request import HttpRequest
from django.http.response import HttpResponse from django.http.response import HttpResponse
from django.utils.cache import patch_vary_headers from django.utils.cache import patch_vary_headers
from django.utils.http import http_date from django.utils.http import http_date
from structlog.stdlib import get_logger
from typing_extensions import runtime
from authentik.core.middleware import KEY_AUTH_VIA, KEY_USER
from authentik.lib.utils.http import get_client_ip
LOGGER = get_logger("authentik.asgi")
class SessionMiddleware(UpstreamSessionMiddleware): class SessionMiddleware(UpstreamSessionMiddleware):
@ -88,3 +96,36 @@ class SessionMiddleware(UpstreamSessionMiddleware):
samesite=same_site, samesite=same_site,
) )
return response return response
class LoggingMiddleware:
"""Logger middleware"""
get_response: Callable[[HttpRequest], HttpResponse]
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
start = time()
response = self.get_response(request)
status_code = response.status_code
kwargs = {
"request_id": request.request_id,
}
kwargs.update(response.ak_context)
self.log(request, status_code, int((time() - start) * 1000), **kwargs)
return response
def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):
"""Log request"""
LOGGER.info(
request.get_full_path(),
remote=get_client_ip(request),
method=request.method,
scheme=request.scheme,
status=status_code,
runtime=runtime,
user_agent=request.META.get("HTTP_USER_AGENT", ""),
**kwargs,
)

View file

@ -241,6 +241,7 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = True
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage" MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
MIDDLEWARE = [ MIDDLEWARE = [
"authentik.root.middleware.LoggingMiddleware",
"django_prometheus.middleware.PrometheusBeforeMiddleware", "django_prometheus.middleware.PrometheusBeforeMiddleware",
"authentik.root.middleware.SessionMiddleware", "authentik.root.middleware.SessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware",
@ -275,7 +276,7 @@ TEMPLATES = [
}, },
] ]
ASGI_APPLICATION = "authentik.root.asgi.app.application" ASGI_APPLICATION = "authentik.root.asgi.application"
CHANNEL_LAYERS = { CHANNEL_LAYERS = {
"default": { "default": {

View file

@ -38,7 +38,7 @@ func NewGoUnicorn() *GoUnicorn {
func (g *GoUnicorn) initCmd() { func (g *GoUnicorn) initCmd() {
command := "gunicorn" command := "gunicorn"
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi.app:application"} args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"}
if config.G.Debug { if config.G.Debug {
command = "./manage.py" command = "./manage.py"
args = []string{"runserver"} args = []string{"runserver"}