diff --git a/authentik/providers/saml/views/slo.py b/authentik/providers/saml/views/slo.py index aecdadec7..9996c8781 100644 --- a/authentik/providers/saml/views/slo.py +++ b/authentik/providers/saml/views/slo.py @@ -3,7 +3,7 @@ from typing import Optional from django.http import HttpRequest from django.http.response import HttpResponse -from django.shortcuts import get_object_or_404, redirect +from django.shortcuts import get_object_or_404 from django.utils.decorators import method_decorator from django.views.decorators.clickjacking import xframe_options_sameorigin from django.views.decorators.csrf import csrf_exempt @@ -11,6 +11,11 @@ from structlog.stdlib import get_logger from authentik.core.models import Application from authentik.events.models import Event, EventAction +from authentik.flows.challenge import SessionEndChallenge +from authentik.flows.models import in_memory_stage +from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner +from authentik.flows.views.executor import SESSION_KEY_PLAN +from authentik.lib.utils.urls import redirect_with_qs from authentik.lib.views import bad_request_message from authentik.policies.views import PolicyAccessView from authentik.providers.saml.exceptions import CannotHandleAssertion @@ -46,9 +51,20 @@ class SAMLSLOView(PolicyAccessView): method_response = self.check_saml_request() if method_response: return method_response - return redirect( - "authentik_core:if-session-end", - application_slug=self.kwargs["application_slug"], + planner = FlowPlanner(self.provider.invalidation_flow) + planner.allow_empty_flows = True + plan = planner.plan( + request, + { + PLAN_CONTEXT_APPLICATION: self.application, + }, + ) + plan.insert_stage(in_memory_stage(SessionEndChallenge)) + request.session[SESSION_KEY_PLAN] = plan + return redirect_with_qs( + "authentik_core:if-flow", + self.request.GET, + flow_slug=self.provider.invalidation_flow.slug, ) def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: diff --git a/tests/e2e/test_provider_oauth2_grafana.py b/tests/e2e/test_provider_oauth2_grafana.py index fdb75e1b9..48153f3a5 100644 --- a/tests/e2e/test_provider_oauth2_grafana.py +++ b/tests/e2e/test_provider_oauth2_grafana.py @@ -177,6 +177,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): ) @apply_blueprint( "default/flow-default-provider-authorization-implicit-consent.yaml", + "default/flow-default-provider-invalidation.yaml", ) @apply_blueprint( "system/providers-oauth2.yaml", @@ -189,6 +190,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): authorization_flow = Flow.objects.get( slug="default-provider-authorization-implicit-consent" ) + invalidation_flow = Flow.objects.get(slug="default-provider-invalidation-flow") provider = OAuth2Provider.objects.create( name="grafana", client_type=ClientTypes.CONFIDENTIAL, @@ -197,6 +199,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): signing_key=create_test_cert(), redirect_uris="http://localhost:3000/login/generic_oauth", authorization_flow=authorization_flow, + invalidation_flow=invalidation_flow, ) provider.property_mappings.set( ScopeMapping.objects.filter( @@ -234,8 +237,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): self.driver.get("http://localhost:3000/logout") self.wait_for_url( self.url( - "authentik_core:if-session-end", - application_slug=self.app_slug, + "authentik_core:if-flow", + flow_slug=invalidation_flow.slug, ) ) self.driver.find_element(By.ID, "logout").click() diff --git a/tests/e2e/test_provider_saml.py b/tests/e2e/test_provider_saml.py index 9252ab0c0..858866e7a 100644 --- a/tests/e2e/test_provider_saml.py +++ b/tests/e2e/test_provider_saml.py @@ -414,6 +414,7 @@ class TestProviderSAML(SeleniumTestCase): ) @apply_blueprint( "default/flow-default-provider-authorization-implicit-consent.yaml", + "default/flow-default-provider-invalidation.yaml", ) @apply_blueprint( "system/providers-saml.yaml", @@ -425,6 +426,7 @@ class TestProviderSAML(SeleniumTestCase): authorization_flow = Flow.objects.get( slug="default-provider-authorization-implicit-consent" ) + invalidation_flow = Flow.objects.get(slug="default-provider-invalidation-flow") provider: SAMLProvider = SAMLProvider.objects.create( name="saml-test", acs_url="http://localhost:9009/saml/acs", @@ -432,6 +434,7 @@ class TestProviderSAML(SeleniumTestCase): issuer="authentik-e2e", sp_binding=SAMLBindings.POST, authorization_flow=authorization_flow, + invalidation_flow=invalidation_flow, signing_kp=create_test_cert(), ) provider.property_mappings.set(SAMLPropertyMapping.objects.all()) @@ -449,7 +452,7 @@ class TestProviderSAML(SeleniumTestCase): self.driver.get("http://localhost:9009/saml/logout") self.wait_for_url( self.url( - "authentik_core:if-session-end", - application_slug=app.slug, + "authentik_core:if-flow", + flow_slug=invalidation_flow.slug, ) )