root: replace asgi-based logger with middleware
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
7cf8a31057
commit
e08077c73a
|
@ -12,7 +12,6 @@ LOCAL = local()
|
||||||
RESPONSE_HEADER_ID = "X-authentik-id"
|
RESPONSE_HEADER_ID = "X-authentik-id"
|
||||||
KEY_AUTH_VIA = "auth_via"
|
KEY_AUTH_VIA = "auth_via"
|
||||||
KEY_USER = "user"
|
KEY_USER = "user"
|
||||||
INTERNAL_HEADER_PREFIX = "X-authentik-internal-"
|
|
||||||
|
|
||||||
|
|
||||||
class ImpersonateMiddleware:
|
class ImpersonateMiddleware:
|
||||||
|
@ -53,9 +52,10 @@ class RequestIDMiddleware:
|
||||||
}
|
}
|
||||||
response = self.get_response(request)
|
response = self.get_response(request)
|
||||||
response[RESPONSE_HEADER_ID] = request.request_id
|
response[RESPONSE_HEADER_ID] = request.request_id
|
||||||
|
setattr(response, "ak_context", {})
|
||||||
if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None):
|
if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None):
|
||||||
response[INTERNAL_HEADER_PREFIX + KEY_AUTH_VIA] = auth_via
|
response.ak_context[KEY_AUTH_VIA] = auth_via
|
||||||
response[INTERNAL_HEADER_PREFIX + KEY_USER] = request.user.username
|
response.ak_context[KEY_USER] = request.user.username
|
||||||
for key in list(LOCAL.authentik.keys()):
|
for key in list(LOCAL.authentik.keys()):
|
||||||
del LOCAL.authentik[key]
|
del LOCAL.authentik[key]
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -7,14 +7,11 @@ For more information on this file, see
|
||||||
https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
|
https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
|
||||||
"""
|
"""
|
||||||
import django
|
import django
|
||||||
from asgiref.compatibility import guarantee_single_callable
|
|
||||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||||
from defusedxml import defuse_stdlib
|
from defusedxml import defuse_stdlib
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
|
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
|
||||||
|
|
||||||
from authentik.root.asgi.logger import ASGILogger
|
|
||||||
|
|
||||||
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
|
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
|
||||||
|
|
||||||
defuse_stdlib()
|
defuse_stdlib()
|
||||||
|
@ -23,15 +20,11 @@ django.setup()
|
||||||
# pylint: disable=wrong-import-position
|
# pylint: disable=wrong-import-position
|
||||||
from authentik.root import websocket # noqa # isort:skip
|
from authentik.root import websocket # noqa # isort:skip
|
||||||
|
|
||||||
application = ASGILogger(
|
application = SentryAsgiMiddleware(
|
||||||
guarantee_single_callable(
|
|
||||||
SentryAsgiMiddleware(
|
|
||||||
ProtocolTypeRouter(
|
ProtocolTypeRouter(
|
||||||
{
|
{
|
||||||
"http": get_asgi_application(),
|
"http": get_asgi_application(),
|
||||||
"websocket": URLRouter(websocket.websocket_urlpatterns),
|
"websocket": URLRouter(websocket.websocket_urlpatterns),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
|
@ -1,107 +0,0 @@
|
||||||
"""ASGI Logger"""
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
from structlog.stdlib import get_logger
|
|
||||||
|
|
||||||
from authentik.core.middleware import INTERNAL_HEADER_PREFIX, RESPONSE_HEADER_ID
|
|
||||||
from authentik.root.asgi.types import ASGIApp, Message, Receive, Scope, Send
|
|
||||||
|
|
||||||
ASGI_IP_HEADERS = (
|
|
||||||
b"x-forwarded-for",
|
|
||||||
b"x-real-ip",
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER = get_logger("authentik.asgi")
|
|
||||||
|
|
||||||
|
|
||||||
class ASGILogger:
|
|
||||||
"""ASGI Logger, instantiated for each request"""
|
|
||||||
|
|
||||||
app: ASGIApp
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp):
|
|
||||||
self.app = app
|
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
||||||
content_length = 0
|
|
||||||
status_code = 0
|
|
||||||
request_id = ""
|
|
||||||
# Copy all headers starting with X-authentik-internal
|
|
||||||
copied_headers = {}
|
|
||||||
location = ""
|
|
||||||
start = time()
|
|
||||||
|
|
||||||
async def send_hooked(message: Message) -> None:
|
|
||||||
"""Hooked send method, which records status code and content-length, and for the final
|
|
||||||
requests logs it"""
|
|
||||||
|
|
||||||
headers = dict(message.get("headers", []))
|
|
||||||
if "status" in message:
|
|
||||||
nonlocal status_code
|
|
||||||
status_code = message["status"]
|
|
||||||
|
|
||||||
if b"Content-Length" in headers:
|
|
||||||
nonlocal content_length
|
|
||||||
content_length += int(headers.get(b"Content-Length", b"0"))
|
|
||||||
|
|
||||||
if message["type"] == "http.response.start":
|
|
||||||
response_headers = dict(message["headers"])
|
|
||||||
nonlocal request_id
|
|
||||||
nonlocal copied_headers
|
|
||||||
nonlocal location
|
|
||||||
request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode()
|
|
||||||
location = response_headers.get(b"Location", b"").decode()
|
|
||||||
# Copy all internal headers to log, and remove them from the final response
|
|
||||||
for header in list(response_headers.keys()):
|
|
||||||
if not header.decode().startswith(INTERNAL_HEADER_PREFIX):
|
|
||||||
continue
|
|
||||||
copied_headers[
|
|
||||||
header.decode().replace(INTERNAL_HEADER_PREFIX, "")
|
|
||||||
] = response_headers[header].decode()
|
|
||||||
del response_headers[header]
|
|
||||||
message["headers"] = list(response_headers.items())
|
|
||||||
|
|
||||||
if message["type"] == "http.response.body" and not message.get("more_body", True):
|
|
||||||
nonlocal start
|
|
||||||
runtime = int((time() - start) * 1000)
|
|
||||||
kwargs = {"request_id": request_id}
|
|
||||||
if location != "":
|
|
||||||
kwargs["location"] = location
|
|
||||||
kwargs.update(copied_headers)
|
|
||||||
self.log(scope, runtime, content_length, status_code, **kwargs)
|
|
||||||
await send(message)
|
|
||||||
|
|
||||||
if scope["type"] == "lifespan":
|
|
||||||
# https://code.djangoproject.com/ticket/31508
|
|
||||||
# https://github.com/encode/uvicorn/issues/266
|
|
||||||
return
|
|
||||||
return await self.app(scope, receive, send_hooked)
|
|
||||||
|
|
||||||
def _get_ip(self, headers: dict[bytes, bytes], scope: Scope) -> str:
|
|
||||||
client_ip = None
|
|
||||||
for header in ASGI_IP_HEADERS:
|
|
||||||
if header in headers:
|
|
||||||
client_ip = headers[header].decode()
|
|
||||||
if not client_ip:
|
|
||||||
client_ip, _ = scope.get("client", ("", 0))
|
|
||||||
# Check if header has multiple values, and use the first one
|
|
||||||
return client_ip.split(", ")[0]
|
|
||||||
|
|
||||||
def log(self, scope: Scope, content_length: int, runtime: float, status_code: int, **kwargs):
|
|
||||||
"""Outpot access logs in a structured format"""
|
|
||||||
headers = dict(scope.get("headers", []))
|
|
||||||
host = self._get_ip(headers, scope)
|
|
||||||
query_string = ""
|
|
||||||
if scope.get("query_string", b"") != b"":
|
|
||||||
query_string = f"?{scope.get('query_string').decode()}"
|
|
||||||
LOGGER.info(
|
|
||||||
f"{scope.get('path', '')}{query_string}",
|
|
||||||
host=host,
|
|
||||||
method=scope.get("method", ""),
|
|
||||||
scheme=scope.get("scheme", ""),
|
|
||||||
status=status_code,
|
|
||||||
size=content_length / 1000 if content_length > 0 else 0,
|
|
||||||
runtime=runtime,
|
|
||||||
user_agent=headers.get(b"user-agent", b"").decode(),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
|
@ -1,11 +0,0 @@
|
||||||
"""ASGI Types"""
|
|
||||||
import typing
|
|
||||||
|
|
||||||
# See https://github.com/encode/starlette/blob/master/starlette/types.py
|
|
||||||
Scope = typing.MutableMapping[str, typing.Any]
|
|
||||||
Message = typing.MutableMapping[str, typing.Any]
|
|
||||||
|
|
||||||
Receive = typing.Callable[[], typing.Awaitable[Message]]
|
|
||||||
Send = typing.Callable[[Message], typing.Awaitable[None]]
|
|
||||||
|
|
||||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Dynamically set SameSite depending if the upstream connection is TLS or not"""
|
"""Dynamically set SameSite depending if the upstream connection is TLS or not"""
|
||||||
import time
|
from time import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.sessions.backends.base import UpdateError
|
from django.contrib.sessions.backends.base import UpdateError
|
||||||
|
@ -9,6 +10,13 @@ from django.http.request import HttpRequest
|
||||||
from django.http.response import HttpResponse
|
from django.http.response import HttpResponse
|
||||||
from django.utils.cache import patch_vary_headers
|
from django.utils.cache import patch_vary_headers
|
||||||
from django.utils.http import http_date
|
from django.utils.http import http_date
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
from typing_extensions import runtime
|
||||||
|
|
||||||
|
from authentik.core.middleware import KEY_AUTH_VIA, KEY_USER
|
||||||
|
from authentik.lib.utils.http import get_client_ip
|
||||||
|
|
||||||
|
LOGGER = get_logger("authentik.asgi")
|
||||||
|
|
||||||
|
|
||||||
class SessionMiddleware(UpstreamSessionMiddleware):
|
class SessionMiddleware(UpstreamSessionMiddleware):
|
||||||
|
@ -88,3 +96,36 @@ class SessionMiddleware(UpstreamSessionMiddleware):
|
||||||
samesite=same_site,
|
samesite=same_site,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingMiddleware:
|
||||||
|
"""Logger middleware"""
|
||||||
|
|
||||||
|
get_response: Callable[[HttpRequest], HttpResponse]
|
||||||
|
|
||||||
|
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
|
||||||
|
self.get_response = get_response
|
||||||
|
|
||||||
|
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||||
|
start = time()
|
||||||
|
response = self.get_response(request)
|
||||||
|
status_code = response.status_code
|
||||||
|
kwargs = {
|
||||||
|
"request_id": request.request_id,
|
||||||
|
}
|
||||||
|
kwargs.update(response.ak_context)
|
||||||
|
self.log(request, status_code, int((time() - start) * 1000), **kwargs)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):
|
||||||
|
"""Log request"""
|
||||||
|
LOGGER.info(
|
||||||
|
request.get_full_path(),
|
||||||
|
remote=get_client_ip(request),
|
||||||
|
method=request.method,
|
||||||
|
scheme=request.scheme,
|
||||||
|
status=status_code,
|
||||||
|
runtime=runtime,
|
||||||
|
user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
|
@ -241,6 +241,7 @@ SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
||||||
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
|
MESSAGE_STORAGE = "authentik.root.messages.storage.ChannelsStorage"
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
|
"authentik.root.middleware.LoggingMiddleware",
|
||||||
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
||||||
"authentik.root.middleware.SessionMiddleware",
|
"authentik.root.middleware.SessionMiddleware",
|
||||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
|
@ -275,7 +276,7 @@ TEMPLATES = [
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
ASGI_APPLICATION = "authentik.root.asgi.app.application"
|
ASGI_APPLICATION = "authentik.root.asgi.application"
|
||||||
|
|
||||||
CHANNEL_LAYERS = {
|
CHANNEL_LAYERS = {
|
||||||
"default": {
|
"default": {
|
||||||
|
|
|
@ -38,7 +38,7 @@ func NewGoUnicorn() *GoUnicorn {
|
||||||
|
|
||||||
func (g *GoUnicorn) initCmd() {
|
func (g *GoUnicorn) initCmd() {
|
||||||
command := "gunicorn"
|
command := "gunicorn"
|
||||||
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi.app:application"}
|
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"}
|
||||||
if config.G.Debug {
|
if config.G.Debug {
|
||||||
command = "./manage.py"
|
command = "./manage.py"
|
||||||
args = []string{"runserver"}
|
args = []string{"runserver"}
|
||||||
|
|
Reference in a new issue