root: connect to backend via socket (#6720)

* root: connect to gunicorn via socket

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* put socket in temp folder

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use non-socket connection for debug

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* don't hardcode local url

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix dev_server missing websocket

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* dedupe logging config between gunicorn and main app

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* slight refactor for proxy errors

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-09-02 17:58:37 +02:00 committed by GitHub
parent c04e83c86c
commit fd561ac802
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 229 additions and 197 deletions

View file

@ -0,0 +1,9 @@
"""custom runserver command"""
from daphne.management.commands.runserver import Command as RunServer
class Command(RunServer):
"""custom runserver command, which doesn't show the misleading django startup message"""
def on_bind(self, server_port):
pass

View file

@ -1,7 +1,112 @@
"""logging helpers"""
import logging
from logging import Logger
from os import getpid
import structlog
from authentik.lib.config import CONFIG
LOG_PRE_CHAIN = [
# Add the log level and a timestamp to the event_dict if the log entry
# is not from structlog.
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
]
def get_log_level():
"""Get log level, clamp trace to debug"""
level = CONFIG.get("log_level").upper()
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level
# so that would have to be re-written too
if level == "TRACE":
level = "DEBUG"
return level
def structlog_configure():
"""Configure structlog itself"""
structlog.configure_once(
processors=[
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.contextvars.merge_contextvars,
add_process_id,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.dict_tracebacks,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, get_log_level(), logging.WARNING)
),
cache_logger_on_first_use=True,
)
def get_logger_config():
"""Configure python stdlib's logging"""
debug = CONFIG.get_bool("debug")
global_level = get_log_level()
base_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(sort_keys=True),
"foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks],
},
"console": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.dev.ConsoleRenderer(colors=debug),
"foreign_pre_chain": LOG_PRE_CHAIN,
},
},
"handlers": {
"console": {
"level": "DEBUG",
"class": "logging.StreamHandler",
"formatter": "console" if debug else "json",
},
},
"loggers": {},
}
handler_level_map = {
"": global_level,
"authentik": global_level,
"django": "WARNING",
"django.request": "ERROR",
"celery": "WARNING",
"selenium": "WARNING",
"docker": "WARNING",
"urllib3": "WARNING",
"websockets": "WARNING",
"daphne": "WARNING",
"kubernetes": "INFO",
"asyncio": "WARNING",
"redis": "WARNING",
"silk": "INFO",
"fsevents": "WARNING",
"uvicorn": "WARNING",
"gunicorn": "INFO",
}
for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = {
"handlers": ["console"],
"level": level,
"propagate": False,
}
return base_config
def add_process_id(logger: Logger, method_name: str, event_dict):
"""Add the current process ID"""

View file

@ -172,7 +172,7 @@ class ChannelsLoggingMiddleware:
LOGGER.info(
scope["path"],
scheme="ws",
remote=scope.get("client", [""])[0],
remote=headers.get(b"x-forwarded-for", b"").decode(),
user_agent=headers.get(b"user-agent", b"").decode(),
**kwargs,
)

View file

@ -1,25 +1,21 @@
"""root settings for authentik"""
import importlib
import logging
import os
from hashlib import sha512
from pathlib import Path
from urllib.parse import quote_plus
import structlog
from celery.schedules import crontab
from sentry_sdk import set_tag
from authentik import ENV_GIT_HASH_KEY, __version__
from authentik.lib.config import CONFIG
from authentik.lib.logging import add_process_id
from authentik.lib.logging import get_logger_config, structlog_configure
from authentik.lib.sentry import sentry_init
from authentik.lib.utils.reflection import get_env
from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP
LOGGER = structlog.get_logger()
BASE_DIR = Path(__file__).absolute().parent.parent.parent
STATICFILES_DIRS = [BASE_DIR / Path("web")]
MEDIA_ROOT = BASE_DIR / Path("media")
@ -368,90 +364,9 @@ MEDIA_URL = "/media/"
TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
# We can't check TEST here as its set later by the test runner
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
if LOG_LEVEL == "TRACE":
LOG_LEVEL = "DEBUG"
structlog.configure_once(
processors=[
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.contextvars.merge_contextvars,
add_process_id,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.dict_tracebacks,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, LOG_LEVEL, logging.WARNING)
),
cache_logger_on_first_use=True,
)
LOG_PRE_CHAIN = [
# Add the log level and a timestamp to the event_dict if the log entry
# is not from structlog.
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
]
LOGGING = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(sort_keys=True),
"foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks],
},
"console": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.dev.ConsoleRenderer(colors=DEBUG),
"foreign_pre_chain": LOG_PRE_CHAIN,
},
},
"handlers": {
"console": {
"level": "DEBUG",
"class": "logging.StreamHandler",
"formatter": "console" if DEBUG else "json",
},
},
"loggers": {},
}
_LOGGING_HANDLER_MAP = {
"": LOG_LEVEL,
"authentik": LOG_LEVEL,
"django": "WARNING",
"django.request": "ERROR",
"celery": "WARNING",
"selenium": "WARNING",
"docker": "WARNING",
"urllib3": "WARNING",
"websockets": "WARNING",
"daphne": "WARNING",
"kubernetes": "INFO",
"asyncio": "WARNING",
"redis": "WARNING",
"silk": "INFO",
"fsevents": "WARNING",
}
for handler_name, level in _LOGGING_HANDLER_MAP.items():
LOGGING["loggers"][handler_name] = {
"handlers": ["console"],
"level": level,
"propagate": False,
}
structlog_configure()
LOGGING = get_logger_config()
_DISALLOWED_ITEMS = [

View file

@ -13,7 +13,6 @@ import (
"goauthentik.io/internal/config"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/debug"
"goauthentik.io/internal/gounicorn"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/proxyv2"
sentryutils "goauthentik.io/internal/utils/sentry"
@ -22,8 +21,6 @@ import (
"goauthentik.io/internal/web/tenant_tls"
)
var running = true
var rootCmd = &cobra.Command{
Use: "authentik",
Short: "Start authentik instance",
@ -63,38 +60,23 @@ var rootCmd = &cobra.Command{
ex := common.Init()
defer common.Defer()
u, _ := url.Parse("http://localhost:8000")
g := gounicorn.New()
defer func() {
l.Info("shutting down gunicorn")
g.Kill()
}()
ws := web.NewWebServer(g)
g.HealthyCallback = func() {
if !config.Get().Outposts.DisableEmbeddedOutpost {
go attemptProxyStart(ws, u)
}
}
go web.RunMetricsServer()
go attemptStartBackend(g)
ws.Start()
<-ex
running = false
l.Info("shutting down webserver")
go ws.Shutdown()
},
u, err := url.Parse(fmt.Sprintf("http://%s", config.Get().Listen.HTTP))
if err != nil {
panic(err)
}
func attemptStartBackend(g *gounicorn.GoUnicorn) {
for {
if !running {
ws := web.NewWebServer()
ws.Core().HealthyCallback = func() {
if config.Get().Outposts.DisableEmbeddedOutpost {
return
}
err := g.Start()
log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting")
go attemptProxyStart(ws, u)
}
ws.Start()
<-ex
l.Info("shutting down webserver")
go ws.Shutdown()
},
}
func attemptProxyStart(ws *web.WebServer, u *url.URL) {

View file

@ -1,7 +1,6 @@
package gounicorn
import (
"net/http"
"os"
"os/exec"
"runtime"
@ -10,10 +9,10 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/utils/web"
)
type GoUnicorn struct {
Healthcheck func() bool
HealthyCallback func()
log *log.Entry
@ -23,9 +22,10 @@ type GoUnicorn struct {
alive bool
}
func New() *GoUnicorn {
func New(healthcheck func() bool) *GoUnicorn {
logger := log.WithField("logger", "authentik.router.unicorn")
g := &GoUnicorn{
Healthcheck: healthcheck,
log: logger,
started: false,
killed: false,
@ -41,7 +41,7 @@ func (g *GoUnicorn) initCmd() {
args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"}
if config.Get().Debug {
command = "./manage.py"
args = []string{"runserver"}
args = []string{"dev_server"}
}
g.log.WithField("args", args).WithField("cmd", command).Debug("Starting gunicorn")
g.p = exec.Command(command, args...)
@ -69,22 +69,11 @@ func (g *GoUnicorn) Start() error {
func (g *GoUnicorn) healthcheck() {
g.log.Debug("starting healthcheck")
h := &http.Client{
Transport: web.NewUserAgentTransport("goauthentik.io/proxy/healthcheck", http.DefaultTransport),
}
check := func() bool {
res, err := h.Get("http://localhost:8000/-/health/live/")
if err == nil && res.StatusCode == 204 {
g.alive = true
return true
}
return false
}
// Default healthcheck is every 1 second on startup
// once we've been healthy once, increase to 30 seconds
for range time.Tick(time.Second) {
if check() {
if g.Healthcheck() {
g.alive = true
g.log.Info("backend is alive, backing off with healthchecks")
g.HealthyCallback()
break
@ -92,7 +81,7 @@ func (g *GoUnicorn) healthcheck() {
g.log.Debug("backend not alive yet")
}
for range time.Tick(30 * time.Second) {
check()
g.Healthcheck()
}
}

View file

@ -1,6 +1,7 @@
package web
import (
"fmt"
"io"
"net/http"
@ -26,7 +27,7 @@ var (
}, []string{"dest"})
)
func RunMetricsServer() {
func (ws *WebServer) runMetricsServer() {
m := mux.NewRouter()
l := log.WithField("logger", "authentik.router.metrics")
m.Use(sentry.SentryNoSampleMiddleware)
@ -38,13 +39,13 @@ func RunMetricsServer() {
).ServeHTTP(rw, r)
// Get upstream metrics
re, err := http.NewRequest("GET", "http://localhost:8000/-/metrics/", nil)
re, err := http.NewRequest("GET", fmt.Sprintf("%s/-/metrics/", ws.ul.String()), nil)
if err != nil {
l.WithError(err).Warning("failed to get upstream metrics")
return
}
re.SetBasicAuth("monitor", config.Get().SecretKey)
res, err := http.DefaultClient.Do(re)
res, err := ws.upstreamHttpClient().Do(re)
if err != nil {
l.WithError(err).Warning("failed to get upstream metrics")
return

View file

@ -2,10 +2,10 @@ package web
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -14,10 +14,9 @@ import (
func (ws *WebServer) configureProxy() {
// Reverse proxy to the application server
u, _ := url.Parse("http://localhost:8000")
director := func(req *http.Request) {
req.URL.Scheme = u.Scheme
req.URL.Host = u.Host
req.URL.Scheme = ws.ul.Scheme
req.URL.Host = ws.ul.Host
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
@ -27,7 +26,10 @@ func (ws *WebServer) configureProxy() {
}
ws.log.WithField("url", req.URL.String()).WithField("headers", req.Header).Trace("tracing request to backend")
}
rp := &httputil.ReverseProxy{Director: director}
rp := &httputil.ReverseProxy{
Director: director,
Transport: ws.upstreamHttpClient().Transport,
}
rp.ErrorHandler = ws.proxyErrorHandler
rp.ModifyResponse = ws.proxyModifyResponse
ws.m.PathPrefix("/outpost.goauthentik.io").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -43,14 +45,14 @@ func (ws *WebServer) configureProxy() {
}).Observe(float64(elapsed))
return
}
ws.proxyErrorHandler(rw, r, fmt.Errorf("proxy not running"))
ws.proxyErrorHandler(rw, r, errors.New("proxy not running"))
})
ws.m.Path("/-/health/live/").HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(204)
}))
ws.m.PathPrefix("/").HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) {
if !ws.p.IsRunning() {
ws.proxyErrorHandler(rw, r, fmt.Errorf("authentik core not running yet"))
if !ws.g.IsRunning() {
ws.proxyErrorHandler(rw, r, errors.New("authentik starting"))
return
}
before := time.Now()
@ -82,17 +84,14 @@ func (ws *WebServer) proxyErrorHandler(rw http.ResponseWriter, req *http.Request
ws.log.WithError(err).Warning("failed to proxy to backend")
rw.WriteHeader(http.StatusBadGateway)
em := fmt.Sprintf("failed to connect to authentik backend: %v", err)
if !ws.p.IsRunning() {
em = "authentik starting..."
}
// return json if the client asks for json
if req.Header.Get("Accept") == "application/json" {
eem, _ := json.Marshal(map[string]string{
err = json.NewEncoder(rw).Encode(map[string]string{
"error": em,
})
em = string(eem)
}
} else {
_, err = rw.Write([]byte(em))
}
if err != nil {
ws.log.WithError(err).Warning("failed to write error message")
}

View file

@ -3,8 +3,12 @@ package web
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
@ -26,13 +30,18 @@ type WebServer struct {
ProxyServer *proxyv2.ProxyServer
TenantTLS *tenant_tls.Watcher
g *gounicorn.GoUnicorn
gr bool
m *mux.Router
lh *mux.Router
log *log.Entry
p *gounicorn.GoUnicorn
uc *http.Client
ul *url.URL
}
func NewWebServer(g *gounicorn.GoUnicorn) *WebServer {
const UnixSocketName = "authentik-core.sock"
func NewWebServer() *WebServer {
l := log.WithField("logger", "authentik.router")
mainHandler := mux.NewRouter()
mainHandler.Use(web.ProxyHeaders())
@ -40,23 +49,80 @@ func NewWebServer(g *gounicorn.GoUnicorn) *WebServer {
loggingHandler := mainHandler.NewRoute().Subrouter()
loggingHandler.Use(web.NewLoggingHandler(l, nil))
tmp := os.TempDir()
socketPath := path.Join(tmp, "authentik-core.sock")
// create http client to talk to backend, normal client if we're in debug more
// and a client that connects to our socket when in non debug mode
var upstreamClient *http.Client
if config.Get().Debug {
upstreamClient = http.DefaultClient
} else {
upstreamClient = &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socketPath)
},
},
}
}
u, _ := url.Parse("http://localhost:8000")
ws := &WebServer{
m: mainHandler,
lh: loggingHandler,
log: l,
p: g,
gr: true,
uc: upstreamClient,
ul: u,
}
ws.configureStatic()
ws.configureProxy()
ws.g = gounicorn.New(func() bool {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/-/health/live/", ws.ul.String()), nil)
if err != nil {
ws.log.WithError(err).Warning("failed to create request for healthcheck")
return false
}
req.Header.Set("User-Agent", "goauthentik.io/router/healthcheck")
res, err := ws.upstreamHttpClient().Do(req)
if err == nil && res.StatusCode == 204 {
return true
}
return false
})
return ws
}
func (ws *WebServer) Start() {
go ws.runMetricsServer()
go ws.attemptStartBackend()
go ws.listenPlain()
go ws.listenTLS()
}
func (ws *WebServer) attemptStartBackend() {
for {
if !ws.gr {
return
}
err := ws.g.Start()
log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting")
}
}
func (ws *WebServer) Core() *gounicorn.GoUnicorn {
return ws.g
}
func (ws *WebServer) upstreamHttpClient() *http.Client {
return ws.uc
}
func (ws *WebServer) Shutdown() {
ws.log.Info("shutting down gunicorn")
ws.g.Kill()
ws.stop <- struct{}{}
}

View file

@ -7,12 +7,12 @@ from pathlib import Path
from tempfile import gettempdir
from typing import TYPE_CHECKING
import structlog
from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME
from prometheus_client.values import MultiProcessValue
from authentik import get_full_version
from authentik.lib.config import CONFIG
from authentik.lib.logging import get_logger_config
from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.reflection import get_env
from authentik.root.install_id import get_install_id_raw
@ -21,57 +21,23 @@ from lifecycle.worker import DjangoUvicornWorker
if TYPE_CHECKING:
from gunicorn.arbiter import Arbiter
bind = "127.0.0.1:8000"
_tmp = Path(gettempdir())
worker_class = "lifecycle.worker.DjangoUvicornWorker"
worker_tmp_dir = str(_tmp.joinpath("authentik_worker_tmp"))
prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp"))
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir)
makedirs(worker_tmp_dir, exist_ok=True)
makedirs(prometheus_tmp_dir, exist_ok=True)
bind = f"unix://{str(_tmp.joinpath('authentik-core.sock'))}"
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir)
max_requests = 1000
max_requests_jitter = 50
_debug = CONFIG.get_bool("DEBUG", False)
logconfig_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(),
"foreign_pre_chain": [
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
],
},
"console": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.dev.ConsoleRenderer(colors=True),
"foreign_pre_chain": [
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
],
},
},
"handlers": {
"console": {"class": "logging.StreamHandler", "formatter": "json" if _debug else "console"},
},
"loggers": {
"uvicorn": {"handlers": ["console"], "level": "WARNING", "propagate": False},
"gunicorn": {"handlers": ["console"], "level": "INFO", "propagate": False},
},
}
logconfig_dict = get_logger_config()
# if we're running in kubernetes, use fixed workers because we can scale with more pods
# otherwise (assume docker-compose), use as much as we can