web: directly read csrf token before injecting into request

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-01-16 16:10:55 +01:00
parent eaeab27004
commit 8008aba450

View file

@ -1,4 +1,4 @@
import { Config, Configuration, CoreApi, CurrentTenant, Middleware, ResponseContext, RootApi } from "@goauthentik/api"; import { Config, Configuration, CoreApi, CurrentTenant, FetchParams, Middleware, RequestContext, ResponseContext, RootApi } from "@goauthentik/api";
import { getCookie } from "../utils"; import { getCookie } from "../utils";
import { APIMiddleware } from "../elements/notifications/APIDrawer"; import { APIMiddleware } from "../elements/notifications/APIDrawer";
import { MessageMiddleware } from "../elements/messages/Middleware"; import { MessageMiddleware } from "../elements/messages/Middleware";
@ -50,27 +50,21 @@ export function tenant(): Promise<CurrentTenant> {
return globalTenantPromise; return globalTenantPromise;
} }
let csrfToken = getCookie("authentik_csrf"); export class CSRFMiddleware implements Middleware {
pre?(context: RequestContext): Promise<FetchParams | void> {
export class CSRFUpdaterMiddleware implements Middleware { // @ts-ignore
post?(context: ResponseContext): Promise<Response | void> { context.init.headers["X-CSRFToken"] = getCookie("authentik_csrf");
const newCsrf = getCookie("authentik_csrf"); return Promise.resolve(context);
if (newCsrf !== csrfToken) {
console.log("authentik/api: rotated CSRF token");
csrfToken = newCsrf;
}
return Promise.resolve(context.response);
} }
} }
export const DEFAULT_CONFIG = new Configuration({ export const DEFAULT_CONFIG = new Configuration({
basePath: process.env.AK_API_BASE_PATH + "/api/v3", basePath: process.env.AK_API_BASE_PATH + "/api/v3",
headers: { headers: {
"X-CSRFToken": csrfToken,
"sentry-trace": getMetaContent("sentry-trace") || "", "sentry-trace": getMetaContent("sentry-trace") || "",
}, },
middleware: [ middleware: [
new CSRFUpdaterMiddleware(), new CSRFMiddleware(),
new APIMiddleware(), new APIMiddleware(),
new MessageMiddleware(), new MessageMiddleware(),
new LoggingMiddleware(), new LoggingMiddleware(),