sources/ldap: make schema optional (#5213)

* sources/ldap: make schema optional

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* create one connection and re-use it

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use magicmock

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-04-10 21:55:56 +02:00 committed by GitHub
parent c1615d044b
commit 1ca8feb5fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 63 additions and 47 deletions

View File

@ -122,7 +122,7 @@ def blueprints_find():
) )
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint) blueprints.append(blueprint)
LOGGER.info( LOGGER.debug(
"parsed & loaded blueprint", "parsed & loaded blueprint",
hash=file_hash, hash=file_hash,
path=str(path), path=str(path),

View File

@ -16,7 +16,7 @@ class PytestTestRunner: # pragma: no cover
self.failfast = failfast self.failfast = failfast
self.keepdb = keepdb self.keepdb = keepdb
self.args = ["-vv"] self.args = ["-vv", "--full-trace"]
if self.failfast: if self.failfast:
self.args.append("--exitfirst") self.args.append("--exitfirst")
if self.keepdb: if self.keepdb:

View File

@ -2,13 +2,12 @@
from typing import Optional from typing import Optional
from django.http import HttpRequest from django.http import HttpRequest
from ldap3 import Connection
from ldap3.core.exceptions import LDAPException, LDAPInvalidCredentialsResult from ldap3.core.exceptions import LDAPException, LDAPInvalidCredentialsResult
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.auth import InbuiltBackend from authentik.core.auth import InbuiltBackend
from authentik.core.models import User from authentik.core.models import User
from authentik.sources.ldap.models import LDAP_TIMEOUT, LDAPSource from authentik.sources.ldap.models import LDAPSource
LOGGER = get_logger() LOGGER = get_logger()
LDAP_DISTINGUISHED_NAME = "distinguishedName" LDAP_DISTINGUISHED_NAME = "distinguishedName"
@ -58,12 +57,11 @@ class LDAPBackend(InbuiltBackend):
# Try to bind as new user # Try to bind as new user
LOGGER.debug("Attempting Binding as user", user=user) LOGGER.debug("Attempting Binding as user", user=user)
try: try:
temp_connection = Connection( temp_connection = source.connection(
source.server, connection_kwargs={
user=user.attributes.get(LDAP_DISTINGUISHED_NAME), "user": user.attributes.get(LDAP_DISTINGUISHED_NAME),
password=password, "password": password,
raise_exceptions=True, }
receive_timeout=LDAP_TIMEOUT,
) )
temp_connection.bind() temp_connection.bind()
return user return user

View File

@ -1,9 +1,11 @@
"""authentik LDAP Models""" """authentik LDAP Models"""
from ssl import CERT_REQUIRED from ssl import CERT_REQUIRED
from typing import Optional
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from ldap3 import ALL, RANDOM, Connection, Server, ServerPool, Tls from ldap3 import ALL, NONE, RANDOM, Connection, Server, ServerPool, Tls
from ldap3.core.exceptions import LDAPSchemaError
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.core.models import Group, PropertyMapping, Source from authentik.core.models import Group, PropertyMapping, Source
@ -103,8 +105,7 @@ class LDAPSource(Source):
return LDAPSourceSerializer return LDAPSourceSerializer
@property def server(self, **kwargs) -> Server:
def server(self) -> Server:
"""Get LDAP Server/ServerPool""" """Get LDAP Server/ServerPool"""
servers = [] servers = []
tls_kwargs = {} tls_kwargs = {}
@ -113,32 +114,45 @@ class LDAPSource(Source):
tls_kwargs["validate"] = CERT_REQUIRED tls_kwargs["validate"] = CERT_REQUIRED
if ciphers := CONFIG.y("ldap.tls.ciphers", None): if ciphers := CONFIG.y("ldap.tls.ciphers", None):
tls_kwargs["ciphers"] = ciphers.strip() tls_kwargs["ciphers"] = ciphers.strip()
kwargs = { server_kwargs = {
"get_info": ALL, "get_info": ALL,
"connect_timeout": LDAP_TIMEOUT, "connect_timeout": LDAP_TIMEOUT,
"tls": Tls(**tls_kwargs), "tls": Tls(**tls_kwargs),
} }
server_kwargs.update(kwargs)
if "," in self.server_uri: if "," in self.server_uri:
for server in self.server_uri.split(","): for server in self.server_uri.split(","):
servers.append(Server(server, **kwargs)) servers.append(Server(server, **server_kwargs))
else: else:
servers = [Server(self.server_uri, **kwargs)] servers = [Server(self.server_uri, **server_kwargs)]
return ServerPool(servers, RANDOM, active=True, exhaust=True) return ServerPool(servers, RANDOM, active=True, exhaust=True)
@property def connection(
def connection(self) -> Connection: self, server_kwargs: Optional[dict] = None, connection_kwargs: Optional[dict] = None
) -> Connection:
"""Get a fully connected and bound LDAP Connection""" """Get a fully connected and bound LDAP Connection"""
server_kwargs = server_kwargs or {}
connection_kwargs = connection_kwargs or {}
connection_kwargs.setdefault("user", self.bind_cn)
connection_kwargs.setdefault("password", self.bind_password)
connection = Connection( connection = Connection(
self.server, self.server(**server_kwargs),
raise_exceptions=True, raise_exceptions=True,
user=self.bind_cn,
password=self.bind_password,
receive_timeout=LDAP_TIMEOUT, receive_timeout=LDAP_TIMEOUT,
**connection_kwargs,
) )
if self.start_tls: if self.start_tls:
connection.start_tls(read_server_info=False) connection.start_tls(read_server_info=False)
try:
connection.bind() connection.bind()
except LDAPSchemaError as exc:
# Schema error, so try connecting without schema info
# See https://github.com/goauthentik/authentik/issues/4590
if server_kwargs.get("get_info", ALL) == NONE:
raise exc
server_kwargs["get_info"] = NONE
return self.connection(server_kwargs, connection_kwargs)
return connection return connection
class Meta: class Meta:

View File

@ -47,10 +47,11 @@ class LDAPPasswordChanger:
def __init__(self, source: LDAPSource) -> None: def __init__(self, source: LDAPSource) -> None:
self._source = source self._source = source
self._connection = source.connection()
def get_domain_root_dn(self) -> str: def get_domain_root_dn(self) -> str:
"""Attempt to get root DN via MS specific fields or generic LDAP fields""" """Attempt to get root DN via MS specific fields or generic LDAP fields"""
info = self._source.connection.server.info info = self._connection.server.info
if "rootDomainNamingContext" in info.other: if "rootDomainNamingContext" in info.other:
return info.other["rootDomainNamingContext"][0] return info.other["rootDomainNamingContext"][0]
naming_contexts = info.naming_contexts naming_contexts = info.naming_contexts
@ -61,7 +62,7 @@ class LDAPPasswordChanger:
"""Check if DOMAIN_PASSWORD_COMPLEX is enabled""" """Check if DOMAIN_PASSWORD_COMPLEX is enabled"""
root_dn = self.get_domain_root_dn() root_dn = self.get_domain_root_dn()
try: try:
root_attrs = self._source.connection.extend.standard.paged_search( root_attrs = self._connection.extend.standard.paged_search(
search_base=root_dn, search_base=root_dn,
search_filter="(objectClass=*)", search_filter="(objectClass=*)",
search_scope=BASE, search_scope=BASE,
@ -90,14 +91,14 @@ class LDAPPasswordChanger:
LOGGER.info(f"User has no {LDAP_DISTINGUISHED_NAME} set.") LOGGER.info(f"User has no {LDAP_DISTINGUISHED_NAME} set.")
return return
try: try:
self._source.connection.extend.microsoft.modify_password(user_dn, password) self._connection.extend.microsoft.modify_password(user_dn, password)
except LDAPAttributeError: except LDAPAttributeError:
self._source.connection.extend.standard.modify_password(user_dn, new_password=password) self._connection.extend.standard.modify_password(user_dn, new_password=password)
def _ad_check_password_existing(self, password: str, user_dn: str) -> bool: def _ad_check_password_existing(self, password: str, user_dn: str) -> bool:
"""Check if a password contains sAMAccount or displayName""" """Check if a password contains sAMAccount or displayName"""
users = list( users = list(
self._source.connection.extend.standard.paged_search( self._connection.extend.standard.paged_search(
search_base=user_dn, search_base=user_dn,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,
search_scope=BASE, search_scope=BASE,

View File

@ -3,6 +3,7 @@ from typing import Any, Generator
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from ldap3 import Connection
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.exceptions import PropertyMappingExpressionException
@ -19,10 +20,12 @@ class BaseLDAPSynchronizer:
_source: LDAPSource _source: LDAPSource
_logger: BoundLogger _logger: BoundLogger
_connection: Connection
_messages: list[str] _messages: list[str]
def __init__(self, source: LDAPSource): def __init__(self, source: LDAPSource):
self._source = source self._source = source
self._connection = source.connection()
self._messages = [] self._messages = []
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)

View File

@ -14,7 +14,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users and groups into authentik""" """Sync LDAP Users and groups into authentik"""
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._source.connection.extend.standard.paged_search( return self._connection.extend.standard.paged_search(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,

View File

@ -20,7 +20,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
self.group_cache: dict[str, Group] = {} self.group_cache: dict[str, Group] = {}
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._source.connection.extend.standard.paged_search( return self._connection.extend.standard.paged_search(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,

View File

@ -16,7 +16,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users into authentik""" """Sync LDAP Users into authentik"""
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._source.connection.extend.standard.paged_search( return self._connection.extend.standard.paged_search(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,

View File

@ -1,5 +1,5 @@
"""LDAP Source tests""" """LDAP Source tests"""
from unittest.mock import Mock, PropertyMock, patch from unittest.mock import MagicMock, Mock, patch
from django.db.models import Q from django.db.models import Q
from django.test import TestCase from django.test import TestCase
@ -37,7 +37,7 @@ class LDAPSyncTests(TestCase):
| Q(managed__startswith="goauthentik.io/sources/ldap/ms-") | Q(managed__startswith="goauthentik.io/sources/ldap/ms-")
) )
) )
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync()
@ -64,7 +64,7 @@ class LDAPSyncTests(TestCase):
) )
) )
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync()

View File

@ -1,5 +1,5 @@
"""LDAP Source tests""" """LDAP Source tests"""
from unittest.mock import PropertyMock, patch from unittest.mock import MagicMock, patch
from django.test import TestCase from django.test import TestCase
@ -10,7 +10,7 @@ from authentik.sources.ldap.password import LDAPPasswordChanger
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
LDAP_PASSWORD = generate_key() LDAP_PASSWORD = generate_key()
LDAP_CONNECTION_PATCH = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) LDAP_CONNECTION_PATCH = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
class LDAPPasswordTests(TestCase): class LDAPPasswordTests(TestCase):

View File

@ -1,5 +1,5 @@
"""LDAP Source tests""" """LDAP Source tests"""
from unittest.mock import PropertyMock, patch from unittest.mock import MagicMock, patch
from django.db.models import Q from django.db.models import Q
from django.test import TestCase from django.test import TestCase
@ -48,7 +48,7 @@ class LDAPSyncTests(TestCase):
) )
self.source.property_mappings.set([mapping]) self.source.property_mappings.set([mapping])
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync()
@ -69,7 +69,7 @@ class LDAPSyncTests(TestCase):
) )
) )
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
# Create the user beforehand so we can set attributes and check they aren't removed # Create the user beforehand so we can set attributes and check they aren't removed
user = User.objects.create( user = User.objects.create(
@ -103,7 +103,7 @@ class LDAPSyncTests(TestCase):
) )
) )
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync()
@ -121,11 +121,11 @@ class LDAPSyncTests(TestCase):
self.source.property_mappings_group.set( self.source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name") LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name")
) )
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
_user = create_test_admin_user() _user = create_test_admin_user()
parent_group = Group.objects.get(name=_user.username) parent_group = Group.objects.get(name=_user.username)
self.source.sync_parent_group = parent_group self.source.sync_parent_group = parent_group
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync() group_sync.sync()
@ -148,7 +148,7 @@ class LDAPSyncTests(TestCase):
self.source.property_mappings_group.set( self.source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn") LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn")
) )
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source)
@ -173,7 +173,7 @@ class LDAPSyncTests(TestCase):
self.source.property_mappings_group.set( self.source.property_mappings_group.set(
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn") LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn")
) )
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
@ -195,7 +195,7 @@ class LDAPSyncTests(TestCase):
) )
) )
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
ldap_sync_all.delay().get() ldap_sync_all.delay().get()
@ -210,6 +210,6 @@ class LDAPSyncTests(TestCase):
) )
) )
self.source.save() self.source.save()
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
ldap_sync_all.delay().get() ldap_sync_all.delay().get()