diff --git a/authentik/core/tests/test_source_flow_manager.py b/authentik/core/tests/test_source_flow_manager.py index 94ab52ba0..9c77b1b2c 100644 --- a/authentik/core/tests/test_source_flow_manager.py +++ b/authentik/core/tests/test_source_flow_manager.py @@ -6,8 +6,12 @@ from guardian.utils import get_anonymous_user from authentik.core.models import SourceUserMatchingModes, User from authentik.core.sources.flow_manager import Action +from authentik.flows.models import Flow, FlowDesignation from authentik.lib.generators import generate_id from authentik.lib.tests.utils import get_request +from authentik.policies.denied import AccessDeniedResponse +from authentik.policies.expression.models import ExpressionPolicy +from authentik.policies.models import PolicyBinding from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.views.callback import OAuthSourceFlowManager @@ -17,7 +21,7 @@ class TestSourceFlowManager(TestCase): def setUp(self) -> None: super().setUp() - self.source = OAuthSource.objects.create(name="test") + self.source: OAuthSource = OAuthSource.objects.create(name="test") self.factory = RequestFactory() self.identifier = generate_id() @@ -143,3 +147,34 @@ class TestSourceFlowManager(TestCase): action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) flow_manager.get_flow() + + def test_error_non_applicable_flow(self): + """Test error handling when a source selected flow is non-applicable due to a policy""" + self.source.user_matching_mode = SourceUserMatchingModes.USERNAME_LINK + + flow = Flow.objects.create( + name="test", slug="test", title="test", designation=FlowDesignation.ENROLLMENT + ) + policy = ExpressionPolicy.objects.create( + name="false", expression="""ak_message("foo");return False""" + ) + PolicyBinding.objects.create( + policy=policy, + target=flow, + order=0, + ) + self.source.enrollment_flow = flow + self.source.save() + + flow_manager = OAuthSourceFlowManager( + self.source, + get_request("/", user=AnonymousUser()), + self.identifier, + {"username": "foo"}, + ) + action, _ = flow_manager.get_action() + self.assertEqual(action, Action.ENROLL) + response = flow_manager.get_flow() + self.assertIsInstance(response, AccessDeniedResponse) + # pylint: disable=no-member + self.assertEqual(response.error_message, "foo")