diff --git a/internal/outpost/ldap/instance.go b/internal/outpost/ldap/instance.go index 4ab8ab9a5..1d98ef1df 100644 --- a/internal/outpost/ldap/instance.go +++ b/internal/outpost/ldap/instance.go @@ -34,6 +34,7 @@ type ProviderInstance struct { tlsServerName *string cert *tls.Certificate + certUUID string outpostName string outpostPk int32 searchAllowedGroups []*strfmt.UUID diff --git a/internal/outpost/ldap/ldap_tls.go b/internal/outpost/ldap/ldap_tls.go index 764ec086a..e46a28090 100644 --- a/internal/outpost/ldap/ldap_tls.go +++ b/internal/outpost/ldap/ldap_tls.go @@ -15,6 +15,7 @@ func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certifica return ls.providers[0].cert, nil } } + allIdenticalCerts := true for _, provider := range ls.providers { if provider.tlsServerName == &info.ServerName { if provider.cert == nil { @@ -23,6 +24,13 @@ func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certifica } return provider.cert, nil } + if provider.certUUID != ls.providers[0].certUUID { + allIdenticalCerts = false + } + } + if allIdenticalCerts { + ls.log.WithField("server-name", info.ServerName).Debug("all providers have the same keypair, using keypair") + return ls.providers[0].cert, nil } ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert") return ls.defaultCert, nil diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 57fc60e3f..34067a876 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -70,13 +70,13 @@ func (ls *LDAPServer) Refresh() error { outpostName: ls.ac.Outpost.Name, outpostPk: provider.Pk, } - if provider.Certificate.Get() != nil { - kp := provider.Certificate.Get() + if kp := provider.Certificate.Get(); kp != nil { err := ls.cs.AddKeypair(*kp) if err != nil { ls.log.WithError(err).Warning("Failed to initially fetch certificate") } providers[idx].cert = ls.cs.Get(*kp) + providers[idx].certUUID = *kp } if *provider.SearchMode.Ptr() == api.LDAPAPIACCESSMODE_CACHED { providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx])