From 9541c58eeb400806c6dce526c8f2e9c513ecc342 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder Date: Thu, 24 Apr 2025 15:25:41 +0200 Subject: [PATCH] settings: introduce settings YAML file to complement cmd arguments --- cmd/go-away/main.go | 237 +++++++++++++---------------------- lib/challenge/data.go | 9 +- lib/challenge/dnsbl/dnsbl.go | 4 +- lib/challenge/key.go | 5 +- lib/challenge/types.go | 5 +- lib/http.go | 19 +-- lib/interface.go | 38 ++++-- lib/policy/options.go | 22 ---- lib/policy/state.go | 19 +++ lib/settings/backend.go | 85 +++++++++++++ lib/settings/bind.go | 169 +++++++++++++++++++++++++ lib/settings/settings.go | 51 ++++++++ lib/settings/strings.go | 24 ++++ lib/state.go | 23 ++-- utils/http.go | 43 +++++-- 15 files changed, 523 insertions(+), 230 deletions(-) delete mode 100644 lib/policy/options.go create mode 100644 lib/policy/state.go create mode 100644 lib/settings/backend.go create mode 100644 lib/settings/bind.go create mode 100644 lib/settings/settings.go create mode 100644 lib/settings/strings.go diff --git a/cmd/go-away/main.go b/cmd/go-away/main.go index 2fe05bd..0fcc7a7 100644 --- a/cmd/go-away/main.go +++ b/cmd/go-away/main.go @@ -4,78 +4,27 @@ import ( "bytes" "crypto/ed25519" "crypto/rand" - "crypto/tls" "encoding/hex" "errors" "flag" "fmt" "git.gammaspectra.live/git/go-away/lib" "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/lib/settings" "git.gammaspectra.live/git/go-away/utils" - "github.com/pires/go-proxyproto" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" + "github.com/goccy/go-yaml" "log" "log/slog" - "net" "net/http" + "net/http/pprof" "os" "os/signal" "path" "runtime/debug" - "strconv" "strings" - "sync/atomic" "syscall" ) -func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) { - if network == "proxy" { - network = "tcp" - proxy = true - } - - formattedAddress := "" - switch network { - case "unix": - formattedAddress = "unix:" + address - case "tcp": - formattedAddress = "http://localhost" + address - default: - formattedAddress = fmt.Sprintf(`(%s) %s`, network, address) - } - - listener, err := net.Listen(network, address) - if err != nil { - log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err)) - } - - // additional permission handling for unix sockets - if network == "unix" { - mode, err := strconv.ParseUint(socketMode, 8, 0) - if err != nil { - listener.Close() - log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", socketMode, err)) - } - - err = os.Chmod(address, os.FileMode(mode)) - if err != nil { - listener.Close() - log.Fatal(fmt.Errorf("could not change socket mode: %w", err)) - } - } - - if proxy { - slog.Warn("listener PROXY enabled") - formattedAddress += " +PROXY" - listener = &proxyproto.Listener{ - Listener: listener, - } - } - - return listener, formattedAddress -} - var internalCmdName = "go-away" var internalMainName = "go-away" var internalMainVersion = "dev" @@ -101,40 +50,20 @@ func (v *MultiVar) Set(value string) error { return nil } -func newACMEManager(clientDirectory string, backends map[string]http.Handler) *autocert.Manager { - - var domains []string - for d := range backends { - parts := strings.Split(d, ":") - d = parts[0] - if net.ParseIP(d) != nil { - continue - } - domains = append(domains, d) - } - - manager := &autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(domains...), - Client: &acme.Client{ - HTTPClient: http.DefaultClient, - DirectoryURL: clientDirectory, - }, - } - return manager -} - func main() { - bind := flag.String("bind", ":8080", "network address to bind HTTP/HTTP(s) to") - bindNetwork := flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp") - bindProxy := flag.Bool("bind-proxy", false, "use PROXY protocol in front of the listener") - socketMode := flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.") + + opt := settings.DefaultSettings + + flag.StringVar(&opt.Bind.Address, "bind", opt.Bind.Address, "network address to bind HTTP/HTTP(s) to") + flag.StringVar(&opt.Bind.Network, "bind-network", opt.Bind.Network, "network family to bind HTTP to, e.g. unix, tcp") + flag.BoolVar(&opt.Bind.Proxy, "bind-proxy", opt.Bind.Proxy, "use PROXY protocol in front of the listener") + flag.StringVar(&opt.Bind.SocketMode, "socket-mode", opt.Bind.SocketMode, "socket mode (permissions) for unix domain sockets.") slogLevel := flag.String("slog-level", "WARN", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)") debugMode := flag.Bool("debug", false, "debug mode with logs and server timings") - passThrough := flag.Bool("passthrough", false, "passthrough mode sends all requests to matching backends until state is loaded") + flag.BoolVar(&opt.Bind.Passthrough, "passthrough", opt.Bind.Passthrough, "passthrough mode sends all requests to matching backends until state is loaded") check := flag.Bool("check", false, "check configuration and policies, then exit") - acmeAutocert := flag.String("acme-autocert", "", "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)") + flag.StringVar(&opt.Bind.TLSAcmeAutoCert, "acme-autocert", opt.Bind.TLSAcmeAutoCert, "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)") clientIpHeader := flag.String("client-ip-header", "", "Client HTTP header to fetch their IP address from (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)") backendIpHeader := flag.String("backend-ip-header", "", "Backend HTTP header to set the client IP address from, if empty defaults to leaving Client header alone (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)") @@ -143,8 +72,9 @@ func main() { policyFile := flag.String("policy", "", "path to policy YAML file") policySnippets := flag.String("policy-snippets", "", "path to YAML snippets folder") - challengeTemplate := flag.String("challenge-template", "anubis", "name or path of the challenge template to use (anubis, forgejo)") - challengeTemplateTheme := flag.String("challenge-template-theme", "", "name of the challenge template theme to use (forgejo => [forgejo-auto, forgejo-dark, forgejo-light, gitea...])") + flag.StringVar(&opt.ChallengeTemplate, "challenge-template", opt.ChallengeTemplate, "name or path of the challenge template to use (anubis, forgejo)") + + templateTheme := flag.String("challenge-template-theme", opt.ChallengeTemplateOverrides["Theme"], "name of the challenge template theme to use (forgejo => [forgejo-auto, forgejo-dark, forgejo-light, gitea...])") packageName := flag.String("package-path", internalCmdName, "package name to expose in .well-known url path") @@ -153,6 +83,8 @@ func main() { var backends MultiVar flag.Var(&backends, "backend", "backend definition in the form of an.example.com=http://backend:1234 (can be specified multiple times)") + settingsFile := flag.String("config", "", "path to config override YAML file") + flag.Parse() var err error @@ -176,6 +108,21 @@ func main() { slog.Info("go-away", "package", internalMainName, "version", internalMainVersion, "cmd", internalCmdName) + // preload missing settings + opt.ChallengeTemplateOverrides["Theme"] = *templateTheme + + // load overrides + if *settingsFile != "" { + settingsData, err := os.ReadFile(*settingsFile) + if err != nil { + log.Fatal(fmt.Errorf("could not read settings file: %w", err)) + } + err = yaml.Unmarshal(settingsData, &opt) + if err != nil { + log.Fatal(fmt.Errorf("could not parse settings file: %w", err)) + } + } + var seed []byte var kValue string @@ -207,18 +154,24 @@ func main() { } createdBackends := make(map[string]http.Handler) - - parsedBackends := make(map[string]string) for _, backend := range backends { + if backend == "" { + // skip empty to allow no values + continue + } parts := strings.Split(backend, "=") if len(parts) != 2 { log.Fatal(fmt.Errorf("invalid backend definition: %s, expected 2 parts, got %v", backend, parts)) } - parsedBackends[parts[0]] = parts[1] + + // make no-settings, default backend + opt.Backends[parts[0]] = settings.Backend{ + URL: parts[1], + } } - for k, v := range parsedBackends { - backend, err := utils.MakeReverseProxy(v) + for k, v := range opt.Backends { + backend, err := v.Create() if err != nil { log.Fatal(fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err)) } @@ -228,10 +181,11 @@ func main() { } if len(createdBackends) == 0 { - log.Fatal(fmt.Errorf("no backends defined in policy file")) + log.Fatal(fmt.Errorf("no backends defined in cmdline or settings file")) } var cache utils.Cache + var acmeCache string if *cachePath != "" { err = os.MkdirAll(*cachePath, 0755) if err != nil { @@ -248,29 +202,8 @@ func main() { if err != nil { log.Fatal(fmt.Errorf("failed to open cache directory: %w", err)) } - } - var tlsConfig *tls.Config - - if *acmeAutocert != "" { - switch *acmeAutocert { - case "letsencrypt": - *acmeAutocert = acme.LetsEncryptURL - } - - acmeManager := newACMEManager(*acmeAutocert, createdBackends) - if *cachePath != "" { - err = os.MkdirAll(path.Join(*cachePath, "acme"), 0755) - if err != nil { - log.Fatal(fmt.Errorf("failed to create acme cache directory: %w", err)) - } - acmeManager.Cache = autocert.DirCache(path.Join(*cachePath, "acme")) - } - slog.Warn( - "acme-autocert enabled", - "directory", *acmeAutocert, - ) - tlsConfig = acmeManager.TLSConfig() + acmeCache = path.Join(*cachePath, "acme") } loadPolicyState := func() (http.Handler, error) { @@ -284,22 +217,19 @@ func main() { return nil, fmt.Errorf("failed to parse policy file: %w", err) } - settings := policy.Settings{ - Cache: cache, - Backends: createdBackends, - Debug: *debugMode, - MainName: internalMainName, - MainVersion: internalMainVersion, - PackageName: *packageName, - ChallengeTemplate: *challengeTemplate, - ChallengeTemplateTheme: *challengeTemplateTheme, - PrivateKeySeed: seed, - ClientIpHeader: *clientIpHeader, - BackendIpHeader: *backendIpHeader, - ChallengeResponseCode: http.StatusTeapot, + stateSettings := policy.StateSettings{ + Cache: cache, + Backends: createdBackends, + MainName: internalMainName, + MainVersion: internalMainVersion, + PackageName: *packageName, + PrivateKeySeed: seed, + ClientIpHeader: *clientIpHeader, + BackendIpHeader: *backendIpHeader, + ChallengeResponseCode: http.StatusTeapot, } - state, err := lib.NewState(*p, settings) + state, err := lib.NewState(*p, opt, stateSettings) if err != nil { return nil, fmt.Errorf("failed to create state: %w", err) @@ -317,32 +247,15 @@ func main() { os.Exit(0) } - listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy) + listener, listenUrl := opt.Bind.Listener() slog.Warn( "listening", "url", listenUrl, ) - var serverHandler atomic.Pointer[http.Handler] - server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if handler := serverHandler.Load(); handler == nil { - http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) - } else { - (*handler).ServeHTTP(w, r) - } - }), tlsConfig) - - if *passThrough { - // setup a passthrough handler temporarily - fn := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backend := utils.SelectHTTPHandler(createdBackends, r.Host) - if backend == nil { - http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) - } else { - backend.ServeHTTP(w, r) - } - })) - serverHandler.Store(&fn) + server, swap, err := opt.Bind.Server(createdBackends, acmeCache) + if err != nil { + log.Fatal(fmt.Errorf("failed to create server: %w", err)) } go func() { @@ -351,7 +264,7 @@ func main() { log.Fatal(fmt.Errorf("failed to load policy state: %w", err)) } - serverHandler.Store(&handler) + swap(handler) slog.Warn( "handler configuration loaded", ) @@ -369,12 +282,34 @@ func main() { continue } - serverHandler.Store(&handler) + swap(handler) slog.Warn("handler configuration reloaded") } }() - if tlsConfig != nil { + if opt.BindDebug != "" { + go func() { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + debugServer := http.Server{ + Addr: opt.BindDebug, + Handler: mux, + } + + slog.Warn( + "listening metrics", + "bind", opt.BindDebug, + ) + if err = debugServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatal(err) + } + }() + } + + if server.TLSConfig != nil { if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) { log.Fatal(err) } diff --git a/lib/challenge/data.go b/lib/challenge/data.go index ef1ac3a..c5c10dd 100644 --- a/lib/challenge/data.go +++ b/lib/challenge/data.go @@ -12,8 +12,8 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/traits" - "net" "net/http" + "net/netip" "net/textproto" "time" ) @@ -36,7 +36,7 @@ type RequestData struct { Time time.Time ChallengeVerify map[Id]VerifyResult ChallengeState map[Id]VerifyState - RemoteAddress net.IP + RemoteAddress netip.AddrPort State StateInterface CookiePrefix string @@ -57,7 +57,6 @@ func CreateRequestData(r *http.Request, state StateInterface) (*http.Request, *R data.ChallengeState = make(map[Id]VerifyState, len(state.GetChallenges())) data.Time = time.Now().UTC() data.State = state - data.r = r data.fp = make(map[string]string, 2) @@ -85,6 +84,8 @@ func CreateRequestData(r *http.Request, state StateInterface) (*http.Request, *R data.CookiePrefix = utils.CookiePrefix + hex.EncodeToString(sum.Sum(nil)[:4]) + "-" r = r.WithContext(context.WithValue(r.Context(), requestDataContextKey{}, &data)) + r = utils.SetRemoteAddress(r, data.RemoteAddress) + data.r = r return r, &data } @@ -96,7 +97,7 @@ func (d *RequestData) ResolveName(name string) (any, bool) { case "method": return d.r.Method, true case "remoteAddress": - return d.RemoteAddress, true + return d.RemoteAddress.Addr().AsSlice(), true case "userAgent": return d.r.UserAgent(), true case "path": diff --git a/lib/challenge/dnsbl/dnsbl.go b/lib/challenge/dnsbl/dnsbl.go index d840506..19a1465 100644 --- a/lib/challenge/dnsbl/dnsbl.go +++ b/lib/challenge/dnsbl/dnsbl.go @@ -119,9 +119,9 @@ func FillRegistration(state challenge.StateInterface, reg *challenge.Registratio data := challenge.RequestDataFromContext(r.Context()) - result, err := lookup(r.Context(), params.Decay, params.Timeout, dnsbl, decayMap, data.RemoteAddress) + result, err := lookup(r.Context(), params.Decay, params.Timeout, dnsbl, decayMap, data.RemoteAddress.Addr().Unmap().AsSlice()) if err != nil { - data.State.Logger(r).Debug("dnsbl lookup failed", "address", data.RemoteAddress.String(), "result", result, "err", err) + data.State.Logger(r).Debug("dnsbl lookup failed", "address", data.RemoteAddress.Addr().String(), "result", result, "err", err) } if result.Bad() { diff --git a/lib/challenge/key.go b/lib/challenge/key.go index 710c2c5..007e4b4 100644 --- a/lib/challenge/key.go +++ b/lib/challenge/key.go @@ -47,7 +47,8 @@ func GetChallengeKeyForRequest(state StateInterface, reg *Registration, until ti hasher.Write([]byte("challenge\x00")) hasher.Write([]byte(reg.Name)) hasher.Write([]byte{0}) - hasher.Write(address.To16()) + ipBuf := address.Addr().Unmap().As16() + hasher.Write(ipBuf[:]) hasher.Write([]byte{0}) // specific headers @@ -72,7 +73,7 @@ func GetChallengeKeyForRequest(state StateInterface, reg *Registration, until ti sum[0] = 0 - if address.To4() != nil { + if address.Addr().Unmap().Is4() { // Is IPv4, mark sum.Set(KeyFlagIsIPv4) } diff --git a/lib/challenge/types.go b/lib/challenge/types.go index c5e343d..1001e41 100644 --- a/lib/challenge/types.go +++ b/lib/challenge/types.go @@ -3,6 +3,7 @@ package challenge import ( "crypto/ed25519" "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/lib/settings" "github.com/google/cel-go/cel" "log/slog" "net/http" @@ -106,7 +107,9 @@ type StateInterface interface { GetChallengeByName(name string) (*Registration, bool) GetChallenges() Register - Settings() policy.Settings + Settings() policy.StateSettings + + Options() settings.Settings GetBackend(host string) http.Handler } diff --git a/lib/http.go b/lib/http.go index c696b26..8f3f84c 100644 --- a/lib/http.go +++ b/lib/http.go @@ -10,10 +10,7 @@ import ( "html/template" "log/slog" "net/http" - "net/http/pprof" - "strconv" "strings" - "time" ) var templates map[string]*template.Template @@ -51,17 +48,11 @@ func initTemplate(name, data string) error { return nil } -func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration time.Duration) { - if state.Settings().Debug { - w.Header().Add("Server-Timing", fmt.Sprintf("%s;desc=%s;dur=%d", name, strconv.Quote(desc), duration.Milliseconds())) - } -} - func GetLoggerForRequest(r *http.Request) *slog.Logger { data := challenge.RequestDataFromContext(r.Context()) args := []any{ "request_id", data.Id.String(), - "remote_address", data.RemoteAddress.String(), + "remote_address", data.RemoteAddress.Addr().String(), "user_agent", r.UserAgent(), "host", r.Host, "path", r.URL.Path, @@ -152,14 +143,6 @@ func (state *State) setupRoutes() error { state.Mux.HandleFunc("/", state.handleRequest) - if state.Settings().Debug { - //TODO: split this to a different listener, metrics listener - http.HandleFunc(state.urlPath+"/debug/pprof/", pprof.Index) - http.HandleFunc(state.urlPath+"/debug/pprof/profile", pprof.Profile) - http.HandleFunc(state.urlPath+"/debug/pprof/symbol", pprof.Symbol) - http.HandleFunc(state.urlPath+"/debug/pprof/trace", pprof.Trace) - } - state.Mux.Handle("GET "+state.urlPath+"/assets/", http.StripPrefix(state.UrlPath()+"/assets/", gzipped.FileServer(gzipped.FS(embed.AssetsFs)))) for _, reg := range state.challenges { diff --git a/lib/interface.go b/lib/interface.go index 03e8c31..3fd0466 100644 --- a/lib/interface.go +++ b/lib/interface.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "git.gammaspectra.live/git/go-away/lib/challenge" "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/lib/settings" "git.gammaspectra.live/git/go-away/utils" "github.com/google/cel-go/cel" "log/slog" @@ -72,23 +73,30 @@ func (state *State) ChallengePage(w http.ResponseWriter, r *http.Request, status input := make(map[string]any) input["Id"] = data.Id.String() input["Random"] = utils.CacheBust() + + input["Path"] = state.UrlPath() + for k, v := range state.Options().ChallengeTemplateOverrides { + input[k] = v + } + for k, v := range state.Options().Strings { + input["str_"+k] = v + } + if reg != nil { input["Challenge"] = reg.Name - input["Path"] = state.UrlPath() } - input["Theme"] = state.Settings().ChallengeTemplateTheme maps.Copy(input, params) if _, ok := input["Title"]; !ok { - input["Title"] = "Checking you are not a bot" + input["Title"] = state.Options().Strings.Get("challenge_are_you_bot") } w.Header().Set("Content-Type", "text/html; charset=utf-8") buf := bytes.NewBuffer(make([]byte, 0, 8192)) - err := templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"].Execute(buf, input) + err := templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"].Execute(buf, input) if err != nil { state.ErrorPage(w, r, http.StatusInternalServerError, err, "") } else { @@ -103,17 +111,25 @@ func (state *State) ErrorPage(w http.ResponseWriter, r *http.Request, status int buf := bytes.NewBuffer(make([]byte, 0, 8192)) - err2 := templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"].Execute(buf, map[string]any{ + input := map[string]any{ "Id": data.Id.String(), "Random": utils.CacheBust(), "Error": err.Error(), "Path": state.UrlPath(), - "Theme": state.Settings().ChallengeTemplateTheme, - "Title": "Oh no! " + http.StatusText(status), + "Theme": "", + "Title": state.Options().Strings.Get("error") + " " + http.StatusText(status), "HideSpinner": true, "Challenge": "", "Redirect": redirect, - }) + } + for k, v := range state.Options().ChallengeTemplateOverrides { + input[k] = v + } + for k, v := range state.Options().Strings { + input["str_"+k] = v + } + + err2 := templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"].Execute(buf, input) if err2 != nil { // nested errors! panic(err2) @@ -136,10 +152,14 @@ func (state *State) GetChallengeByName(name string) (*challenge.Registration, bo reg, _, ok := state.challenges.GetByName(name) return reg, ok } -func (state *State) Settings() policy.Settings { +func (state *State) Settings() policy.StateSettings { return state.settings } +func (state *State) Options() settings.Settings { + return state.opt +} + func (state *State) GetBackend(host string) http.Handler { return utils.SelectHTTPHandler(state.Settings().Backends, host) } diff --git a/lib/policy/options.go b/lib/policy/options.go deleted file mode 100644 index 1c934cd..0000000 --- a/lib/policy/options.go +++ /dev/null @@ -1,22 +0,0 @@ -package policy - -import ( - "git.gammaspectra.live/git/go-away/utils" - "net/http" -) - -type Settings struct { - Cache utils.Cache - Backends map[string]http.Handler - PrivateKeySeed []byte - Debug bool - MainName string - MainVersion string - PackageName string - ChallengeTemplate string - ChallengeTemplateTheme string - ClientIpHeader string - BackendIpHeader string - - ChallengeResponseCode int -} diff --git a/lib/policy/state.go b/lib/policy/state.go new file mode 100644 index 0000000..e3aff51 --- /dev/null +++ b/lib/policy/state.go @@ -0,0 +1,19 @@ +package policy + +import ( + "git.gammaspectra.live/git/go-away/utils" + "net/http" +) + +type StateSettings struct { + Cache utils.Cache + Backends map[string]http.Handler + PrivateKeySeed []byte + MainName string + MainVersion string + PackageName string + ClientIpHeader string + BackendIpHeader string + + ChallengeResponseCode int +} diff --git a/lib/settings/backend.go b/lib/settings/backend.go new file mode 100644 index 0000000..31d075a --- /dev/null +++ b/lib/settings/backend.go @@ -0,0 +1,85 @@ +package settings + +import ( + "git.gammaspectra.live/git/go-away/utils" + "net/http" + "net/http/httputil" +) + +type Backend struct { + // URL Target server backend path. Supports http/https/unix protocols. + URL string `yaml:"url"` + + // Host Override the Host header and TLS SNI with this value if specified + Host string `yaml:"host"` + + //ProxyProtocol uint8 `yaml:"proxy-protocol"` + + // HTTP2Enabled Enable HTTP2 to backend + HTTP2Enabled bool `yaml:"http2-enabled"` + + // TLSSkipVerify Disable TLS certificate verification, if any + TLSSkipVerify bool `yaml:"tls-skip-verify"` +} + +func (b Backend) Create() (*httputil.ReverseProxy, error) { + proxy, err := utils.MakeReverseProxy(b.URL) + if err != nil { + return nil, err + } + + transport := proxy.Transport.(*http.Transport) + + if b.HTTP2Enabled { + transport.ForceAttemptHTTP2 = true + } + + if b.TLSSkipVerify { + transport.TLSClientConfig.InsecureSkipVerify = true + } + + if b.Host != "" { + transport.TLSClientConfig.ServerName = b.Host + director := proxy.Director + proxy.Director = func(req *http.Request) { + req.Host = b.Host + director(req) + } + } + + /*if b.ProxyProtocol > 0 { + dialContext := transport.DialContext + if dialContext == nil { + dialContext = (&net.Dialer{}).DialContext + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := dialContext(ctx, network, addr) + if err != nil { + return nil, err + } + addrPort := utils.GetRemoteAddress(ctx) + if addrPort == nil { + // pass as is + hdr := proxyproto.HeaderProxyFromAddrs(b.ProxyProtocol, conn.LocalAddr(), conn.RemoteAddr()) + _, err = hdr.WriteTo(conn) + if err != nil { + conn.Close() + return nil, err + } + } else { + // set proper headers! + hdr := proxyproto.HeaderProxyFromAddrs(b.ProxyProtocol, net.TCPAddrFromAddrPort(*addrPort), conn.RemoteAddr()) + _, err = hdr.WriteTo(conn) + if err != nil { + conn.Close() + return nil, err + } + } + return conn, nil + } + }*/ + + proxy.Transport = transport + + return proxy, nil +} diff --git a/lib/settings/bind.go b/lib/settings/bind.go new file mode 100644 index 0000000..b614901 --- /dev/null +++ b/lib/settings/bind.go @@ -0,0 +1,169 @@ +package settings + +import ( + "context" + "crypto/tls" + "fmt" + "git.gammaspectra.live/git/go-away/utils" + "github.com/pires/go-proxyproto" + "golang.org/x/crypto/acme" + "golang.org/x/crypto/acme/autocert" + "log" + "log/slog" + "net" + "net/http" + "os" + "strconv" + "sync/atomic" +) + +type Bind struct { + Address string `yaml:"address"` + Network string `yaml:"network"` + SocketMode string `yaml:"socket-mode"` + Proxy bool `yaml:"proxy"` + + Passthrough bool `yaml:"passthrough"` + + // TLSAcmeAutoCert URL to ACME directory, or letsencrypt + TLSAcmeAutoCert string `yaml:"tls-acme-autocert"` + + // TLSCertificate Alternate to TLSAcmeAutoCert + TLSCertificate string `yaml:"tls-certificate"` + // TLSPrivateKey Alternate to TLSAcmeAutoCert + TLSPrivateKey string `yaml:"tls-key"` +} + +func (b *Bind) Listener() (net.Listener, string) { + return setupListener(b.Network, b.Address, b.SocketMode, b.Proxy) +} + +func (b *Bind) Server(backends map[string]http.Handler, acmeCachePath string) (*http.Server, func(http.Handler), error) { + + var tlsConfig *tls.Config + + if b.TLSAcmeAutoCert != "" { + switch b.TLSAcmeAutoCert { + case "letsencrypt": + b.TLSAcmeAutoCert = acme.LetsEncryptURL + } + + acmeManager := newACMEManager(b.TLSAcmeAutoCert, backends) + if acmeCachePath != "" { + err := os.MkdirAll(acmeCachePath, 0755) + if err != nil { + return nil, nil, fmt.Errorf("failed to create acme cache directory: %w", err) + } + acmeManager.Cache = autocert.DirCache(acmeCachePath) + } + slog.Warn( + "acme-autocert enabled", + "directory", b.TLSAcmeAutoCert, + ) + tlsConfig = acmeManager.TLSConfig() + } else if b.TLSCertificate != "" && b.TLSPrivateKey != "" { + tlsConfig = &tls.Config{} + var err error + tlsConfig.Certificates = make([]tls.Certificate, 1) + tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(b.TLSCertificate, b.TLSPrivateKey) + if err != nil { + return nil, nil, err + } + slog.Warn( + "TLS enabled", + "certificate", b.TLSCertificate, + ) + } + + var serverHandler atomic.Pointer[http.Handler] + server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if handler := serverHandler.Load(); handler == nil { + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + } else { + (*handler).ServeHTTP(w, r) + } + }), tlsConfig) + + swap := func(handler http.Handler) { + serverHandler.Store(&handler) + } + + if b.Passthrough { + // setup a passthrough handler temporarily + swap(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backend := utils.SelectHTTPHandler(backends, r.Host) + if backend == nil { + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + } else { + backend.ServeHTTP(w, r) + } + }))) + } + + return server, swap, nil + +} + +func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) { + if network == "proxy" { + network = "tcp" + proxy = true + } + + formattedAddress := "" + switch network { + case "unix": + formattedAddress = "unix:" + address + case "tcp": + formattedAddress = "http://localhost" + address + default: + formattedAddress = fmt.Sprintf(`(%s) %s`, network, address) + } + + listener, err := net.Listen(network, address) + if err != nil { + log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err)) + } + + // additional permission handling for unix sockets + if network == "unix" { + mode, err := strconv.ParseUint(socketMode, 8, 0) + if err != nil { + listener.Close() + log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", socketMode, err)) + } + + err = os.Chmod(address, os.FileMode(mode)) + if err != nil { + listener.Close() + log.Fatal(fmt.Errorf("could not change socket mode: %w", err)) + } + } + + if proxy { + slog.Warn("listener PROXY enabled") + formattedAddress += " +PROXY" + listener = &proxyproto.Listener{ + Listener: listener, + } + } + + return listener, formattedAddress +} + +func newACMEManager(clientDirectory string, backends map[string]http.Handler) *autocert.Manager { + manager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostPolicy(func(ctx context.Context, host string) error { + if utils.SelectHTTPHandler(backends, host) != nil { + return nil + } + return fmt.Errorf("acme/autocert: host %s not configured in backends", host) + }), + Client: &acme.Client{ + HTTPClient: http.DefaultClient, + DirectoryURL: clientDirectory, + }, + } + return manager +} diff --git a/lib/settings/settings.go b/lib/settings/settings.go new file mode 100644 index 0000000..6f56616 --- /dev/null +++ b/lib/settings/settings.go @@ -0,0 +1,51 @@ +package settings + +import "maps" + +type Settings struct { + Bind Bind `json:"bind"` + + Backends map[string]Backend `json:"backends"` + + BindDebug string `json:"bind-debug"` + BindMetrics string `json:"bind-metrics"` + + Strings Strings `yaml:"strings"` + + // Links to add to challenge/error pages like privacy/impressum. + Links []Link `yaml:"links"` + + ChallengeTemplate string `yaml:"challenge-template"` + + // ChallengeTemplateOverrides Key/Value overrides for the current chosen template + // Replacements TODO: + // Path -> go-away path + ChallengeTemplateOverrides map[string]string `yaml:"challenge-template-overrides"` +} + +type Link struct { + Name string `yaml:"name"` + URL string `yaml:"url"` +} + +var DefaultSettings = Settings{ + Strings: DefaultStrings, + ChallengeTemplate: "anubis", + ChallengeTemplateOverrides: func() map[string]string { + m := make(map[string]string) + maps.Copy(m, map[string]string{ + "Theme": "", + "Logo": "", + }) + return m + }(), + + Bind: Bind{ + Address: ":8080", + Network: "tcp", + SocketMode: "0770", + Proxy: false, + TLSAcmeAutoCert: "", + }, + Backends: make(map[string]Backend), +} diff --git a/lib/settings/strings.go b/lib/settings/strings.go new file mode 100644 index 0000000..1fc4533 --- /dev/null +++ b/lib/settings/strings.go @@ -0,0 +1,24 @@ +package settings + +import "maps" + +type Strings map[string]string + +var DefaultStrings = make(Strings).set(map[string]string{ + "challenge_are_you_bot": "Checking you are not a bot", + "error": "Oh no!", +}) + +func (s Strings) set(v map[string]string) Strings { + maps.Copy(s, v) + return s +} + +func (s Strings) Get(value string) string { + v, ok := (s)[value] + if !ok { + // fallback + return "string:" + value + } + return v +} diff --git a/lib/state.go b/lib/state.go index 74d4521..d94d745 100644 --- a/lib/state.go +++ b/lib/state.go @@ -8,6 +8,7 @@ import ( "git.gammaspectra.live/git/go-away/lib/challenge" "git.gammaspectra.live/git/go-away/lib/condition" "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/lib/settings" "git.gammaspectra.live/git/go-away/utils" "github.com/google/cel-go/cel" "github.com/yl2chen/cidranger" @@ -31,7 +32,8 @@ type State struct { publicKey ed25519.PublicKey privateKey ed25519.PrivateKey - settings policy.Settings + opt settings.Settings + settings policy.StateSettings networks map[string]cidranger.Ranger @@ -44,10 +46,11 @@ type State struct { Mux *http.ServeMux } -func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler, err error) { +func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) { state := new(State) state.close = make(chan struct{}) state.settings = settings + state.opt = opt state.client = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse @@ -89,22 +92,18 @@ func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler, } } - if state.Settings().ChallengeTemplate == "" { - state.settings.ChallengeTemplate = "anubis" - } + if templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"] == nil { - if templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"] == nil { - - if data, err := os.ReadFile(state.Settings().ChallengeTemplate); err == nil && len(data) > 0 { - name := path.Base(state.Settings().ChallengeTemplate) + if data, err := os.ReadFile(state.Options().ChallengeTemplate); err == nil && len(data) > 0 { + name := path.Base(state.Options().ChallengeTemplate) err := initTemplate(name, string(data)) if err != nil { - return nil, fmt.Errorf("error loading template %s: %w", settings.ChallengeTemplate, err) + return nil, fmt.Errorf("error loading template %s: %w", state.Options().ChallengeTemplate, err) } - state.settings.ChallengeTemplate = name + state.opt.ChallengeTemplate = name } - return nil, fmt.Errorf("no template defined for %s", settings.ChallengeTemplate) + return nil, fmt.Errorf("no template defined for %s", state.Options().ChallengeTemplate) } state.networks = make(map[string]cidranger.Ranger) diff --git a/utils/http.go b/utils/http.go index 9765ff5..8e739fb 100644 --- a/utils/http.go +++ b/utils/http.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/http/httputil" + "net/netip" "net/url" "strings" ) @@ -75,6 +76,7 @@ func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) { } transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{} // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 if u.Scheme == "unix" { @@ -91,6 +93,7 @@ func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) { } rp := httputil.NewSingleHostReverseProxy(u) + rp.Transport = transport return rp, nil @@ -108,22 +111,44 @@ func GetRequestScheme(r *http.Request) string { return "http" } -func GetRequestAddress(r *http.Request, clientHeader string) net.IP { - var ipStr string +func GetRequestAddress(r *http.Request, clientHeader string) netip.AddrPort { + strVal := r.RemoteAddr + if clientHeader != "" { - ipStr = r.Header.Get(clientHeader) + strVal = r.Header.Get(clientHeader) } - if ipStr != "" { + if strVal != "" { // handle X-Forwarded-For - ipStr = strings.Split(ipStr, ",")[0] + strVal = strings.Split(strVal, ",")[0] } // fallback - if ipStr == "" { - ipStr, _, _ = net.SplitHostPort(r.RemoteAddr) + if strVal == "" { + strVal = r.RemoteAddr } - ipStr = strings.Trim(ipStr, "[]") - return net.ParseIP(ipStr) + + addrPort, err := netip.ParseAddrPort(strVal) + if err != nil { + addr, err2 := netip.ParseAddr(strVal) + if err2 != nil { + return netip.AddrPort{} + } + addrPort = netip.AddrPortFrom(addr, 0) + } + return addrPort +} + +type remoteAddress struct{} + +func SetRemoteAddress(r *http.Request, addrPort netip.AddrPort) *http.Request { + return r.WithContext(context.WithValue(r.Context(), remoteAddress{}, addrPort)) +} +func GetRemoteAddress(ctx context.Context) *netip.AddrPort { + ip, ok := ctx.Value(remoteAddress{}).(netip.AddrPort) + if !ok { + return nil + } + return &ip } func CacheBust() string {