outpost: rewrite re-connect logic without recws
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
8abc9cc031
commit
f4988bc45e
1
go.mod
1
go.mod
|
@ -27,7 +27,6 @@ require (
|
|||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/recws-org/recws v1.3.1
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
goauthentik.io/api v0.2021104.11
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
|
||||
|
|
3
go.sum
3
go.sum
|
@ -356,7 +356,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC
|
|||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
|
||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
|
@ -481,8 +480,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
|
|||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/recws-org/recws v1.3.1 h1:vtRhYpgNPBs3iFyu/+zxBqNzLYgID7UPC5siThkvbs0=
|
||||
github.com/recws-org/recws v1.3.1/go.mod h1:gRH/uJLMsO7lbcecAB1Im1Zc6eKxs93ftGR0R39QeYA=
|
||||
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
|
|
|
@ -11,10 +11,9 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/recws-org/recws"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/constants"
|
||||
|
||||
|
@ -36,9 +35,12 @@ type APIController struct {
|
|||
logger *log.Entry
|
||||
|
||||
reloadOffset time.Duration
|
||||
lastWsReconnect time.Time
|
||||
|
||||
wsConn *recws.RecConn
|
||||
wsConn *websocket.Conn
|
||||
lastWsReconnect time.Time
|
||||
wsIsReconnecting bool
|
||||
wsBackoffMultiplier int
|
||||
|
||||
instanceUUID uuid.UUID
|
||||
}
|
||||
|
||||
|
@ -88,9 +90,13 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
|||
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
||||
instanceUUID: uuid.New(),
|
||||
Outpost: outpost,
|
||||
wsBackoffMultiplier: 1,
|
||||
}
|
||||
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
|
||||
ac.initWS(akURL, strfmt.UUID(outpost.Pk))
|
||||
err = ac.initWS(akURL, outpost.Pk)
|
||||
if err != nil {
|
||||
go ac.reconnectWS()
|
||||
}
|
||||
ac.configureRefreshSignal()
|
||||
return ac
|
||||
}
|
||||
|
@ -148,10 +154,6 @@ func (a *APIController) StartBackgorundTasks() error {
|
|||
"version": constants.VERSION,
|
||||
"build": constants.BUILD(),
|
||||
}).Set(1)
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS re-connector...")
|
||||
a.startWSReConnector()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Handler...")
|
||||
a.startWSHandler()
|
||||
|
|
|
@ -6,18 +6,17 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/recws-org/recws"
|
||||
"goauthentik.io/internal/constants"
|
||||
)
|
||||
|
||||
func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/"
|
||||
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/?%s"
|
||||
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
||||
|
||||
authHeader := fmt.Sprintf("Bearer %s", ac.token)
|
||||
|
@ -32,15 +31,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
|||
value = "false"
|
||||
}
|
||||
|
||||
ws := &recws.RecConn{
|
||||
NonVerbose: true,
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: strings.ToLower(value) == "true",
|
||||
},
|
||||
}
|
||||
ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header)
|
||||
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("Connecting to authentik")
|
||||
ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID, akURL.Query().Encode()), header)
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Warning("failed to connect websocket")
|
||||
return err
|
||||
}
|
||||
|
||||
ac.wsConn = ws
|
||||
// Send hello message with our version
|
||||
|
@ -52,11 +55,14 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
|||
"uuid": ac.instanceUUID.String(),
|
||||
},
|
||||
}
|
||||
err := ws.WriteJSON(msg)
|
||||
err = ws.WriteJSON(msg)
|
||||
if err != nil {
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik")
|
||||
return err
|
||||
}
|
||||
ac.lastWsReconnect = time.Now()
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown Gracefully stops all workers, disconnects from websocket
|
||||
|
@ -65,21 +71,43 @@ func (ac *APIController) Shutdown() {
|
|||
// waiting (with timeout) for the server to close the connection.
|
||||
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
ac.logger.Println("write close:", err)
|
||||
ac.logger.WithError(err).Warning("failed to write close message")
|
||||
return
|
||||
}
|
||||
err = ac.wsConn.Close()
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Warning("failed to close websocket")
|
||||
}
|
||||
ac.logger.Info("finished shutdown")
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSReConnector() {
|
||||
for {
|
||||
time.Sleep(time.Second * 5)
|
||||
if ac.wsConn.IsConnected() {
|
||||
continue
|
||||
func (ac *APIController) reconnectWS() {
|
||||
if ac.wsIsReconnecting {
|
||||
return
|
||||
}
|
||||
if time.Since(ac.lastWsReconnect).Seconds() > 30 {
|
||||
ac.wsConn.CloseAndReconnect()
|
||||
ac.logger.Info("Reconnecting websocket")
|
||||
ac.lastWsReconnect = time.Now()
|
||||
ac.wsIsReconnecting = true
|
||||
u := url.URL{
|
||||
Host: ac.Client.GetConfig().Host,
|
||||
Scheme: ac.Client.GetConfig().Scheme,
|
||||
}
|
||||
attempt := 1
|
||||
for {
|
||||
q := u.Query()
|
||||
q.Set("attempt", strconv.Itoa(attempt))
|
||||
u.RawQuery = q.Encode()
|
||||
err := ac.initWS(u, ac.Outpost.Pk)
|
||||
attempt += 1
|
||||
if err != nil {
|
||||
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier)
|
||||
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second)
|
||||
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
|
||||
// Limit to 300 seconds (5m)
|
||||
if ac.wsBackoffMultiplier >= 300 {
|
||||
ac.wsBackoffMultiplier = 300
|
||||
}
|
||||
} else {
|
||||
ac.wsIsReconnecting = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,6 +124,7 @@ func (ac *APIController) startWSHandler() {
|
|||
"uuid": ac.instanceUUID.String(),
|
||||
}).Set(0)
|
||||
logger.WithError(err).Warning("ws read error")
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
|
@ -126,9 +155,6 @@ func (ac *APIController) startWSHandler() {
|
|||
func (ac *APIController) startWSHealth() {
|
||||
ticker := time.NewTicker(time.Second * 10)
|
||||
for ; true; <-ticker.C {
|
||||
if !ac.wsConn.IsConnected() {
|
||||
continue
|
||||
}
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: map[string]interface{}{
|
||||
|
@ -141,6 +167,7 @@ func (ac *APIController) startWSHealth() {
|
|||
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
|
||||
if err != nil {
|
||||
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
} else {
|
||||
|
|
Reference in New Issue