From 7d107991a26f0233ad47afe61939ba30e4266922 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Thu, 4 Feb 2021 20:22:28 +0100 Subject: [PATCH] sources/ldap: fix count for membership, fix wrong attribute being searched --- authentik/sources/ldap/sync/base.py | 2 +- authentik/sources/ldap/sync/membership.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 501d5f956..d9405b0ad 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -30,6 +30,6 @@ class BaseLDAPSynchronizer: return f"{self._source.additional_group_dn},{self._source.base_dn}" return self._source.base_dn - def sync(self): + def sync(self) -> int: """Sync function, implemented in subclass""" raise NotImplementedError() diff --git a/authentik/sources/ldap/sync/membership.py b/authentik/sources/ldap/sync/membership.py index 7fd7e9a1a..72f9c8e12 100644 --- a/authentik/sources/ldap/sync/membership.py +++ b/authentik/sources/ldap/sync/membership.py @@ -19,7 +19,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): super().__init__(source) self.group_cache: dict[str, Group] = {} - def sync(self): + def sync(self) -> int: """Iterate over all Users and assign Groups using memberOf Field""" groups = self._source.connection.extend.standard.paged_search( search_base=self.base_dn_groups, @@ -28,8 +28,10 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): attributes=[ self._source.group_membership_field, self._source.object_uniqueness_field, + LDAP_DISTINGUISHED_NAME, ], ) + membership_count = 0 for group in groups: members = group.get("attributes", {}).get( self._source.group_membership_field, [] @@ -41,13 +43,16 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): ak_group = self.get_group(group) if not ak_group: continue + membership_count += 1 + membership_count += users.count() ak_group.users.set(users) ak_group.save() self._logger.debug("Successfully updated group membership") + return membership_count def get_group(self, group_dict: dict[str, Any]) -> Optional[Group]: """Check if we fetched the group already, and if not cache it for later""" - group_uniq = group_dict.get("attributes", {}).get(LDAP_UNIQUENESS, "") + group_uniq = group_dict.get("attributes", {}).get(self._source.object_uniqueness_field, "") group_dn = group_dict.get("attributes", {}).get(LDAP_DISTINGUISHED_NAME, "") if group_uniq not in self.group_cache: groups = Group.objects.filter(