outpost: rewrite re-connect logic without recws

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-12-11 22:48:33 +01:00
parent 8abc9cc031
commit f4988bc45e
4 changed files with 64 additions and 39 deletions

1
go.mod
View file

@ -27,7 +27,6 @@ require (
github.com/pkg/errors v0.9.1
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
github.com/prometheus/client_golang v1.11.0
github.com/recws-org/recws v1.3.1
github.com/sirupsen/logrus v1.8.1
goauthentik.io/api v0.2021104.11
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect

3
go.sum
View file

@ -356,7 +356,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -481,8 +480,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/recws-org/recws v1.3.1 h1:vtRhYpgNPBs3iFyu/+zxBqNzLYgID7UPC5siThkvbs0=
github.com/recws-org/recws v1.3.1/go.mod h1:gRH/uJLMsO7lbcecAB1Im1Zc6eKxs93ftGR0R39QeYA=
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=

View file

@ -11,10 +11,9 @@ import (
"syscall"
"time"
"github.com/go-openapi/strfmt"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/recws-org/recws"
"goauthentik.io/api"
"goauthentik.io/internal/constants"
@ -35,10 +34,13 @@ type APIController struct {
logger *log.Entry
reloadOffset time.Duration
lastWsReconnect time.Time
reloadOffset time.Duration
wsConn *websocket.Conn
lastWsReconnect time.Time
wsIsReconnecting bool
wsBackoffMultiplier int
wsConn *recws.RecConn
instanceUUID uuid.UUID
}
@ -85,12 +87,16 @@ func NewAPIController(akURL url.URL, token string) *APIController {
token: token,
logger: log,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
wsBackoffMultiplier: 1,
}
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
ac.initWS(akURL, strfmt.UUID(outpost.Pk))
err = ac.initWS(akURL, outpost.Pk)
if err != nil {
go ac.reconnectWS()
}
ac.configureRefreshSignal()
return ac
}
@ -148,10 +154,6 @@ func (a *APIController) StartBackgorundTasks() error {
"version": constants.VERSION,
"build": constants.BUILD(),
}).Set(1)
go func() {
a.logger.Debug("Starting WS re-connector...")
a.startWSReConnector()
}()
go func() {
a.logger.Debug("Starting WS Handler...")
a.startWSHandler()

View file

@ -6,18 +6,17 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/go-openapi/strfmt"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/recws-org/recws"
"goauthentik.io/internal/constants"
)
func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
pathTemplate := "%s://%s/ws/outpost/%s/"
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
pathTemplate := "%s://%s/ws/outpost/%s/?%s"
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", ac.token)
@ -32,15 +31,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
value = "false"
}
ws := &recws.RecConn{
NonVerbose: true,
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: strings.ToLower(value) == "true",
},
}
ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header)
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("Connecting to authentik")
ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID, akURL.Query().Encode()), header)
if err != nil {
ac.logger.WithError(err).Warning("failed to connect websocket")
return err
}
ac.wsConn = ws
// Send hello message with our version
@ -52,11 +55,14 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
"uuid": ac.instanceUUID.String(),
},
}
err := ws.WriteJSON(msg)
err = ws.WriteJSON(msg)
if err != nil {
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik")
return err
}
ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
return nil
}
// Shutdown Gracefully stops all workers, disconnects from websocket
@ -65,21 +71,43 @@ func (ac *APIController) Shutdown() {
// waiting (with timeout) for the server to close the connection.
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
ac.logger.Println("write close:", err)
ac.logger.WithError(err).Warning("failed to write close message")
return
}
err = ac.wsConn.Close()
if err != nil {
ac.logger.WithError(err).Warning("failed to close websocket")
}
ac.logger.Info("finished shutdown")
}
func (ac *APIController) startWSReConnector() {
func (ac *APIController) reconnectWS() {
if ac.wsIsReconnecting {
return
}
ac.wsIsReconnecting = true
u := url.URL{
Host: ac.Client.GetConfig().Host,
Scheme: ac.Client.GetConfig().Scheme,
}
attempt := 1
for {
time.Sleep(time.Second * 5)
if ac.wsConn.IsConnected() {
continue
}
if time.Since(ac.lastWsReconnect).Seconds() > 30 {
ac.wsConn.CloseAndReconnect()
ac.logger.Info("Reconnecting websocket")
ac.lastWsReconnect = time.Now()
q := u.Query()
q.Set("attempt", strconv.Itoa(attempt))
u.RawQuery = q.Encode()
err := ac.initWS(u, ac.Outpost.Pk)
attempt += 1
if err != nil {
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier)
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second)
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
// Limit to 300 seconds (5m)
if ac.wsBackoffMultiplier >= 300 {
ac.wsBackoffMultiplier = 300
}
} else {
ac.wsIsReconnecting = false
return
}
}
}
@ -96,6 +124,7 @@ func (ac *APIController) startWSHandler() {
"uuid": ac.instanceUUID.String(),
}).Set(0)
logger.WithError(err).Warning("ws read error")
go ac.reconnectWS()
time.Sleep(time.Second * 5)
continue
}
@ -126,9 +155,6 @@ func (ac *APIController) startWSHandler() {
func (ac *APIController) startWSHealth() {
ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C {
if !ac.wsConn.IsConnected() {
continue
}
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: map[string]interface{}{
@ -141,6 +167,7 @@ func (ac *APIController) startWSHealth() {
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
go ac.reconnectWS()
time.Sleep(time.Second * 5)
continue
} else {