diff --git a/authentik/core/models.py b/authentik/core/models.py index e27a104b7..766df053b 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -81,6 +81,27 @@ class Group(models.Model): ) attributes = models.JSONField(default=dict, blank=True) + def is_member(self, user: "User") -> bool: + """Recursively check if `user` is member of us, or any parent.""" + query = """ + WITH RECURSIVE parents AS ( + SELECT authentik_core_group.*, 0 AS relative_depth + FROM authentik_core_group + WHERE authentik_core_group.group_uuid = %s + + UNION ALL + + SELECT authentik_core_group.*, parents.relative_depth - 1 + FROM authentik_core_group,parents + WHERE authentik_core_group.parent_id = parents.group_uuid + ) + SELECT group_uuid + FROM parents + GROUP BY group_uuid; + """ + groups = Group.objects.raw(query, [self.group_uuid]) + return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists() + def __str__(self): return f"Group {self.name}" diff --git a/authentik/core/tests/test_groups.py b/authentik/core/tests/test_groups.py new file mode 100644 index 000000000..9d66341f9 --- /dev/null +++ b/authentik/core/tests/test_groups.py @@ -0,0 +1,40 @@ +"""group tests""" +from django.test.testcases import TestCase + +from authentik.core.models import Group, User + + +class TestGroups(TestCase): + """Test group membership""" + + def test_group_membership_simple(self): + """Test simple membership""" + user = User.objects.create(username="user") + user2 = User.objects.create(username="user2") + group = Group.objects.create(name="group") + group.users.add(user) + self.assertTrue(group.is_member(user)) + self.assertFalse(group.is_member(user2)) + + def test_group_membership_parent(self): + """Test parent membership""" + user = User.objects.create(username="user") + user2 = User.objects.create(username="user2") + first = Group.objects.create(name="first") + second = Group.objects.create(name="second", parent=first) + second.users.add(user) + self.assertTrue(first.is_member(user)) + self.assertFalse(first.is_member(user2)) + + def test_group_membership_parent_extra(self): + """Test parent membership""" + user = User.objects.create(username="user") + user2 = User.objects.create(username="user2") + first = Group.objects.create(name="first") + second = Group.objects.create(name="second", parent=first) + third = Group.objects.create(name="third", parent=second) + second.users.add(user) + self.assertTrue(first.is_member(user)) + self.assertFalse(first.is_member(user2)) + self.assertFalse(third.is_member(user)) + self.assertFalse(third.is_member(user2)) diff --git a/authentik/policies/models.py b/authentik/policies/models.py index 3a4f30ffa..fa4a77e61 100644 --- a/authentik/policies/models.py +++ b/authentik/policies/models.py @@ -65,14 +65,14 @@ class PolicyBinding(SerializerModel): # This is quite an ugly hack to prevent pylint from trying # to resolve authentik_core.models.Group # as python import path - "authentik_core." + "Group", + "authentik_core.Group", on_delete=models.CASCADE, default=None, null=True, blank=True, ) user = models.ForeignKey( - "authentik_core." + "User", + "authentik_core.User", on_delete=models.CASCADE, default=None, null=True, @@ -96,7 +96,7 @@ class PolicyBinding(SerializerModel): self.policy: Policy return self.policy.passes(request) if self.group: - return PolicyResult(self.group.users.filter(pk=request.user.pk).exists()) + return PolicyResult(self.group.is_member(request.user)) if self.user: return PolicyResult(request.user == self.user) return PolicyResult(False)