diff --git a/internal/outpost/ldap/bind/memory/memory.go b/internal/outpost/ldap/bind/memory/memory.go index b8ba768d0..11f7e5acd 100644 --- a/internal/outpost/ldap/bind/memory/memory.go +++ b/internal/outpost/ldap/bind/memory/memory.go @@ -48,8 +48,8 @@ func (sb *SessionBinder) Bind(username string, req *bind.Request) (ldap.LDAPResu result, err := sb.DirectBinder.Bind(username, req) // Only cache the result if there's been an error if err == nil { - flags, ok := sb.si.GetFlags(req.BindDN) - if !ok { + flags := sb.si.GetFlags(req.BindDN) + if flags == nil { sb.log.Error("user flags not set after bind") return result, err } diff --git a/internal/outpost/ldap/instance.go b/internal/outpost/ldap/instance.go index 606be3ce2..4ab8ab9a5 100644 --- a/internal/outpost/ldap/instance.go +++ b/internal/outpost/ldap/instance.go @@ -38,7 +38,7 @@ type ProviderInstance struct { outpostPk int32 searchAllowedGroups []*strfmt.UUID boundUsersMutex sync.RWMutex - boundUsers map[string]flags.UserFlags + boundUsers map[string]*flags.UserFlags uidStartNumber int32 gidStartNumber int32 @@ -68,16 +68,19 @@ func (pi *ProviderInstance) GetOutpostName() string { return pi.outpostName } -func (pi *ProviderInstance) GetFlags(dn string) (flags.UserFlags, bool) { +func (pi *ProviderInstance) GetFlags(dn string) *flags.UserFlags { pi.boundUsersMutex.RLock() + defer pi.boundUsersMutex.RUnlock() flags, ok := pi.boundUsers[dn] - pi.boundUsersMutex.RUnlock() - return flags, ok + if !ok { + return nil + } + return flags } func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) { pi.boundUsersMutex.Lock() - pi.boundUsers[dn] = flag + pi.boundUsers[dn] = &flag pi.boundUsersMutex.Unlock() } diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 97b695d32..6d469e754 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -44,7 +44,7 @@ func (ls *LDAPServer) Refresh() error { // Get existing instance so we can transfer boundUsers existing := ls.getCurrentProvider(provider.Pk) - users := make(map[string]flags.UserFlags) + users := make(map[string]*flags.UserFlags) if existing != nil { existing.boundUsersMutex.RLock() users = existing.boundUsers diff --git a/internal/outpost/ldap/search/direct/direct.go b/internal/outpost/ldap/search/direct/direct.go index 5895ee288..88063815d 100644 --- a/internal/outpost/ldap/search/direct/direct.go +++ b/internal/outpost/ldap/search/direct/direct.go @@ -70,8 +70,8 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ds.si.GetBaseDN()) } - flags, ok := ds.si.GetFlags(req.BindDN) - if !ok { + flags := ds.si.GetFlags(req.BindDN) + if flags == nil { req.Log().Debug("User info not cached") metrics.RequestsRejected.With(prometheus.Labels{ "outpost_name": ds.si.GetOutpostName(), diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go index 496e40bb7..7403fe20c 100644 --- a/internal/outpost/ldap/search/memory/memory.go +++ b/internal/outpost/ldap/search/memory/memory.go @@ -73,8 +73,8 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN()) } - flags, ok := ms.si.GetFlags(req.BindDN) - if !ok { + flags := ms.si.GetFlags(req.BindDN) + if flags == nil { req.Log().Debug("User info not cached") metrics.RequestsRejected.With(prometheus.Labels{ "outpost_name": ms.si.GetOutpostName(), diff --git a/internal/outpost/ldap/server/base.go b/internal/outpost/ldap/server/base.go index bc8b26e06..982aa4fc2 100644 --- a/internal/outpost/ldap/server/base.go +++ b/internal/outpost/ldap/server/base.go @@ -31,7 +31,7 @@ type LDAPServerInstance interface { UsersForGroup(api.Group) []string - GetFlags(dn string) (flags.UserFlags, bool) + GetFlags(dn string) *flags.UserFlags SetFlags(dn string, flags flags.UserFlags) GetBaseEntry() *ldap.Entry