outposts/proxy: cache basic and bearer credentials for one minute
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
4c45d35507
commit
0ddcefce80
|
@ -18,6 +18,7 @@ import (
|
||||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/jellydator/ttlcache/v3"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
|
@ -48,7 +49,8 @@ type Application struct {
|
||||||
mux *mux.Router
|
mux *mux.Router
|
||||||
ak *ak.APIController
|
ak *ak.APIController
|
||||||
|
|
||||||
errorTemplates *template.Template
|
errorTemplates *template.Template
|
||||||
|
authHeaderCache *ttlcache.Cache[string, Claims]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewApplication(p api.ProxyOutpostConfig, c *http.Client, cs *ak.CryptoStore, ak *ak.APIController) (*Application, error) {
|
func NewApplication(p api.ProxyOutpostConfig, c *http.Client, cs *ak.CryptoStore, ak *ak.APIController) (*Application, error) {
|
||||||
|
@ -90,18 +92,20 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, cs *ak.CryptoStore
|
||||||
}
|
}
|
||||||
mux := mux.NewRouter()
|
mux := mux.NewRouter()
|
||||||
a := &Application{
|
a := &Application{
|
||||||
Host: externalHost.Host,
|
Host: externalHost.Host,
|
||||||
log: muxLogger,
|
log: muxLogger,
|
||||||
outpostName: ak.Outpost.Name,
|
outpostName: ak.Outpost.Name,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
oauthConfig: oauth2Config,
|
oauthConfig: oauth2Config,
|
||||||
tokenVerifier: verifier,
|
tokenVerifier: verifier,
|
||||||
proxyConfig: p,
|
proxyConfig: p,
|
||||||
httpClient: c,
|
httpClient: c,
|
||||||
mux: mux,
|
mux: mux,
|
||||||
errorTemplates: templates.GetTemplates(),
|
errorTemplates: templates.GetTemplates(),
|
||||||
ak: ak,
|
ak: ak,
|
||||||
|
authHeaderCache: ttlcache.New(ttlcache.WithDisableTouchOnHit[string, Claims]()),
|
||||||
}
|
}
|
||||||
|
go a.authHeaderCache.Start()
|
||||||
a.sessions = a.getStore(p, externalHost)
|
a.sessions = a.getStore(p, externalHost)
|
||||||
mux.Use(web.NewLoggingHandler(muxLogger, func(l *log.Entry, r *http.Request) *log.Entry {
|
mux.Use(web.NewLoggingHandler(muxLogger, func(l *log.Entry, r *http.Request) *log.Entry {
|
||||||
c := a.getClaimsFromSession(r)
|
c := a.getClaimsFromSession(r)
|
||||||
|
@ -216,6 +220,10 @@ func (a *Application) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
a.mux.ServeHTTP(rw, r)
|
a.mux.ServeHTTP(rw, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Application) Stop() {
|
||||||
|
a.authHeaderCache.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
||||||
redirect := a.endpoint.EndSessionEndpoint
|
redirect := a.endpoint.EndSessionEndpoint
|
||||||
s, err := a.sessions.Get(r, constants.SessionName)
|
s, err := a.sessions.Get(r, constants.SessionName)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package application
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||||
)
|
)
|
||||||
|
@ -13,8 +14,6 @@ const AuthBearer = "Bearer "
|
||||||
// checkAuth Get claims which are currently in session
|
// checkAuth Get claims which are currently in session
|
||||||
// Returns an error if the session can't be loaded or the claims can't be parsed/type-cast
|
// Returns an error if the session can't be loaded or the claims can't be parsed/type-cast
|
||||||
func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*Claims, error) {
|
func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*Claims, error) {
|
||||||
s, _ := a.sessions.Get(r, constants.SessionName)
|
|
||||||
|
|
||||||
c := a.getClaimsFromSession(r)
|
c := a.getClaimsFromSession(r)
|
||||||
if c != nil {
|
if c != nil {
|
||||||
return c, nil
|
return c, nil
|
||||||
|
@ -23,19 +22,18 @@ func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*Claim
|
||||||
if rw == nil {
|
if rw == nil {
|
||||||
return nil, fmt.Errorf("no response writer")
|
return nil, fmt.Errorf("no response writer")
|
||||||
}
|
}
|
||||||
|
// Check TTL cache
|
||||||
|
c = a.getClaimsFromCache(r)
|
||||||
|
if c != nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
// Check bearer token if set
|
// Check bearer token if set
|
||||||
bearer := a.checkAuthHeaderBearer(r)
|
bearer := a.checkAuthHeaderBearer(r)
|
||||||
if bearer != "" {
|
if bearer != "" {
|
||||||
a.log.Trace("checking bearer token")
|
a.log.Trace("checking bearer token")
|
||||||
tc := a.attemptBearerAuth(r, bearer)
|
tc := a.attemptBearerAuth(r, bearer)
|
||||||
if tc != nil {
|
if tc != nil {
|
||||||
s.Values[constants.SessionClaims] = tc.Claims
|
return a.saveAndCacheClaims(rw, r, tc.Claims)
|
||||||
err := s.Save(r, rw)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r.Header.Del(HeaderAuthorization)
|
|
||||||
return &tc.Claims, nil
|
|
||||||
}
|
}
|
||||||
a.log.Trace("no/invalid bearer token")
|
a.log.Trace("no/invalid bearer token")
|
||||||
}
|
}
|
||||||
|
@ -45,13 +43,7 @@ func (a *Application) checkAuth(rw http.ResponseWriter, r *http.Request) (*Claim
|
||||||
a.log.Trace("checking basic auth")
|
a.log.Trace("checking basic auth")
|
||||||
tc := a.attemptBasicAuth(username, password)
|
tc := a.attemptBasicAuth(username, password)
|
||||||
if tc != nil {
|
if tc != nil {
|
||||||
s.Values[constants.SessionClaims] = *tc
|
return a.saveAndCacheClaims(rw, r, *tc)
|
||||||
err := s.Save(r, rw)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r.Header.Del(HeaderAuthorization)
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
a.log.Trace("no/invalid basic auth")
|
a.log.Trace("no/invalid basic auth")
|
||||||
}
|
}
|
||||||
|
@ -76,3 +68,32 @@ func (a *Application) getClaimsFromSession(r *http.Request) *Claims {
|
||||||
}
|
}
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Application) getClaimsFromCache(r *http.Request) *Claims {
|
||||||
|
key := r.Header.Get(HeaderAuthorization)
|
||||||
|
item := a.authHeaderCache.Get(key)
|
||||||
|
if item != nil && !item.IsExpired() {
|
||||||
|
v := item.Value()
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) saveAndCacheClaims(rw http.ResponseWriter, r *http.Request, claims Claims) (*Claims, error) {
|
||||||
|
s, _ := a.sessions.Get(r, constants.SessionName)
|
||||||
|
|
||||||
|
s.Values[constants.SessionClaims] = claims
|
||||||
|
err := s.Save(r, rw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key := r.Header.Get(HeaderAuthorization)
|
||||||
|
item := a.authHeaderCache.Get(key)
|
||||||
|
// Don't set when the key is already found
|
||||||
|
if item == nil {
|
||||||
|
a.authHeaderCache.Set(key, claims, time.Second*60)
|
||||||
|
}
|
||||||
|
r.Header.Del(HeaderAuthorization)
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
|
@ -25,9 +25,19 @@ func (ps *ProxyServer) Refresh() error {
|
||||||
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
||||||
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
|
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
|
||||||
hc := &http.Client{
|
hc := &http.Client{
|
||||||
Transport: web.NewUserAgentTransport(constants.OutpostUserAgent()+ua, web.NewTracingTransport(rsp.Context(), ak.GetTLSTransport())),
|
Transport: web.NewUserAgentTransport(
|
||||||
|
constants.OutpostUserAgent()+ua,
|
||||||
|
web.NewTracingTransport(
|
||||||
|
rsp.Context(),
|
||||||
|
ak.GetTLSTransport(),
|
||||||
|
),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
a, err := application.NewApplication(provider, hc, ps.cryptoStore, ps.akAPI)
|
a, err := application.NewApplication(provider, hc, ps.cryptoStore, ps.akAPI)
|
||||||
|
existing, ok := apps[a.Host]
|
||||||
|
if ok {
|
||||||
|
existing.Stop()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ps.log.WithError(err).Warning("failed to setup application")
|
ps.log.WithError(err).Warning("failed to setup application")
|
||||||
} else {
|
} else {
|
||||||
|
|
Reference in a new issue