diff --git a/authentik/core/models.py b/authentik/core/models.py index de4b59222..7f4c25ca7 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -160,8 +160,8 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): """Recursively get all groups this user is a member of. At least one query is done to get the direct groups of the user, with groups there are at most 3 queries done""" - direct_groups = tuple( - str(x) for x in self.ak_groups.all().values_list("pk", flat=True).iterator() + direct_groups = list( + x for x in self.ak_groups.all().values_list("pk", flat=True).iterator() ) if len(direct_groups) < 1: return Group.objects.none() @@ -169,7 +169,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): WITH RECURSIVE parents AS ( SELECT authentik_core_group.*, 0 AS relative_depth FROM authentik_core_group - WHERE authentik_core_group.group_uuid IN (%s) + WHERE authentik_core_group.group_uuid = ANY(%s) UNION ALL @@ -185,7 +185,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser): GROUP BY group_uuid, name ORDER BY name; """ - group_pks = [group.pk for group in Group.objects.raw(query, direct_groups).iterator()] + group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()] return Group.objects.filter(pk__in=group_pks) def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: diff --git a/authentik/core/tests/test_groups.py b/authentik/core/tests/test_groups.py index 5ac6e7e67..65a0cc3e6 100644 --- a/authentik/core/tests/test_groups.py +++ b/authentik/core/tests/test_groups.py @@ -13,7 +13,9 @@ class TestGroups(TestCase): user = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id()) group = Group.objects.create(name=generate_id()) + other_group = Group.objects.create(name=generate_id()) group.users.add(user) + other_group.users.add(user) self.assertTrue(group.is_member(user)) self.assertFalse(group.is_member(user2))