From ec42b597abefcda91022362501b17e6a2f77282b Mon Sep 17 00:00:00 2001 From: Jens L Date: Mon, 13 Feb 2023 16:34:47 +0100 Subject: [PATCH] providers/proxy: send token request internally, with overwritten host header (#4675) * send token request internally, with overwritten host header Signed-off-by: Jens Langhammer * fix Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/core/tasks.py | 1 + authentik/policies/tests/test_engine.py | 3 +- .../proxyv2/application/application.go | 37 ++++++++++--------- .../outpost/proxyv2/application/endpoint.go | 1 - .../proxyv2/application/oauth_callback.go | 2 +- internal/utils/web/http_host_interceptor.go | 31 ++++++++++++++++ 6 files changed, 55 insertions(+), 20 deletions(-) create mode 100644 internal/utils/web/http_host_interceptor.go diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 9f7d73b10..0b4c839e8 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -43,6 +43,7 @@ def clean_expired_models(self: MonitoredTask): amount = 0 for session in AuthenticatedSession.objects.all(): cache_key = f"{KEY_PREFIX}{session.session_key}" + value = None try: value = cache.get(cache_key) # pylint: disable=broad-except diff --git a/authentik/policies/tests/test_engine.py b/authentik/policies/tests/test_engine.py index ad33aed8a..5dd7f95e7 100644 --- a/authentik/policies/tests/test_engine.py +++ b/authentik/policies/tests/test_engine.py @@ -5,6 +5,7 @@ from django.test import TestCase from authentik.core.models import User from authentik.policies.dummy.models import DummyPolicy from authentik.policies.engine import PolicyEngine +from authentik.policies.exceptions import PolicyEngineException from authentik.policies.expression.models import ExpressionPolicy from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode from authentik.policies.tests.test_process import clear_policy_cache @@ -93,7 +94,7 @@ class TestPolicyEngine(TestCase): """Test invalid policy type""" pbm = PolicyBindingModel.objects.create() PolicyBinding.objects.create(target=pbm, policy=self.policy_wrong_type, order=0) - with self.assertRaises(TypeError): + with self.assertRaises(PolicyEngineException): engine = PolicyEngine(pbm, self.user) engine.build() diff --git a/internal/outpost/proxyv2/application/application.go b/internal/outpost/proxyv2/application/application.go index f00720631..6cddbb929 100644 --- a/internal/outpost/proxyv2/application/application.go +++ b/internal/outpost/proxyv2/application/application.go @@ -43,9 +43,10 @@ type Application struct { outpostName string sessionName string - sessions sessions.Store - proxyConfig api.ProxyOutpostConfig - httpClient *http.Client + sessions sessions.Store + proxyConfig api.ProxyOutpostConfig + httpClient *http.Client + publicHostHTTPClient *http.Client log *log.Entry mux *mux.Router @@ -110,25 +111,27 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A } mux := mux.NewRouter() + // Save cookie name, based on hashed client ID h := sha256.New() bs := string(h.Sum([]byte(*p.ClientId))) sessionName := fmt.Sprintf("authentik_proxy_%s", bs[:8]) a := &Application{ - Host: externalHost.Host, - log: muxLogger, - outpostName: server.API().Outpost.Name, - sessionName: sessionName, - endpoint: endpoint, - oauthConfig: oauth2Config, - tokenVerifier: verifier, - proxyConfig: p, - httpClient: c, - mux: mux, - errorTemplates: templates.GetTemplates(), - ak: server.API(), - authHeaderCache: ttlcache.New(ttlcache.WithDisableTouchOnHit[string, Claims]()), - srv: server, + Host: externalHost.Host, + log: muxLogger, + outpostName: server.API().Outpost.Name, + sessionName: sessionName, + endpoint: endpoint, + oauthConfig: oauth2Config, + tokenVerifier: verifier, + proxyConfig: p, + httpClient: c, + publicHostHTTPClient: web.NewHostInterceptor(c, server.API().Outpost.Config["authentik_host"].(string)), + mux: mux, + errorTemplates: templates.GetTemplates(), + ak: server.API(), + authHeaderCache: ttlcache.New(ttlcache.WithDisableTouchOnHit[string, Claims]()), + srv: server, } go a.authHeaderCache.Start() a.sessions = a.getStore(p, externalHost) diff --git a/internal/outpost/proxyv2/application/endpoint.go b/internal/outpost/proxyv2/application/endpoint.go index 8c32a63fb..28735db14 100644 --- a/internal/outpost/proxyv2/application/endpoint.go +++ b/internal/outpost/proxyv2/application/endpoint.go @@ -64,7 +64,6 @@ func GetOIDCEndpoint(p api.ProxyOutpostConfig, authentikHost string, embedded bo ep.AuthURL = updateURL(authUrl, aku.Scheme, aku.Host) ep.EndSessionEndpoint = updateURL(endUrl, aku.Scheme, aku.Host) ep.JwksUri = updateURL(jwksUrl, aku.Scheme, aku.Host) - ep.TokenURL = updateURL(tokenUrl, aku.Scheme, aku.Host) ep.Issuer = updateURL(ep.Issuer, aku.Scheme, aku.Host) return ep } diff --git a/internal/outpost/proxyv2/application/oauth_callback.go b/internal/outpost/proxyv2/application/oauth_callback.go index 36ec3a3a5..eef418a84 100644 --- a/internal/outpost/proxyv2/application/oauth_callback.go +++ b/internal/outpost/proxyv2/application/oauth_callback.go @@ -24,7 +24,7 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co return nil, fmt.Errorf("blank code") } - ctx := context.WithValue(c, oauth2.HTTPClient, a.httpClient) + ctx := context.WithValue(c, oauth2.HTTPClient, a.publicHostHTTPClient) // Verify state and errors. oauth2Token, err := a.oauthConfig.Exchange(ctx, code) if err != nil { diff --git a/internal/utils/web/http_host_interceptor.go b/internal/utils/web/http_host_interceptor.go new file mode 100644 index 000000000..d3c7de862 --- /dev/null +++ b/internal/utils/web/http_host_interceptor.go @@ -0,0 +1,31 @@ +package web + +import ( + "net/http" + "net/url" + + log "github.com/sirupsen/logrus" +) + +type hostInterceptor struct { + inner http.RoundTripper + host string +} + +func (t hostInterceptor) RoundTrip(r *http.Request) (*http.Response, error) { + r.Host = t.host + return t.inner.RoundTrip(r) +} + +func NewHostInterceptor(inner *http.Client, host string) *http.Client { + aku, err := url.Parse(host) + if err != nil { + log.WithField("host", host).WithError(err).Warn("failed to parse host") + } + return &http.Client{ + Transport: hostInterceptor{ + inner: inner.Transport, + host: aku.Host, + }, + } +}