outpost: rewrite re-connect logic without recws

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-12-11 22:48:33 +01:00
parent 8abc9cc031
commit f4988bc45e
4 changed files with 64 additions and 39 deletions

1
go.mod
View File

@ -27,7 +27,6 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/recws-org/recws v1.3.1
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
goauthentik.io/api v0.2021104.11 goauthentik.io/api v0.2021104.11
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect

3
go.sum
View File

@ -356,7 +356,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -481,8 +480,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/recws-org/recws v1.3.1 h1:vtRhYpgNPBs3iFyu/+zxBqNzLYgID7UPC5siThkvbs0=
github.com/recws-org/recws v1.3.1/go.mod h1:gRH/uJLMsO7lbcecAB1Im1Zc6eKxs93ftGR0R39QeYA=
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=

View File

@ -11,10 +11,9 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/go-openapi/strfmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/recws-org/recws"
"goauthentik.io/api" "goauthentik.io/api"
"goauthentik.io/internal/constants" "goauthentik.io/internal/constants"
@ -36,9 +35,12 @@ type APIController struct {
logger *log.Entry logger *log.Entry
reloadOffset time.Duration reloadOffset time.Duration
lastWsReconnect time.Time
wsConn *recws.RecConn wsConn *websocket.Conn
lastWsReconnect time.Time
wsIsReconnecting bool
wsBackoffMultiplier int
instanceUUID uuid.UUID instanceUUID uuid.UUID
} }
@ -88,9 +90,13 @@ func NewAPIController(akURL url.URL, token string) *APIController {
reloadOffset: time.Duration(rand.Intn(10)) * time.Second, reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(), instanceUUID: uuid.New(),
Outpost: outpost, Outpost: outpost,
wsBackoffMultiplier: 1,
} }
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
ac.initWS(akURL, strfmt.UUID(outpost.Pk)) err = ac.initWS(akURL, outpost.Pk)
if err != nil {
go ac.reconnectWS()
}
ac.configureRefreshSignal() ac.configureRefreshSignal()
return ac return ac
} }
@ -148,10 +154,6 @@ func (a *APIController) StartBackgorundTasks() error {
"version": constants.VERSION, "version": constants.VERSION,
"build": constants.BUILD(), "build": constants.BUILD(),
}).Set(1) }).Set(1)
go func() {
a.logger.Debug("Starting WS re-connector...")
a.startWSReConnector()
}()
go func() { go func() {
a.logger.Debug("Starting WS Handler...") a.logger.Debug("Starting WS Handler...")
a.startWSHandler() a.startWSHandler()

View File

@ -6,18 +6,17 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
"github.com/go-openapi/strfmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/recws-org/recws"
"goauthentik.io/internal/constants" "goauthentik.io/internal/constants"
) )
func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) { func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
pathTemplate := "%s://%s/ws/outpost/%s/" pathTemplate := "%s://%s/ws/outpost/%s/?%s"
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws") scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", ac.token) authHeader := fmt.Sprintf("Bearer %s", ac.token)
@ -32,15 +31,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
value = "false" value = "false"
} }
ws := &recws.RecConn{ dialer := websocket.Dialer{
NonVerbose: true, Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: strings.ToLower(value) == "true", InsecureSkipVerify: strings.ToLower(value) == "true",
}, },
} }
ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header)
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("Connecting to authentik") ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID, akURL.Query().Encode()), header)
if err != nil {
ac.logger.WithError(err).Warning("failed to connect websocket")
return err
}
ac.wsConn = ws ac.wsConn = ws
// Send hello message with our version // Send hello message with our version
@ -52,11 +55,14 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
"uuid": ac.instanceUUID.String(), "uuid": ac.instanceUUID.String(),
}, },
} }
err := ws.WriteJSON(msg) err = ws.WriteJSON(msg)
if err != nil { if err != nil {
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik") ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik")
return err
} }
ac.lastWsReconnect = time.Now() ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
return nil
} }
// Shutdown Gracefully stops all workers, disconnects from websocket // Shutdown Gracefully stops all workers, disconnects from websocket
@ -65,21 +71,43 @@ func (ac *APIController) Shutdown() {
// waiting (with timeout) for the server to close the connection. // waiting (with timeout) for the server to close the connection.
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
ac.logger.Println("write close:", err) ac.logger.WithError(err).Warning("failed to write close message")
return return
} }
err = ac.wsConn.Close()
if err != nil {
ac.logger.WithError(err).Warning("failed to close websocket")
}
ac.logger.Info("finished shutdown")
} }
func (ac *APIController) startWSReConnector() { func (ac *APIController) reconnectWS() {
for { if ac.wsIsReconnecting {
time.Sleep(time.Second * 5) return
if ac.wsConn.IsConnected() {
continue
} }
if time.Since(ac.lastWsReconnect).Seconds() > 30 { ac.wsIsReconnecting = true
ac.wsConn.CloseAndReconnect() u := url.URL{
ac.logger.Info("Reconnecting websocket") Host: ac.Client.GetConfig().Host,
ac.lastWsReconnect = time.Now() Scheme: ac.Client.GetConfig().Scheme,
}
attempt := 1
for {
q := u.Query()
q.Set("attempt", strconv.Itoa(attempt))
u.RawQuery = q.Encode()
err := ac.initWS(u, ac.Outpost.Pk)
attempt += 1
if err != nil {
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier)
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second)
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
// Limit to 300 seconds (5m)
if ac.wsBackoffMultiplier >= 300 {
ac.wsBackoffMultiplier = 300
}
} else {
ac.wsIsReconnecting = false
return
} }
} }
} }
@ -96,6 +124,7 @@ func (ac *APIController) startWSHandler() {
"uuid": ac.instanceUUID.String(), "uuid": ac.instanceUUID.String(),
}).Set(0) }).Set(0)
logger.WithError(err).Warning("ws read error") logger.WithError(err).Warning("ws read error")
go ac.reconnectWS()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} }
@ -126,9 +155,6 @@ func (ac *APIController) startWSHandler() {
func (ac *APIController) startWSHealth() { func (ac *APIController) startWSHealth() {
ticker := time.NewTicker(time.Second * 10) ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C { for ; true; <-ticker.C {
if !ac.wsConn.IsConnected() {
continue
}
aliveMsg := websocketMessage{ aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello, Instruction: WebsocketInstructionHello,
Args: map[string]interface{}{ Args: map[string]interface{}{
@ -141,6 +167,7 @@ func (ac *APIController) startWSHealth() {
ac.logger.WithField("loop", "ws-health").Trace("hello'd") ac.logger.WithField("loop", "ws-health").Trace("hello'd")
if err != nil { if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
go ac.reconnectWS()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} else { } else {