From 150927e7ba88d7548481e021e33dce9907397720 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder <57538841+WeebDataHoarder@users.noreply.github.com> Date: Wed, 2 Apr 2025 19:23:09 +0200 Subject: [PATCH] Allow multiple backends --- cmd/away.go | 49 +++------------------------------------- lib/http.go | 54 ++++++++++++++++++++++++++++++++++++++++---- lib/policy/policy.go | 2 ++ lib/policy/rule.go | 1 + lib/state.go | 23 ++++++++++++++++--- policy.yml | 16 +++++++++++++ 6 files changed, 92 insertions(+), 53 deletions(-) diff --git a/cmd/away.go b/cmd/away.go index f354147..d575084 100644 --- a/cmd/away.go +++ b/cmd/away.go @@ -1,52 +1,20 @@ package main import ( - "context" "errors" "flag" "fmt" "git.gammaspectra.live/git/go-away/lib" - "git.gammaspectra.live/git/go-away/lib/network" "git.gammaspectra.live/git/go-away/lib/policy" "gopkg.in/yaml.v3" "log" "log/slog" "net" "net/http" - "net/http/httputil" - "net/url" "os" "strconv" ) -func makeReverseProxy(target string) (http.Handler, error) { - u, err := url.Parse(target) - if err != nil { - return nil, fmt.Errorf("failed to parse target URL: %w", err) - } - - transport := http.DefaultTransport.(*http.Transport).Clone() - - // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 - if u.Scheme == "unix" { - // clean path up so we don't use the socket path in proxied requests - addr := u.Path - u.Path = "" - // tell transport how to dial unix sockets - transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - dialer := net.Dialer{} - return dialer.DialContext(ctx, "unix", addr) - } - // tell transport how to handle the unix url scheme - transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport}) - } - - rp := httputil.NewSingleHostReverseProxy(u) - rp.Transport = transport - - return rp, nil -} - func setupListener(network, address, socketMode string) (net.Listener, string) { formattedAddress := "" switch network { @@ -88,15 +56,11 @@ func main() { slogLevel := flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)") - target := flag.String("target", "http://localhost:80", "target to reverse proxy to") - policyFile := flag.String("policy", "", "path to policy YAML file") challengeTemplate := flag.String("challenge-template", "anubis", "name of the challenge template to use") flag.Parse() - _, _, _, _ = bind, bindNetwork, socketMode, target - { var programLevel slog.Level if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil { @@ -119,19 +83,13 @@ func main() { log.Fatal(fmt.Errorf("failed to read policy file: %w", err)) } - var policy policy.Policy + var p policy.Policy - if err = yaml.Unmarshal(policyData, &policy); err != nil { + if err = yaml.Unmarshal(policyData, &p); err != nil { log.Fatal(fmt.Errorf("failed to parse policy file: %w", err)) } - backend, err := makeReverseProxy(*target) - if err != nil { - log.Fatal(fmt.Errorf("failed to create reverse proxy for %s: %w", *target, err)) - } - - state, err := lib.NewState(policy, lib.StateSettings{ - Backend: backend, + state, err := lib.NewState(p, lib.StateSettings{ PackagePath: "git.gammaspectra.live/git/go-away/cmd", ChallengeTemplate: *challengeTemplate, }) @@ -144,7 +102,6 @@ func main() { slog.Info( "listening", "url", listenUrl, - "target", *target, ) server := http.Server{ diff --git a/lib/http.go b/lib/http.go index 8d76a16..02298ba 100644 --- a/lib/http.go +++ b/lib/http.go @@ -3,16 +3,21 @@ package lib import ( "bytes" "codeberg.org/meta/gzipped/v2" + "context" "crypto/rand" "encoding/base64" "errors" "fmt" go_away "git.gammaspectra.live/git/go-away" + "git.gammaspectra.live/git/go-away/lib/network" "git.gammaspectra.live/git/go-away/lib/policy" "github.com/google/cel-go/common/types" "html/template" "maps" + "net" "net/http" + "net/http/httputil" + "net/url" "path/filepath" "strings" "time" @@ -54,6 +59,34 @@ func init() { } } +func makeReverseProxy(target string) (http.Handler, error) { + u, err := url.Parse(target) + if err != nil { + return nil, fmt.Errorf("failed to parse target URL: %w", err) + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + + // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 + if u.Scheme == "unix" { + // clean path up so we don't use the socket path in proxied requests + addr := u.Path + u.Path = "" + // tell transport how to dial unix sockets + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", addr) + } + // tell transport how to handle the unix url scheme + transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport}) + } + + rp := httputil.NewSingleHostReverseProxy(u) + rp.Transport = transport + + return rp, nil +} + func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error { input := make(map[string]any) input["Random"] = cacheBust @@ -104,8 +137,17 @@ func (state *State) errorPage(w http.ResponseWriter, status int, err error) erro func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { + host := r.Host + + backend, ok := state.Backends[host] + if !ok { + http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) + return + } + //TODO better matcher! combo ast? env := map[string]any{ + "host": host, "method": r.Method, "remoteAddress": state.GetRequestAddress(r), "userAgent": r.UserAgent(), @@ -127,6 +169,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { } for _, rule := range state.Rules { + // skip rules that have host match + if rule.Host != nil && *rule.Host != host { + continue + } if out, _, err := rule.Program.Eval(env); err != nil { //TODO error panic(err) @@ -136,7 +182,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { default: panic(fmt.Errorf("unknown action %s", rule.Action)) case policy.RuleActionPASS: - state.Backend.ServeHTTP(w, r) + backend.ServeHTTP(w, r) return case policy.RuleActionCHALLENGE, policy.RuleActionCHECK: expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity) @@ -154,7 +200,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { } // we passed the challenge! //TODO log? - state.Backend.ServeHTTP(w, r) + backend.ServeHTTP(w, r) return } } @@ -174,7 +220,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { goto nextRule } // we pass the challenge early! - state.Backend.ServeHTTP(w, r) + backend.ServeHTTP(w, r) return } } else { @@ -197,7 +243,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { nextRule: } - state.Backend.ServeHTTP(w, r) + backend.ServeHTTP(w, r) return } diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 41e5377..9a824b2 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -43,4 +43,6 @@ type Policy struct { Challenges map[string]Challenge `yaml:"challenges"` Rules []Rule `yaml:"rules"` + + Backends map[string]string `json:"backends"` } diff --git a/lib/policy/rule.go b/lib/policy/rule.go index 526de67..874d727 100644 --- a/lib/policy/rule.go +++ b/lib/policy/rule.go @@ -12,6 +12,7 @@ const ( type Rule struct { Name string `yaml:"name"` + Host *string `yaml:"host"` Conditions []string `yaml:"conditions"` Action string `yaml:"action"` diff --git a/lib/state.go b/lib/state.go index a019e95..d450fa3 100644 --- a/lib/state.go +++ b/lib/state.go @@ -40,7 +40,7 @@ type State struct { Settings StateSettings UrlPath string Mux *http.ServeMux - Backend http.Handler + Backends map[string]http.Handler Networks map[string]cidranger.Ranger @@ -61,6 +61,8 @@ type RuleState struct { Name string Hash string + Host *string + Program cel.Program Action policy.RuleAction Challenges []string @@ -94,7 +96,6 @@ type ChallengeState struct { } type StateSettings struct { - Backend http.Handler PackagePath string ChallengeTemplate string } @@ -108,7 +109,16 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) }, } state.UrlPath = "/.well-known/." + state.Settings.PackagePath - state.Backend = settings.Backend + + state.Backends = make(map[string]http.Handler) + + for k, v := range p.Backends { + backend, err := makeReverseProxy(v) + if err != nil { + return nil, fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err) + } + state.Backends[k] = backend + } state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader) if err != nil { @@ -492,6 +502,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) state.RulesEnv, err = cel.NewEnv( cel.DefaultUTCTimeZone(true), cel.Variable("remoteAddress", cel.BytesType), + cel.Variable("host", cel.StringType), cel.Variable("method", cel.StringType), cel.Variable("userAgent", cel.StringType), cel.Variable("path", cel.StringType), @@ -565,12 +576,18 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) for _, rule := range p.Rules { hasher := sha256.New() hasher.Write([]byte(rule.Name)) + hasher.Write([]byte{0}) + if rule.Host != nil { + hasher.Write([]byte(*rule.Host)) + } + hasher.Write([]byte{0}) hasher.Write(privateKeyFingerprint[:]) sum := hasher.Sum(nil) r := RuleState{ Name: rule.Name, Hash: hex.EncodeToString(sum[:8]), + Host: rule.Host, Action: policy.RuleAction(strings.ToUpper(rule.Action)), Challenges: rule.Challenges, } diff --git a/policy.yml b/policy.yml index 28ce40c..d95f271 100644 --- a/policy.yml +++ b/policy.yml @@ -1,4 +1,8 @@ +# Define backends to use. Rules can be done generally, or only applying to specific hosts +backends: + git.gammaspectra.live: http://gitea:3000 + # Define networks to be used later below networks: # todo: support direct ASN lookups @@ -218,6 +222,18 @@ conditions: # user activity tab - 'path.matches("^/[^/]") && "tab" in query && query.tab == "activity"' +# Rules and conditions are served this environment +# remoteAddress (net.IP) - Connecting client remote address from headers or properties +# host (string) - HTTP Host +# method (string) - HTTP Method/Verb +# userAgent (string) - HTTP User-Agent header +# path (string) - HTTP request Path +# query (map[string]string) - HTTP request Query arguments +# headers (map[string]string) - HTTP request headers +# +# Additionally these functions are available +# inNetwork(networkName string, address net.IP) bool +# inNetwork(networkCIDR string, address net.IP) bool rules: - name: undesired-networks conditions: