diff --git a/internal/outpost/proxyv2/application/mode_forward_caddy_test.go b/internal/outpost/proxyv2/application/mode_forward_caddy_test.go index 9942eec0b..f33c025f4 100644 --- a/internal/outpost/proxyv2/application/mode_forward_caddy_test.go +++ b/internal/outpost/proxyv2/application/mode_forward_caddy_test.go @@ -52,7 +52,7 @@ func TestForwardHandleCaddy_Single_Headers(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect]) @@ -137,7 +137,7 @@ func TestForwardHandleCaddy_Domain_Header(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect]) diff --git a/internal/outpost/proxyv2/application/mode_forward_envoy_test.go b/internal/outpost/proxyv2/application/mode_forward_envoy_test.go index 71b222139..dc56d0d0a 100644 --- a/internal/outpost/proxyv2/application/mode_forward_envoy_test.go +++ b/internal/outpost/proxyv2/application/mode_forward_envoy_test.go @@ -37,7 +37,7 @@ func TestForwardHandleEnvoy_Single_Headers(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://ext.t.goauthentik.io/app", s.Values[constants.SessionRedirect]) @@ -106,7 +106,7 @@ func TestForwardHandleEnvoy_Domain_Header(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect]) diff --git a/internal/outpost/proxyv2/application/mode_forward_traefik_test.go b/internal/outpost/proxyv2/application/mode_forward_traefik_test.go index 87598b6c4..e328bab1b 100644 --- a/internal/outpost/proxyv2/application/mode_forward_traefik_test.go +++ b/internal/outpost/proxyv2/application/mode_forward_traefik_test.go @@ -52,7 +52,7 @@ func TestForwardHandleTraefik_Single_Headers(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect]) @@ -137,7 +137,7 @@ func TestForwardHandleTraefik_Domain_Header(t *testing.T) { "client_id": []string{*a.proxyConfig.ClientId}, "redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"}, "response_type": []string{"code"}, - "state": []string{s.Values[constants.SessionOAuthState].([]string)[0]}, + "state": []string{s.Values[constants.SessionOAuthState].(string)}, } assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String()) assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect]) diff --git a/internal/outpost/proxyv2/application/oauth.go b/internal/outpost/proxyv2/application/oauth.go index 5f3b5212e..69b4f6495 100644 --- a/internal/outpost/proxyv2/application/oauth.go +++ b/internal/outpost/proxyv2/application/oauth.go @@ -45,22 +45,28 @@ func (a *Application) checkRedirectParam(r *http.Request) (string, bool) { func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request) { newState := base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(32)) - s, err := a.sessions.Get(r, constants.SessionName) - if err != nil { - s.Values[constants.SessionOAuthState] = []string{} - } - state, ok := s.Values[constants.SessionOAuthState].([]string) - if !ok { - s.Values[constants.SessionOAuthState] = []string{} - state = []string{} + s, _ := a.sessions.Get(r, constants.SessionName) + // Check if we already have a state in the session, + // and if we do we don't do anything here + currentState, ok := s.Values[constants.SessionOAuthState].(string) + if ok { + claims, err := a.getClaims(r) + if err != nil && claims != nil { + a.log.Trace("auth start request with existing authenticated session") + a.redirect(rw, r) + return + } + a.log.Trace("session already has state, sending redirect to current state") + http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(currentState), http.StatusFound) + return } rd, ok := a.checkRedirectParam(r) if ok { s.Values[constants.SessionRedirect] = rd a.log.WithField("rd", rd).Trace("Setting redirect") } - s.Values[constants.SessionOAuthState] = append(state, newState) - err = s.Save(r, rw) + s.Values[constants.SessionOAuthState] = newState + err := s.Save(r, rw) if err != nil { a.log.WithError(err).Warning("failed to save session") } @@ -75,10 +81,10 @@ func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request state, ok := s.Values[constants.SessionOAuthState] if !ok { a.log.Warning("No state saved in session") - http.Redirect(rw, r, a.proxyConfig.ExternalHost, http.StatusFound) + a.redirect(rw, r) return } - claims, err := a.redeemCallback(state.([]string), r.URL, r.Context()) + claims, err := a.redeemCallback(state.(string), r.URL, r.Context()) if err != nil { a.log.WithError(err).Warning("failed to redeem code") rw.WriteHeader(400) @@ -101,11 +107,5 @@ func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request rw.WriteHeader(400) return } - redirect := a.proxyConfig.ExternalHost - redirectR, ok := s.Values[constants.SessionRedirect] - if ok { - redirect = redirectR.(string) - } - a.log.WithField("redirect", redirect).Trace("final redirect") - http.Redirect(rw, r, redirect, http.StatusFound) + a.redirect(rw, r) } diff --git a/internal/outpost/proxyv2/application/oauth_callback.go b/internal/outpost/proxyv2/application/oauth_callback.go index f32a1cc42..73dd6b618 100644 --- a/internal/outpost/proxyv2/application/oauth_callback.go +++ b/internal/outpost/proxyv2/application/oauth_callback.go @@ -9,23 +9,13 @@ import ( "golang.org/x/oauth2" ) -func (a *Application) redeemCallback(states []string, u *url.URL, c context.Context) (*Claims, error) { +func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Context) (*Claims, error) { state := u.Query().Get("state") - if len(states) < 1 { - return nil, fmt.Errorf("no states") - } - found := false - for _, fstate := range states { - if fstate == state { - found = true - } - } a.log.WithFields(log.Fields{ - "states": states, + "states": savedState, "expected": state, - "found": found, }).Trace("tracing states") - if !found { + if savedState != state { return nil, fmt.Errorf("invalid state") } diff --git a/internal/outpost/proxyv2/application/utils.go b/internal/outpost/proxyv2/application/utils.go index 67405c2b3..79147d336 100644 --- a/internal/outpost/proxyv2/application/utils.go +++ b/internal/outpost/proxyv2/application/utils.go @@ -56,12 +56,27 @@ func (a *Application) redirectToStart(rw http.ResponseWriter, r *http.Request) { } urlArgs := url.Values{ - "rd": []string{redirectUrl}, + redirectParam: []string{redirectUrl}, } authUrl := urlJoin(a.proxyConfig.ExternalHost, "/outpost.goauthentik.io/start") http.Redirect(rw, r, authUrl+"?"+urlArgs.Encode(), http.StatusFound) } +func (a *Application) redirect(rw http.ResponseWriter, r *http.Request) { + redirect := a.proxyConfig.ExternalHost + rd, ok := a.checkRedirectParam(r) + if ok { + redirect = rd + } + s, _ := a.sessions.Get(r, constants.SessionName) + redirectR, ok := s.Values[constants.SessionRedirect] + if ok { + redirect = redirectR.(string) + } + a.log.WithField("redirect", redirect).Trace("final redirect") + http.Redirect(rw, r, redirect, http.StatusFound) +} + // getClaims Get claims which are currently in session // Returns an error if the session can't be loaded or the claims can't be parsed/type-cast func (a *Application) getClaims(r *http.Request) (*Claims, error) {