diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index a6b0d9c7e..590884c92 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -10,6 +10,8 @@ from django.contrib.sessions.exceptions import SessionInterrupted from django.contrib.sessions.middleware import SessionMiddleware as UpstreamSessionMiddleware from django.http.request import HttpRequest from django.http.response import HttpResponse +from django.middleware.csrf import CSRF_SESSION_KEY +from django.middleware.csrf import CsrfViewMiddleware as UpstreamCsrfViewMiddleware from django.utils.cache import patch_vary_headers from django.utils.http import http_date from jwt import PyJWTError, decode, encode @@ -131,6 +133,29 @@ class SessionMiddleware(UpstreamSessionMiddleware): return response +class CsrfViewMiddleware(UpstreamCsrfViewMiddleware): + """Dynamically set secure depending if the upstream connection is TLS or not""" + + def _set_csrf_cookie(self, request: HttpRequest, response: HttpResponse): + if settings.CSRF_USE_SESSIONS: + if request.session.get(CSRF_SESSION_KEY) != request.META["CSRF_COOKIE"]: + request.session[CSRF_SESSION_KEY] = request.META["CSRF_COOKIE"] + else: + secure = SessionMiddleware.is_secure(request) + response.set_cookie( + settings.CSRF_COOKIE_NAME, + request.META["CSRF_COOKIE"], + max_age=settings.CSRF_COOKIE_AGE, + domain=settings.CSRF_COOKIE_DOMAIN, + path=settings.CSRF_COOKIE_PATH, + secure=secure, + httponly=settings.CSRF_COOKIE_HTTPONLY, + samesite=settings.CSRF_COOKIE_SAMESITE, + ) + # Set the Vary header since content varies with the CSRF cookie. + patch_vary_headers(response, ("Cookie",)) + + class ChannelsLoggingMiddleware: """Logging middleware for channels""" diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 838cfddf2..01553a1c8 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -226,7 +226,7 @@ MIDDLEWARE = [ "authentik.events.middleware.AuditMiddleware", "django.middleware.security.SecurityMiddleware", "django.middleware.common.CommonMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", + "authentik.root.middleware.CsrfViewMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", "authentik.core.middleware.ImpersonateMiddleware",