diff --git a/go.mod b/go.mod index 833bfd1a8..a1c2e08fe 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/pkg/errors v0.9.1 github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect github.com/prometheus/client_golang v1.11.0 - github.com/recws-org/recws v1.3.1 github.com/sirupsen/logrus v1.8.1 goauthentik.io/api v0.2021104.11 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect diff --git a/go.sum b/go.sum index 40287d7d0..813f1d18f 100644 --- a/go.sum +++ b/go.sum @@ -356,7 +356,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -481,8 +480,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/recws-org/recws v1.3.1 h1:vtRhYpgNPBs3iFyu/+zxBqNzLYgID7UPC5siThkvbs0= -github.com/recws-org/recws v1.3.1/go.mod h1:gRH/uJLMsO7lbcecAB1Im1Zc6eKxs93ftGR0R39QeYA= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= diff --git a/internal/outpost/ak/api.go b/internal/outpost/ak/api.go index 42fee6598..1e49fe7ab 100644 --- a/internal/outpost/ak/api.go +++ b/internal/outpost/ak/api.go @@ -11,10 +11,9 @@ import ( "syscall" "time" - "github.com/go-openapi/strfmt" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" - "github.com/recws-org/recws" "goauthentik.io/api" "goauthentik.io/internal/constants" @@ -35,10 +34,13 @@ type APIController struct { logger *log.Entry - reloadOffset time.Duration - lastWsReconnect time.Time + reloadOffset time.Duration + + wsConn *websocket.Conn + lastWsReconnect time.Time + wsIsReconnecting bool + wsBackoffMultiplier int - wsConn *recws.RecConn instanceUUID uuid.UUID } @@ -85,12 +87,16 @@ func NewAPIController(akURL url.URL, token string) *APIController { token: token, logger: log, - reloadOffset: time.Duration(rand.Intn(10)) * time.Second, - instanceUUID: uuid.New(), - Outpost: outpost, + reloadOffset: time.Duration(rand.Intn(10)) * time.Second, + instanceUUID: uuid.New(), + Outpost: outpost, + wsBackoffMultiplier: 1, } ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") - ac.initWS(akURL, strfmt.UUID(outpost.Pk)) + err = ac.initWS(akURL, outpost.Pk) + if err != nil { + go ac.reconnectWS() + } ac.configureRefreshSignal() return ac } @@ -148,10 +154,6 @@ func (a *APIController) StartBackgorundTasks() error { "version": constants.VERSION, "build": constants.BUILD(), }).Set(1) - go func() { - a.logger.Debug("Starting WS re-connector...") - a.startWSReConnector() - }() go func() { a.logger.Debug("Starting WS Handler...") a.startWSHandler() diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index b9e992540..2bb4f5ec2 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -6,18 +6,17 @@ import ( "net/http" "net/url" "os" + "strconv" "strings" "time" - "github.com/go-openapi/strfmt" "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" - "github.com/recws-org/recws" "goauthentik.io/internal/constants" ) -func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) { - pathTemplate := "%s://%s/ws/outpost/%s/" +func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { + pathTemplate := "%s://%s/ws/outpost/%s/?%s" scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws") authHeader := fmt.Sprintf("Bearer %s", ac.token) @@ -32,15 +31,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) { value = "false" } - ws := &recws.RecConn{ - NonVerbose: true, + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{ InsecureSkipVerify: strings.ToLower(value) == "true", }, } - ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header) - ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("Connecting to authentik") + ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID, akURL.Query().Encode()), header) + if err != nil { + ac.logger.WithError(err).Warning("failed to connect websocket") + return err + } ac.wsConn = ws // Send hello message with our version @@ -52,11 +55,14 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) { "uuid": ac.instanceUUID.String(), }, } - err := ws.WriteJSON(msg) + err = ws.WriteJSON(msg) if err != nil { ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik") + return err } ac.lastWsReconnect = time.Now() + ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket") + return nil } // Shutdown Gracefully stops all workers, disconnects from websocket @@ -65,21 +71,43 @@ func (ac *APIController) Shutdown() { // waiting (with timeout) for the server to close the connection. err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { - ac.logger.Println("write close:", err) + ac.logger.WithError(err).Warning("failed to write close message") return } + err = ac.wsConn.Close() + if err != nil { + ac.logger.WithError(err).Warning("failed to close websocket") + } + ac.logger.Info("finished shutdown") } -func (ac *APIController) startWSReConnector() { +func (ac *APIController) reconnectWS() { + if ac.wsIsReconnecting { + return + } + ac.wsIsReconnecting = true + u := url.URL{ + Host: ac.Client.GetConfig().Host, + Scheme: ac.Client.GetConfig().Scheme, + } + attempt := 1 for { - time.Sleep(time.Second * 5) - if ac.wsConn.IsConnected() { - continue - } - if time.Since(ac.lastWsReconnect).Seconds() > 30 { - ac.wsConn.CloseAndReconnect() - ac.logger.Info("Reconnecting websocket") - ac.lastWsReconnect = time.Now() + q := u.Query() + q.Set("attempt", strconv.Itoa(attempt)) + u.RawQuery = q.Encode() + err := ac.initWS(u, ac.Outpost.Pk) + attempt += 1 + if err != nil { + ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier) + time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second) + ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2 + // Limit to 300 seconds (5m) + if ac.wsBackoffMultiplier >= 300 { + ac.wsBackoffMultiplier = 300 + } + } else { + ac.wsIsReconnecting = false + return } } } @@ -96,6 +124,7 @@ func (ac *APIController) startWSHandler() { "uuid": ac.instanceUUID.String(), }).Set(0) logger.WithError(err).Warning("ws read error") + go ac.reconnectWS() time.Sleep(time.Second * 5) continue } @@ -126,9 +155,6 @@ func (ac *APIController) startWSHandler() { func (ac *APIController) startWSHealth() { ticker := time.NewTicker(time.Second * 10) for ; true; <-ticker.C { - if !ac.wsConn.IsConnected() { - continue - } aliveMsg := websocketMessage{ Instruction: WebsocketInstructionHello, Args: map[string]interface{}{ @@ -141,6 +167,7 @@ func (ac *APIController) startWSHealth() { ac.logger.WithField("loop", "ws-health").Trace("hello'd") if err != nil { ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") + go ac.reconnectWS() time.Sleep(time.Second * 5) continue } else {