From 6a8be0dc710d16fe7c6ff313ed99f31aca3b12f0 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 23 Jul 2021 15:41:09 +0200 Subject: [PATCH] outposts/ldap: improve parsing of LDAP filters Signed-off-by: Jens Langhammer --- internal/outpost/ldap/instance_search.go | 11 +++-- .../outpost/ldap/instance_search_group.go | 36 +++++++++++---- internal/outpost/ldap/instance_search_user.go | 46 +++++++++++++------ 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/internal/outpost/ldap/instance_search.go b/internal/outpost/ldap/instance_search.go index 6484efcca..3b2781cb3 100644 --- a/internal/outpost/ldap/instance_search.go +++ b/internal/outpost/ldap/instance_search.go @@ -46,6 +46,11 @@ func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, } accsp.Finish() + parsedFilter, err := ldap.CompileFilter(req.Filter) + if err != nil { + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) + } + switch filterEntity { default: return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter) @@ -59,7 +64,7 @@ func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, go func() { defer wg.Done() gapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_group") - groups, _, err := parseFilterForGroup(pi.s.ac.Client.CoreApi.CoreGroupsList(gapisp.Context()), req.Filter).Execute() + groups, _, err := parseFilterForGroup(pi.s.ac.Client.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter).Execute() gapisp.Finish() if err != nil { req.log.WithError(err).Warning("failed to get groups") @@ -75,7 +80,7 @@ func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, go func() { defer wg.Done() uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user") - users, _, err := parseFilterForUser(pi.s.ac.Client.CoreApi.CoreUsersList(uapisp.Context()), req.Filter).Execute() + users, _, err := parseFilterForUser(pi.s.ac.Client.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter).Execute() uapisp.Finish() if err != nil { req.log.WithError(err).Warning("failed to get groups") @@ -90,7 +95,7 @@ func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, entries = append(gEntries, uEntries...) case UserObjectClass, "": uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user") - users, _, err := parseFilterForUser(pi.s.ac.Client.CoreApi.CoreUsersList(uapisp.Context()), req.Filter).Execute() + users, _, err := parseFilterForUser(pi.s.ac.Client.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter).Execute() uapisp.Finish() if err != nil { diff --git a/internal/outpost/ldap/instance_search_group.go b/internal/outpost/ldap/instance_search_group.go index 6c629a702..db1be7f68 100644 --- a/internal/outpost/ldap/instance_search_group.go +++ b/internal/outpost/ldap/instance_search_group.go @@ -6,17 +6,13 @@ import ( "goauthentik.io/api" ) -func parseFilterForGroup(req api.ApiCoreGroupsListRequest, filter string) api.ApiCoreGroupsListRequest { - f, err := ldap.CompileFilter(filter) - if err != nil { - return req - } +func parseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet) api.ApiCoreGroupsListRequest { switch f.Tag { case ldap.FilterEqualityMatch: return parseFilterForGroupSingle(req, f) case ldap.FilterAnd: for _, child := range f.Children { - req = parseFilterForGroupSingle(req, child) + req = parseFilterForGroup(req, child) } return req } @@ -24,10 +20,30 @@ func parseFilterForGroup(req api.ApiCoreGroupsListRequest, filter string) api.Ap } func parseFilterForGroupSingle(req api.ApiCoreGroupsListRequest, f *ber.Packet) api.ApiCoreGroupsListRequest { - v := f.Children[1].Value.(string) - switch f.Children[0].Value.(string) { - case "cn": - return req.Name(v) + // We can only handle key = value pairs here + if len(f.Children) < 2 { + return req + } + k := f.Children[0].Value + // Ensure key is string + if _, ok := k.(string); !ok { + return req + } + v := f.Children[1].Value + // Null values are ignored + if v == nil { + return req + } + // Switch on type of the value, then check the key + switch vv := v.(type) { + case string: + switch k { + case "cn": + return req.Name(vv) + } + // TODO: Support int + default: + return req } return req } diff --git a/internal/outpost/ldap/instance_search_user.go b/internal/outpost/ldap/instance_search_user.go index e5bafa0d9..bf3851a97 100644 --- a/internal/outpost/ldap/instance_search_user.go +++ b/internal/outpost/ldap/instance_search_user.go @@ -6,17 +6,13 @@ import ( "goauthentik.io/api" ) -func parseFilterForUser(req api.ApiCoreUsersListRequest, filter string) api.ApiCoreUsersListRequest { - f, err := ldap.CompileFilter(filter) - if err != nil { - return req - } +func parseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet) api.ApiCoreUsersListRequest { switch f.Tag { case ldap.FilterEqualityMatch: return parseFilterForUserSingle(req, f) case ldap.FilterAnd: for _, child := range f.Children { - req = parseFilterForUserSingle(req, child) + req = parseFilterForUser(req, child) } return req } @@ -24,15 +20,35 @@ func parseFilterForUser(req api.ApiCoreUsersListRequest, filter string) api.ApiC } func parseFilterForUserSingle(req api.ApiCoreUsersListRequest, f *ber.Packet) api.ApiCoreUsersListRequest { - v := f.Children[1].Value.(string) - switch f.Children[0].Value.(string) { - case "cn": - return req.Username(v) - case "name": - case "displayName": - return req.Name(v) - case "mail": - return req.Email(v) + // We can only handle key = value pairs here + if len(f.Children) < 2 { + return req + } + k := f.Children[0].Value + // Ensure key is string + if _, ok := k.(string); !ok { + return req + } + v := f.Children[1].Value + // Null values are ignored + if v == nil { + return req + } + // Switch on type of the value, then check the key + switch vv := v.(type) { + case string: + switch k { + case "cn": + return req.Username(vv) + case "name": + case "displayName": + return req.Name(vv) + case "mail": + return req.Email(vv) + } + // TODO: Support int + default: + return req } return req }