diff --git a/internal/crypto/generate.go b/internal/crypto/generate.go index 7503ed0d4..c7b4792d5 100644 --- a/internal/crypto/generate.go +++ b/internal/crypto/generate.go @@ -1,24 +1,24 @@ package crypto import ( - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" + "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "math/big" - "net" - "os" "time" log "github.com/sirupsen/logrus" ) -func GenerateKeypair(hosts []string) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +// GenerateSelfSignedCert Generate a self-signed TLS Certificate, to be used as fallback +func GenerateSelfSignedCert() (tls.Certificate, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { log.Fatalf("Failed to generate private key: %v", err) + return tls.Certificate{}, err } keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment @@ -30,12 +30,14 @@ func GenerateKeypair(hosts []string) { serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { log.Fatalf("Failed to generate serial number: %v", err) + return tls.Certificate{}, err } template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ - Organization: []string{"BeryJu.org"}, + Organization: []string{"authentik"}, + CommonName: "authentik default certificate", }, NotBefore: notBefore, NotAfter: notAfter, @@ -45,46 +47,17 @@ func GenerateKeypair(hosts []string) { BasicConstraintsValid: true, } - for _, h := range hosts { - if ip := net.ParseIP(h); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, h) - } - } + template.DNSNames = []string{"*"} - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv, priv) + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - log.Fatalf("Failed to create certificate: %v", err) - } - - certOut, err := os.Create("cert.pem") - if err != nil { - log.Fatalf("Failed to open cert.pem for writing: %v", err) - } - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - log.Fatalf("Failed to write data to cert.pem: %v", err) - } - if err := certOut.Close(); err != nil { - log.Fatalf("Error closing cert.pem: %v", err) - } - log.Print("wrote cert.pem\n") - - keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - log.Fatalf("Failed to open key.pem for writing: %v", err) - return + log.Warning(err) } + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) privBytes, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { - log.Fatalf("Unable to marshal private key: %v", err) + log.Warning(err) } - if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { - log.Fatalf("Failed to write data to key.pem: %v", err) - } - if err := keyOut.Close(); err != nil { - log.Fatalf("Error closing key.pem: %v", err) - } - log.Print("wrote key.pem\n") - return + privPemByes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + return tls.X509KeyPair(pemBytes, privPemByes) } diff --git a/internal/web/web.go b/internal/web/web.go index e924ecb06..316a67c83 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -1,6 +1,9 @@ package web import ( + "context" + "errors" + "net" "net/http" "sync" @@ -16,6 +19,8 @@ type WebServer struct { LegacyProxy bool + stop chan struct{} // channel for waiting shutdown + m *mux.Router lh *mux.Router log *log.Entry @@ -49,12 +54,45 @@ func (ws *WebServer) Run() { }() go func() { defer wg.Done() - // ws.listenTLS() + ws.listenTLS() }() wg.Done() } func (ws *WebServer) listenPlain() { + ln, err := net.Listen("tcp", config.G.Web.Listen) + if err != nil { + ws.log.WithError(err).Fatalf("failed to listen") + } + ws.log.WithField("addr", config.G.Web.Listen).Info("Running") + + ws.serve(ln) + ws.log.WithField("addr", config.G.Web.Listen).Info("Running") http.ListenAndServe(config.G.Web.Listen, ws.m) } + +func (ws *WebServer) serve(listener net.Listener) { + srv := &http.Server{ + Handler: ws.m, + } + + // See https://golang.org/pkg/net/http/#Server.Shutdown + idleConnsClosed := make(chan struct{}) + go func() { + <-ws.stop // wait notification for stopping server + + // We received an interrupt signal, shut down. + if err := srv.Shutdown(context.Background()); err != nil { + // Error from closing listeners, or context timeout: + ws.log.Printf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + }() + + err := srv.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + ws.log.Errorf("ERROR: http.Serve() - %s", err) + } + <-idleConnsClosed +} diff --git a/internal/web/web_ssl.go b/internal/web/web_ssl.go new file mode 100644 index 000000000..b89356908 --- /dev/null +++ b/internal/web/web_ssl.go @@ -0,0 +1,32 @@ +package web + +import ( + "crypto/tls" + "net" + + "goauthentik.io/internal/config" + "goauthentik.io/internal/crypto" +) + +// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests +func (ws *WebServer) listenTLS() { + cert, err := crypto.GenerateSelfSignedCert() + if err != nil { + ws.log.WithError(err).Error("failed to generate default cert") + } + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + + ln, err := net.Listen("tcp", config.G.Web.ListenTLS) + if err != nil { + ws.log.WithError(err).Fatalf("failed to listen") + } + ws.log.WithField("addr", config.G.Web.ListenTLS).Info("Running") + + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) + ws.serve(tlsListener) + ws.log.Printf("closing %s", tlsListener.Addr()) +} diff --git a/internal/web/web_utils.go b/internal/web/web_utils.go new file mode 100644 index 000000000..08c71e68c --- /dev/null +++ b/internal/web/web_utils.go @@ -0,0 +1,31 @@ +package web + +import ( + "log" + "net" + "time" +) + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + err = tc.SetKeepAlive(true) + if err != nil { + log.Printf("Error setting Keep-Alive: %v", err) + } + err = tc.SetKeepAlivePeriod(3 * time.Minute) + if err != nil { + log.Printf("Error setting Keep-Alive period: %v", err) + } + return tc, nil +} diff --git a/outpost/main.go b/outpost/main.go deleted file mode 100644 index 790580777..000000000 --- a/outpost/main.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -func main() { - -}