settings: introduce settings YAML file to complement cmd arguments

This commit is contained in:
WeebDataHoarder
2025-04-24 15:25:41 +02:00
parent fc7d67ad70
commit 9541c58eeb
15 changed files with 523 additions and 230 deletions

View File

@@ -4,78 +4,27 @@ import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"encoding/hex"
"errors"
"flag"
"fmt"
"git.gammaspectra.live/git/go-away/lib"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/lib/settings"
"git.gammaspectra.live/git/go-away/utils"
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"github.com/goccy/go-yaml"
"log"
"log/slog"
"net"
"net/http"
"net/http/pprof"
"os"
"os/signal"
"path"
"runtime/debug"
"strconv"
"strings"
"sync/atomic"
"syscall"
)
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
if network == "proxy" {
network = "tcp"
proxy = true
}
formattedAddress := ""
switch network {
case "unix":
formattedAddress = "unix:" + address
case "tcp":
formattedAddress = "http://localhost" + address
default:
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
}
listener, err := net.Listen(network, address)
if err != nil {
log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
}
// additional permission handling for unix sockets
if network == "unix" {
mode, err := strconv.ParseUint(socketMode, 8, 0)
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", socketMode, err))
}
err = os.Chmod(address, os.FileMode(mode))
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
}
}
if proxy {
slog.Warn("listener PROXY enabled")
formattedAddress += " +PROXY"
listener = &proxyproto.Listener{
Listener: listener,
}
}
return listener, formattedAddress
}
var internalCmdName = "go-away"
var internalMainName = "go-away"
var internalMainVersion = "dev"
@@ -101,40 +50,20 @@ func (v *MultiVar) Set(value string) error {
return nil
}
func newACMEManager(clientDirectory string, backends map[string]http.Handler) *autocert.Manager {
var domains []string
for d := range backends {
parts := strings.Split(d, ":")
d = parts[0]
if net.ParseIP(d) != nil {
continue
}
domains = append(domains, d)
}
manager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(domains...),
Client: &acme.Client{
HTTPClient: http.DefaultClient,
DirectoryURL: clientDirectory,
},
}
return manager
}
func main() {
bind := flag.String("bind", ":8080", "network address to bind HTTP/HTTP(s) to")
bindNetwork := flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
bindProxy := flag.Bool("bind-proxy", false, "use PROXY protocol in front of the listener")
socketMode := flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
opt := settings.DefaultSettings
flag.StringVar(&opt.Bind.Address, "bind", opt.Bind.Address, "network address to bind HTTP/HTTP(s) to")
flag.StringVar(&opt.Bind.Network, "bind-network", opt.Bind.Network, "network family to bind HTTP to, e.g. unix, tcp")
flag.BoolVar(&opt.Bind.Proxy, "bind-proxy", opt.Bind.Proxy, "use PROXY protocol in front of the listener")
flag.StringVar(&opt.Bind.SocketMode, "socket-mode", opt.Bind.SocketMode, "socket mode (permissions) for unix domain sockets.")
slogLevel := flag.String("slog-level", "WARN", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
debugMode := flag.Bool("debug", false, "debug mode with logs and server timings")
passThrough := flag.Bool("passthrough", false, "passthrough mode sends all requests to matching backends until state is loaded")
flag.BoolVar(&opt.Bind.Passthrough, "passthrough", opt.Bind.Passthrough, "passthrough mode sends all requests to matching backends until state is loaded")
check := flag.Bool("check", false, "check configuration and policies, then exit")
acmeAutocert := flag.String("acme-autocert", "", "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)")
flag.StringVar(&opt.Bind.TLSAcmeAutoCert, "acme-autocert", opt.Bind.TLSAcmeAutoCert, "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)")
clientIpHeader := flag.String("client-ip-header", "", "Client HTTP header to fetch their IP address from (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)")
backendIpHeader := flag.String("backend-ip-header", "", "Backend HTTP header to set the client IP address from, if empty defaults to leaving Client header alone (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)")
@@ -143,8 +72,9 @@ func main() {
policyFile := flag.String("policy", "", "path to policy YAML file")
policySnippets := flag.String("policy-snippets", "", "path to YAML snippets folder")
challengeTemplate := flag.String("challenge-template", "anubis", "name or path of the challenge template to use (anubis, forgejo)")
challengeTemplateTheme := flag.String("challenge-template-theme", "", "name of the challenge template theme to use (forgejo => [forgejo-auto, forgejo-dark, forgejo-light, gitea...])")
flag.StringVar(&opt.ChallengeTemplate, "challenge-template", opt.ChallengeTemplate, "name or path of the challenge template to use (anubis, forgejo)")
templateTheme := flag.String("challenge-template-theme", opt.ChallengeTemplateOverrides["Theme"], "name of the challenge template theme to use (forgejo => [forgejo-auto, forgejo-dark, forgejo-light, gitea...])")
packageName := flag.String("package-path", internalCmdName, "package name to expose in .well-known url path")
@@ -153,6 +83,8 @@ func main() {
var backends MultiVar
flag.Var(&backends, "backend", "backend definition in the form of an.example.com=http://backend:1234 (can be specified multiple times)")
settingsFile := flag.String("config", "", "path to config override YAML file")
flag.Parse()
var err error
@@ -176,6 +108,21 @@ func main() {
slog.Info("go-away", "package", internalMainName, "version", internalMainVersion, "cmd", internalCmdName)
// preload missing settings
opt.ChallengeTemplateOverrides["Theme"] = *templateTheme
// load overrides
if *settingsFile != "" {
settingsData, err := os.ReadFile(*settingsFile)
if err != nil {
log.Fatal(fmt.Errorf("could not read settings file: %w", err))
}
err = yaml.Unmarshal(settingsData, &opt)
if err != nil {
log.Fatal(fmt.Errorf("could not parse settings file: %w", err))
}
}
var seed []byte
var kValue string
@@ -207,18 +154,24 @@ func main() {
}
createdBackends := make(map[string]http.Handler)
parsedBackends := make(map[string]string)
for _, backend := range backends {
if backend == "" {
// skip empty to allow no values
continue
}
parts := strings.Split(backend, "=")
if len(parts) != 2 {
log.Fatal(fmt.Errorf("invalid backend definition: %s, expected 2 parts, got %v", backend, parts))
}
parsedBackends[parts[0]] = parts[1]
// make no-settings, default backend
opt.Backends[parts[0]] = settings.Backend{
URL: parts[1],
}
}
for k, v := range parsedBackends {
backend, err := utils.MakeReverseProxy(v)
for k, v := range opt.Backends {
backend, err := v.Create()
if err != nil {
log.Fatal(fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err))
}
@@ -228,10 +181,11 @@ func main() {
}
if len(createdBackends) == 0 {
log.Fatal(fmt.Errorf("no backends defined in policy file"))
log.Fatal(fmt.Errorf("no backends defined in cmdline or settings file"))
}
var cache utils.Cache
var acmeCache string
if *cachePath != "" {
err = os.MkdirAll(*cachePath, 0755)
if err != nil {
@@ -248,29 +202,8 @@ func main() {
if err != nil {
log.Fatal(fmt.Errorf("failed to open cache directory: %w", err))
}
}
var tlsConfig *tls.Config
if *acmeAutocert != "" {
switch *acmeAutocert {
case "letsencrypt":
*acmeAutocert = acme.LetsEncryptURL
}
acmeManager := newACMEManager(*acmeAutocert, createdBackends)
if *cachePath != "" {
err = os.MkdirAll(path.Join(*cachePath, "acme"), 0755)
if err != nil {
log.Fatal(fmt.Errorf("failed to create acme cache directory: %w", err))
}
acmeManager.Cache = autocert.DirCache(path.Join(*cachePath, "acme"))
}
slog.Warn(
"acme-autocert enabled",
"directory", *acmeAutocert,
)
tlsConfig = acmeManager.TLSConfig()
acmeCache = path.Join(*cachePath, "acme")
}
loadPolicyState := func() (http.Handler, error) {
@@ -284,22 +217,19 @@ func main() {
return nil, fmt.Errorf("failed to parse policy file: %w", err)
}
settings := policy.Settings{
Cache: cache,
Backends: createdBackends,
Debug: *debugMode,
MainName: internalMainName,
MainVersion: internalMainVersion,
PackageName: *packageName,
ChallengeTemplate: *challengeTemplate,
ChallengeTemplateTheme: *challengeTemplateTheme,
PrivateKeySeed: seed,
ClientIpHeader: *clientIpHeader,
BackendIpHeader: *backendIpHeader,
ChallengeResponseCode: http.StatusTeapot,
stateSettings := policy.StateSettings{
Cache: cache,
Backends: createdBackends,
MainName: internalMainName,
MainVersion: internalMainVersion,
PackageName: *packageName,
PrivateKeySeed: seed,
ClientIpHeader: *clientIpHeader,
BackendIpHeader: *backendIpHeader,
ChallengeResponseCode: http.StatusTeapot,
}
state, err := lib.NewState(*p, settings)
state, err := lib.NewState(*p, opt, stateSettings)
if err != nil {
return nil, fmt.Errorf("failed to create state: %w", err)
@@ -317,32 +247,15 @@ func main() {
os.Exit(0)
}
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
listener, listenUrl := opt.Bind.Listener()
slog.Warn(
"listening",
"url", listenUrl,
)
var serverHandler atomic.Pointer[http.Handler]
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler := serverHandler.Load(); handler == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
(*handler).ServeHTTP(w, r)
}
}), tlsConfig)
if *passThrough {
// setup a passthrough handler temporarily
fn := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend := utils.SelectHTTPHandler(createdBackends, r.Host)
if backend == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
backend.ServeHTTP(w, r)
}
}))
serverHandler.Store(&fn)
server, swap, err := opt.Bind.Server(createdBackends, acmeCache)
if err != nil {
log.Fatal(fmt.Errorf("failed to create server: %w", err))
}
go func() {
@@ -351,7 +264,7 @@ func main() {
log.Fatal(fmt.Errorf("failed to load policy state: %w", err))
}
serverHandler.Store(&handler)
swap(handler)
slog.Warn(
"handler configuration loaded",
)
@@ -369,12 +282,34 @@ func main() {
continue
}
serverHandler.Store(&handler)
swap(handler)
slog.Warn("handler configuration reloaded")
}
}()
if tlsConfig != nil {
if opt.BindDebug != "" {
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
debugServer := http.Server{
Addr: opt.BindDebug,
Handler: mux,
}
slog.Warn(
"listening metrics",
"bind", opt.BindDebug,
)
if err = debugServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
}()
}
if server.TLSConfig != nil {
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}

View File

@@ -12,8 +12,8 @@ import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
"net"
"net/http"
"net/netip"
"net/textproto"
"time"
)
@@ -36,7 +36,7 @@ type RequestData struct {
Time time.Time
ChallengeVerify map[Id]VerifyResult
ChallengeState map[Id]VerifyState
RemoteAddress net.IP
RemoteAddress netip.AddrPort
State StateInterface
CookiePrefix string
@@ -57,7 +57,6 @@ func CreateRequestData(r *http.Request, state StateInterface) (*http.Request, *R
data.ChallengeState = make(map[Id]VerifyState, len(state.GetChallenges()))
data.Time = time.Now().UTC()
data.State = state
data.r = r
data.fp = make(map[string]string, 2)
@@ -85,6 +84,8 @@ func CreateRequestData(r *http.Request, state StateInterface) (*http.Request, *R
data.CookiePrefix = utils.CookiePrefix + hex.EncodeToString(sum.Sum(nil)[:4]) + "-"
r = r.WithContext(context.WithValue(r.Context(), requestDataContextKey{}, &data))
r = utils.SetRemoteAddress(r, data.RemoteAddress)
data.r = r
return r, &data
}
@@ -96,7 +97,7 @@ func (d *RequestData) ResolveName(name string) (any, bool) {
case "method":
return d.r.Method, true
case "remoteAddress":
return d.RemoteAddress, true
return d.RemoteAddress.Addr().AsSlice(), true
case "userAgent":
return d.r.UserAgent(), true
case "path":

View File

@@ -119,9 +119,9 @@ func FillRegistration(state challenge.StateInterface, reg *challenge.Registratio
data := challenge.RequestDataFromContext(r.Context())
result, err := lookup(r.Context(), params.Decay, params.Timeout, dnsbl, decayMap, data.RemoteAddress)
result, err := lookup(r.Context(), params.Decay, params.Timeout, dnsbl, decayMap, data.RemoteAddress.Addr().Unmap().AsSlice())
if err != nil {
data.State.Logger(r).Debug("dnsbl lookup failed", "address", data.RemoteAddress.String(), "result", result, "err", err)
data.State.Logger(r).Debug("dnsbl lookup failed", "address", data.RemoteAddress.Addr().String(), "result", result, "err", err)
}
if result.Bad() {

View File

@@ -47,7 +47,8 @@ func GetChallengeKeyForRequest(state StateInterface, reg *Registration, until ti
hasher.Write([]byte("challenge\x00"))
hasher.Write([]byte(reg.Name))
hasher.Write([]byte{0})
hasher.Write(address.To16())
ipBuf := address.Addr().Unmap().As16()
hasher.Write(ipBuf[:])
hasher.Write([]byte{0})
// specific headers
@@ -72,7 +73,7 @@ func GetChallengeKeyForRequest(state StateInterface, reg *Registration, until ti
sum[0] = 0
if address.To4() != nil {
if address.Addr().Unmap().Is4() {
// Is IPv4, mark
sum.Set(KeyFlagIsIPv4)
}

View File

@@ -3,6 +3,7 @@ package challenge
import (
"crypto/ed25519"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/lib/settings"
"github.com/google/cel-go/cel"
"log/slog"
"net/http"
@@ -106,7 +107,9 @@ type StateInterface interface {
GetChallengeByName(name string) (*Registration, bool)
GetChallenges() Register
Settings() policy.Settings
Settings() policy.StateSettings
Options() settings.Settings
GetBackend(host string) http.Handler
}

View File

@@ -10,10 +10,7 @@ import (
"html/template"
"log/slog"
"net/http"
"net/http/pprof"
"strconv"
"strings"
"time"
)
var templates map[string]*template.Template
@@ -51,17 +48,11 @@ func initTemplate(name, data string) error {
return nil
}
func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration time.Duration) {
if state.Settings().Debug {
w.Header().Add("Server-Timing", fmt.Sprintf("%s;desc=%s;dur=%d", name, strconv.Quote(desc), duration.Milliseconds()))
}
}
func GetLoggerForRequest(r *http.Request) *slog.Logger {
data := challenge.RequestDataFromContext(r.Context())
args := []any{
"request_id", data.Id.String(),
"remote_address", data.RemoteAddress.String(),
"remote_address", data.RemoteAddress.Addr().String(),
"user_agent", r.UserAgent(),
"host", r.Host,
"path", r.URL.Path,
@@ -152,14 +143,6 @@ func (state *State) setupRoutes() error {
state.Mux.HandleFunc("/", state.handleRequest)
if state.Settings().Debug {
//TODO: split this to a different listener, metrics listener
http.HandleFunc(state.urlPath+"/debug/pprof/", pprof.Index)
http.HandleFunc(state.urlPath+"/debug/pprof/profile", pprof.Profile)
http.HandleFunc(state.urlPath+"/debug/pprof/symbol", pprof.Symbol)
http.HandleFunc(state.urlPath+"/debug/pprof/trace", pprof.Trace)
}
state.Mux.Handle("GET "+state.urlPath+"/assets/", http.StripPrefix(state.UrlPath()+"/assets/", gzipped.FileServer(gzipped.FS(embed.AssetsFs))))
for _, reg := range state.challenges {

View File

@@ -5,6 +5,7 @@ import (
"crypto/ed25519"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/lib/settings"
"git.gammaspectra.live/git/go-away/utils"
"github.com/google/cel-go/cel"
"log/slog"
@@ -72,23 +73,30 @@ func (state *State) ChallengePage(w http.ResponseWriter, r *http.Request, status
input := make(map[string]any)
input["Id"] = data.Id.String()
input["Random"] = utils.CacheBust()
input["Path"] = state.UrlPath()
for k, v := range state.Options().ChallengeTemplateOverrides {
input[k] = v
}
for k, v := range state.Options().Strings {
input["str_"+k] = v
}
if reg != nil {
input["Challenge"] = reg.Name
input["Path"] = state.UrlPath()
}
input["Theme"] = state.Settings().ChallengeTemplateTheme
maps.Copy(input, params)
if _, ok := input["Title"]; !ok {
input["Title"] = "Checking you are not a bot"
input["Title"] = state.Options().Strings.Get("challenge_are_you_bot")
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
buf := bytes.NewBuffer(make([]byte, 0, 8192))
err := templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"].Execute(buf, input)
err := templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"].Execute(buf, input)
if err != nil {
state.ErrorPage(w, r, http.StatusInternalServerError, err, "")
} else {
@@ -103,17 +111,25 @@ func (state *State) ErrorPage(w http.ResponseWriter, r *http.Request, status int
buf := bytes.NewBuffer(make([]byte, 0, 8192))
err2 := templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"].Execute(buf, map[string]any{
input := map[string]any{
"Id": data.Id.String(),
"Random": utils.CacheBust(),
"Error": err.Error(),
"Path": state.UrlPath(),
"Theme": state.Settings().ChallengeTemplateTheme,
"Title": "Oh no! " + http.StatusText(status),
"Theme": "",
"Title": state.Options().Strings.Get("error") + " " + http.StatusText(status),
"HideSpinner": true,
"Challenge": "",
"Redirect": redirect,
})
}
for k, v := range state.Options().ChallengeTemplateOverrides {
input[k] = v
}
for k, v := range state.Options().Strings {
input["str_"+k] = v
}
err2 := templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"].Execute(buf, input)
if err2 != nil {
// nested errors!
panic(err2)
@@ -136,10 +152,14 @@ func (state *State) GetChallengeByName(name string) (*challenge.Registration, bo
reg, _, ok := state.challenges.GetByName(name)
return reg, ok
}
func (state *State) Settings() policy.Settings {
func (state *State) Settings() policy.StateSettings {
return state.settings
}
func (state *State) Options() settings.Settings {
return state.opt
}
func (state *State) GetBackend(host string) http.Handler {
return utils.SelectHTTPHandler(state.Settings().Backends, host)
}

View File

@@ -1,22 +0,0 @@
package policy
import (
"git.gammaspectra.live/git/go-away/utils"
"net/http"
)
type Settings struct {
Cache utils.Cache
Backends map[string]http.Handler
PrivateKeySeed []byte
Debug bool
MainName string
MainVersion string
PackageName string
ChallengeTemplate string
ChallengeTemplateTheme string
ClientIpHeader string
BackendIpHeader string
ChallengeResponseCode int
}

19
lib/policy/state.go Normal file
View File

@@ -0,0 +1,19 @@
package policy
import (
"git.gammaspectra.live/git/go-away/utils"
"net/http"
)
type StateSettings struct {
Cache utils.Cache
Backends map[string]http.Handler
PrivateKeySeed []byte
MainName string
MainVersion string
PackageName string
ClientIpHeader string
BackendIpHeader string
ChallengeResponseCode int
}

85
lib/settings/backend.go Normal file
View File

@@ -0,0 +1,85 @@
package settings
import (
"git.gammaspectra.live/git/go-away/utils"
"net/http"
"net/http/httputil"
)
type Backend struct {
// URL Target server backend path. Supports http/https/unix protocols.
URL string `yaml:"url"`
// Host Override the Host header and TLS SNI with this value if specified
Host string `yaml:"host"`
//ProxyProtocol uint8 `yaml:"proxy-protocol"`
// HTTP2Enabled Enable HTTP2 to backend
HTTP2Enabled bool `yaml:"http2-enabled"`
// TLSSkipVerify Disable TLS certificate verification, if any
TLSSkipVerify bool `yaml:"tls-skip-verify"`
}
func (b Backend) Create() (*httputil.ReverseProxy, error) {
proxy, err := utils.MakeReverseProxy(b.URL)
if err != nil {
return nil, err
}
transport := proxy.Transport.(*http.Transport)
if b.HTTP2Enabled {
transport.ForceAttemptHTTP2 = true
}
if b.TLSSkipVerify {
transport.TLSClientConfig.InsecureSkipVerify = true
}
if b.Host != "" {
transport.TLSClientConfig.ServerName = b.Host
director := proxy.Director
proxy.Director = func(req *http.Request) {
req.Host = b.Host
director(req)
}
}
/*if b.ProxyProtocol > 0 {
dialContext := transport.DialContext
if dialContext == nil {
dialContext = (&net.Dialer{}).DialContext
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr)
if err != nil {
return nil, err
}
addrPort := utils.GetRemoteAddress(ctx)
if addrPort == nil {
// pass as is
hdr := proxyproto.HeaderProxyFromAddrs(b.ProxyProtocol, conn.LocalAddr(), conn.RemoteAddr())
_, err = hdr.WriteTo(conn)
if err != nil {
conn.Close()
return nil, err
}
} else {
// set proper headers!
hdr := proxyproto.HeaderProxyFromAddrs(b.ProxyProtocol, net.TCPAddrFromAddrPort(*addrPort), conn.RemoteAddr())
_, err = hdr.WriteTo(conn)
if err != nil {
conn.Close()
return nil, err
}
}
return conn, nil
}
}*/
proxy.Transport = transport
return proxy, nil
}

169
lib/settings/bind.go Normal file
View File

@@ -0,0 +1,169 @@
package settings
import (
"context"
"crypto/tls"
"fmt"
"git.gammaspectra.live/git/go-away/utils"
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"log"
"log/slog"
"net"
"net/http"
"os"
"strconv"
"sync/atomic"
)
type Bind struct {
Address string `yaml:"address"`
Network string `yaml:"network"`
SocketMode string `yaml:"socket-mode"`
Proxy bool `yaml:"proxy"`
Passthrough bool `yaml:"passthrough"`
// TLSAcmeAutoCert URL to ACME directory, or letsencrypt
TLSAcmeAutoCert string `yaml:"tls-acme-autocert"`
// TLSCertificate Alternate to TLSAcmeAutoCert
TLSCertificate string `yaml:"tls-certificate"`
// TLSPrivateKey Alternate to TLSAcmeAutoCert
TLSPrivateKey string `yaml:"tls-key"`
}
func (b *Bind) Listener() (net.Listener, string) {
return setupListener(b.Network, b.Address, b.SocketMode, b.Proxy)
}
func (b *Bind) Server(backends map[string]http.Handler, acmeCachePath string) (*http.Server, func(http.Handler), error) {
var tlsConfig *tls.Config
if b.TLSAcmeAutoCert != "" {
switch b.TLSAcmeAutoCert {
case "letsencrypt":
b.TLSAcmeAutoCert = acme.LetsEncryptURL
}
acmeManager := newACMEManager(b.TLSAcmeAutoCert, backends)
if acmeCachePath != "" {
err := os.MkdirAll(acmeCachePath, 0755)
if err != nil {
return nil, nil, fmt.Errorf("failed to create acme cache directory: %w", err)
}
acmeManager.Cache = autocert.DirCache(acmeCachePath)
}
slog.Warn(
"acme-autocert enabled",
"directory", b.TLSAcmeAutoCert,
)
tlsConfig = acmeManager.TLSConfig()
} else if b.TLSCertificate != "" && b.TLSPrivateKey != "" {
tlsConfig = &tls.Config{}
var err error
tlsConfig.Certificates = make([]tls.Certificate, 1)
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(b.TLSCertificate, b.TLSPrivateKey)
if err != nil {
return nil, nil, err
}
slog.Warn(
"TLS enabled",
"certificate", b.TLSCertificate,
)
}
var serverHandler atomic.Pointer[http.Handler]
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler := serverHandler.Load(); handler == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
(*handler).ServeHTTP(w, r)
}
}), tlsConfig)
swap := func(handler http.Handler) {
serverHandler.Store(&handler)
}
if b.Passthrough {
// setup a passthrough handler temporarily
swap(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend := utils.SelectHTTPHandler(backends, r.Host)
if backend == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
backend.ServeHTTP(w, r)
}
})))
}
return server, swap, nil
}
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
if network == "proxy" {
network = "tcp"
proxy = true
}
formattedAddress := ""
switch network {
case "unix":
formattedAddress = "unix:" + address
case "tcp":
formattedAddress = "http://localhost" + address
default:
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
}
listener, err := net.Listen(network, address)
if err != nil {
log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
}
// additional permission handling for unix sockets
if network == "unix" {
mode, err := strconv.ParseUint(socketMode, 8, 0)
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", socketMode, err))
}
err = os.Chmod(address, os.FileMode(mode))
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
}
}
if proxy {
slog.Warn("listener PROXY enabled")
formattedAddress += " +PROXY"
listener = &proxyproto.Listener{
Listener: listener,
}
}
return listener, formattedAddress
}
func newACMEManager(clientDirectory string, backends map[string]http.Handler) *autocert.Manager {
manager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostPolicy(func(ctx context.Context, host string) error {
if utils.SelectHTTPHandler(backends, host) != nil {
return nil
}
return fmt.Errorf("acme/autocert: host %s not configured in backends", host)
}),
Client: &acme.Client{
HTTPClient: http.DefaultClient,
DirectoryURL: clientDirectory,
},
}
return manager
}

51
lib/settings/settings.go Normal file
View File

@@ -0,0 +1,51 @@
package settings
import "maps"
type Settings struct {
Bind Bind `json:"bind"`
Backends map[string]Backend `json:"backends"`
BindDebug string `json:"bind-debug"`
BindMetrics string `json:"bind-metrics"`
Strings Strings `yaml:"strings"`
// Links to add to challenge/error pages like privacy/impressum.
Links []Link `yaml:"links"`
ChallengeTemplate string `yaml:"challenge-template"`
// ChallengeTemplateOverrides Key/Value overrides for the current chosen template
// Replacements TODO:
// Path -> go-away path
ChallengeTemplateOverrides map[string]string `yaml:"challenge-template-overrides"`
}
type Link struct {
Name string `yaml:"name"`
URL string `yaml:"url"`
}
var DefaultSettings = Settings{
Strings: DefaultStrings,
ChallengeTemplate: "anubis",
ChallengeTemplateOverrides: func() map[string]string {
m := make(map[string]string)
maps.Copy(m, map[string]string{
"Theme": "",
"Logo": "",
})
return m
}(),
Bind: Bind{
Address: ":8080",
Network: "tcp",
SocketMode: "0770",
Proxy: false,
TLSAcmeAutoCert: "",
},
Backends: make(map[string]Backend),
}

24
lib/settings/strings.go Normal file
View File

@@ -0,0 +1,24 @@
package settings
import "maps"
type Strings map[string]string
var DefaultStrings = make(Strings).set(map[string]string{
"challenge_are_you_bot": "Checking you are not a bot",
"error": "Oh no!",
})
func (s Strings) set(v map[string]string) Strings {
maps.Copy(s, v)
return s
}
func (s Strings) Get(value string) string {
v, ok := (s)[value]
if !ok {
// fallback
return "string:" + value
}
return v
}

View File

@@ -8,6 +8,7 @@ import (
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/condition"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/lib/settings"
"git.gammaspectra.live/git/go-away/utils"
"github.com/google/cel-go/cel"
"github.com/yl2chen/cidranger"
@@ -31,7 +32,8 @@ type State struct {
publicKey ed25519.PublicKey
privateKey ed25519.PrivateKey
settings policy.Settings
opt settings.Settings
settings policy.StateSettings
networks map[string]cidranger.Ranger
@@ -44,10 +46,11 @@ type State struct {
Mux *http.ServeMux
}
func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler, err error) {
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) {
state := new(State)
state.close = make(chan struct{})
state.settings = settings
state.opt = opt
state.client = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
@@ -89,22 +92,18 @@ func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler,
}
}
if state.Settings().ChallengeTemplate == "" {
state.settings.ChallengeTemplate = "anubis"
}
if templates["challenge-"+state.Options().ChallengeTemplate+".gohtml"] == nil {
if templates["challenge-"+state.Settings().ChallengeTemplate+".gohtml"] == nil {
if data, err := os.ReadFile(state.Settings().ChallengeTemplate); err == nil && len(data) > 0 {
name := path.Base(state.Settings().ChallengeTemplate)
if data, err := os.ReadFile(state.Options().ChallengeTemplate); err == nil && len(data) > 0 {
name := path.Base(state.Options().ChallengeTemplate)
err := initTemplate(name, string(data))
if err != nil {
return nil, fmt.Errorf("error loading template %s: %w", settings.ChallengeTemplate, err)
return nil, fmt.Errorf("error loading template %s: %w", state.Options().ChallengeTemplate, err)
}
state.settings.ChallengeTemplate = name
state.opt.ChallengeTemplate = name
}
return nil, fmt.Errorf("no template defined for %s", settings.ChallengeTemplate)
return nil, fmt.Errorf("no template defined for %s", state.Options().ChallengeTemplate)
}
state.networks = make(map[string]cidranger.Ranger)

View File

@@ -10,6 +10,7 @@ import (
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
)
@@ -75,6 +76,7 @@ func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) {
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{}
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if u.Scheme == "unix" {
@@ -91,6 +93,7 @@ func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) {
}
rp := httputil.NewSingleHostReverseProxy(u)
rp.Transport = transport
return rp, nil
@@ -108,22 +111,44 @@ func GetRequestScheme(r *http.Request) string {
return "http"
}
func GetRequestAddress(r *http.Request, clientHeader string) net.IP {
var ipStr string
func GetRequestAddress(r *http.Request, clientHeader string) netip.AddrPort {
strVal := r.RemoteAddr
if clientHeader != "" {
ipStr = r.Header.Get(clientHeader)
strVal = r.Header.Get(clientHeader)
}
if ipStr != "" {
if strVal != "" {
// handle X-Forwarded-For
ipStr = strings.Split(ipStr, ",")[0]
strVal = strings.Split(strVal, ",")[0]
}
// fallback
if ipStr == "" {
ipStr, _, _ = net.SplitHostPort(r.RemoteAddr)
if strVal == "" {
strVal = r.RemoteAddr
}
ipStr = strings.Trim(ipStr, "[]")
return net.ParseIP(ipStr)
addrPort, err := netip.ParseAddrPort(strVal)
if err != nil {
addr, err2 := netip.ParseAddr(strVal)
if err2 != nil {
return netip.AddrPort{}
}
addrPort = netip.AddrPortFrom(addr, 0)
}
return addrPort
}
type remoteAddress struct{}
func SetRemoteAddress(r *http.Request, addrPort netip.AddrPort) *http.Request {
return r.WithContext(context.WithValue(r.Context(), remoteAddress{}, addrPort))
}
func GetRemoteAddress(ctx context.Context) *netip.AddrPort {
ip, ok := ctx.Value(remoteAddress{}).(netip.AddrPort)
if !ok {
return nil
}
return &ip
}
func CacheBust() string {