Move most code under lib

This commit is contained in:
WeebDataHoarder
2025-04-01 21:22:19 +02:00
parent cccc06cb54
commit df5e125cf2
13 changed files with 155 additions and 134 deletions

View File

@@ -1 +0,0 @@
package go_away

View File

@@ -5,7 +5,9 @@ import (
"errors"
"flag"
"fmt"
go_away "git.gammaspectra.live/git/go-away"
"git.gammaspectra.live/git/go-away/lib"
"git.gammaspectra.live/git/go-away/lib/network"
"git.gammaspectra.live/git/go-away/lib/policy"
"gopkg.in/yaml.v3"
"log"
"log/slog"
@@ -36,7 +38,7 @@ func makeReverseProxy(target string) (http.Handler, error) {
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", go_away.UnixRoundTripper{Transport: transport})
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
}
rp := httputil.NewSingleHostReverseProxy(u)
@@ -116,7 +118,7 @@ func main() {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
}
var policy go_away.Policy
var policy policy.Policy
if err = yaml.Unmarshal(policyData, &policy); err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
@@ -127,7 +129,7 @@ func main() {
log.Fatal(fmt.Errorf("failed to create reverse proxy for %s: %w", *target, err))
}
state, err := go_away.NewState(policy, "git.gammaspectra.live/git/go-away/cmd", backend)
state, err := lib.NewState(policy, "git.gammaspectra.live/git/go-away/cmd", backend)
if err != nil {
log.Fatal(fmt.Errorf("failed to create state: %w", err))

12
embed.go Normal file
View File

@@ -0,0 +1,12 @@
package go_away
import "embed"
//go:embed assets
var AssetsFs embed.FS
//go:embed challenge
var ChallengeFs embed.FS
//go:embed templates
var TemplatesFs embed.FS

View File

@@ -1,4 +1,4 @@
package go_away
package lib
import (
"bytes"

View File

@@ -1,4 +1,4 @@
package go_away
package condition
import (
"fmt"
@@ -15,7 +15,7 @@ const (
OperatorAnd = "&&"
)
func ConditionFromStrings(env *cel.Env, operator string, conditions ...string) (*cel.Ast, error) {
func FromStrings(env *cel.Env, operator string, conditions ...string) (*cel.Ast, error) {
var asts []*cel.Ast
for _, c := range conditions {
ast, issues := env.Compile(c)
@@ -25,10 +25,10 @@ func ConditionFromStrings(env *cel.Env, operator string, conditions ...string) (
asts = append(asts, ast)
}
return MergeConditions(env, operator, asts...)
return Merge(env, operator, asts...)
}
func MergeConditions(env *cel.Env, operator string, conditions ...*cel.Ast) (*cel.Ast, error) {
func Merge(env *cel.Env, operator string, conditions ...*cel.Ast) (*cel.Ast, error) {
if len(conditions) == 0 {
return nil, nil
} else if len(conditions) == 1 {

View File

@@ -1,4 +1,4 @@
package go_away
package lib
import (
"net/http"

View File

@@ -1,12 +1,13 @@
package go_away
package lib
import (
"codeberg.org/meta/gzipped/v2"
"crypto/rand"
"embed"
"encoding/base64"
"errors"
"fmt"
go_away "git.gammaspectra.live/git/go-away"
"git.gammaspectra.live/git/go-away/lib/policy"
"github.com/google/cel-go/common/types"
"html/template"
"net/http"
@@ -15,15 +16,6 @@ import (
"time"
)
//go:embed assets
var assetsFs embed.FS
//go:embed challenge
var challengesFs embed.FS
//go:embed templates
var templatesFs embed.FS
var templates map[string]*template.Template
var cacheBust string
@@ -39,7 +31,7 @@ func init() {
templates = make(map[string]*template.Template)
dir, err := templatesFs.ReadDir("templates")
dir, err := go_away.TemplatesFs.ReadDir("templates")
if err != nil {
panic(err)
}
@@ -47,7 +39,7 @@ func init() {
if e.IsDir() {
continue
}
data, err := templatesFs.ReadFile(filepath.Join("templates", e.Name()))
data, err := go_away.TemplatesFs.ReadFile(filepath.Join("templates", e.Name()))
if err != nil {
panic(err)
}
@@ -92,10 +84,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
switch rule.Action {
default:
panic(fmt.Errorf("unknown action %s", rule.Action))
case PolicyRuleActionPASS:
case policy.RuleActionPASS:
state.Backend.ServeHTTP(w, r)
return
case PolicyRuleActionCHALLENGE, PolicyRuleActionCHECK:
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
for _, challengeName := range rule.Challenges {
@@ -106,7 +98,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
ClearCookie(CookiePrefix+challengeName, w)
}
} else {
if rule.Action == PolicyRuleActionCHECK {
if rule.Action == policy.RuleActionCHECK {
goto nextRule
}
// we passed the challenge!
@@ -127,7 +119,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
case ChallengeResultContinue:
continue
case ChallengeResultPass:
if rule.Action == PolicyRuleActionCHECK {
if rule.Action == policy.RuleActionCHECK {
goto nextRule
}
// we pass the challenge early!
@@ -138,11 +130,11 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
panic("challenge not found")
}
}
case PolicyRuleActionDENY:
case policy.RuleActionDENY:
//TODO: config error code
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
case PolicyRuleActionBLOCK:
case policy.RuleActionBLOCK:
//TODO: config error code
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
@@ -161,7 +153,7 @@ func (state *State) setupRoutes() error {
state.Mux.HandleFunc("/", state.handleRequest)
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(assetsFs))))
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(go_away.AssetsFs))))
for challengeName, c := range state.Challenges {
if c.Static != nil {
@@ -213,19 +205,3 @@ func (state *State) setupRoutes() error {
return nil
}
// UnixRoundTripper https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
type UnixRoundTripper struct {
Transport *http.Transport
}
// RoundTrip set bare minimum stuff
func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
if req.Host == "" {
req.Host = "localhost"
}
req.URL.Host = req.Host // proxy error: no Host in request URL
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
return t.Transport.RoundTrip(req)
}

19
lib/network/unix.go Normal file
View File

@@ -0,0 +1,19 @@
package network
import "net/http"
// UnixRoundTripper https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
type UnixRoundTripper struct {
Transport *http.Transport
}
// RoundTrip set bare minimum stuff
func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
if req.Host == "" {
req.Host = "localhost"
}
req.URL.Host = req.Host // proxy error: no Host in request URL
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
return t.Transport.RoundTrip(req)
}

14
lib/policy/challenge.go Normal file
View File

@@ -0,0 +1,14 @@
package policy
type Challenge struct {
Mode string `yaml:"mode"`
Asset *string `yaml:"asset,omitempty"`
Url *string `yaml:"url,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
Runtime struct {
Mode string `yaml:"mode,omitempty"`
Asset string `yaml:"asset,omitempty"`
Probability float64 `yaml:"probability,omitempty"`
} `yaml:"runtime"`
}

View File

@@ -1,4 +1,4 @@
package go_away
package policy
import (
"encoding/json"
@@ -12,78 +12,7 @@ import (
"regexp"
)
func parseCIDROrIP(value string) (net.IPNet, error) {
_, ipNet, err := net.ParseCIDR(value)
if err != nil {
ip := net.ParseIP(value)
if ip == nil {
return net.IPNet{}, fmt.Errorf("failed to parse CIDR: %s", err)
}
if ip4 := ip.To4(); ip4 != nil {
return net.IPNet{
IP: ip4,
// single ip
Mask: net.CIDRMask(len(ip4)*8, len(ip4)*8),
}, nil
}
return net.IPNet{
IP: ip,
// single ip
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}, nil
} else if ipNet != nil {
return *ipNet, nil
} else {
return net.IPNet{}, errors.New("invalid CIDR")
}
}
type Policy struct {
// Networks map of networks and prefixes to be loaded
Networks map[string][]PolicyNetwork `yaml:"networks"`
Conditions map[string][]string `yaml:"conditions"`
Challenges map[string]PolicyChallenge `yaml:"challenges"`
Rules []PolicyRule `yaml:"rules"`
}
type PolicyRuleAction string
const (
PolicyRuleActionPASS PolicyRuleAction = "PASS"
PolicyRuleActionDENY PolicyRuleAction = "DENY"
PolicyRuleActionBLOCK PolicyRuleAction = "BLOCK"
PolicyRuleActionCHALLENGE PolicyRuleAction = "CHALLENGE"
PolicyRuleActionCHECK PolicyRuleAction = "CHECK"
)
type PolicyRule struct {
Name string `yaml:"name"`
Conditions []string `yaml:"conditions"`
Action string `yaml:"action"`
Challenges []string `yaml:"challenges"`
}
type PolicyChallenge struct {
Mode string `yaml:"mode"`
Asset *string `yaml:"asset,omitempty"`
Url *string `yaml:"url,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
Runtime struct {
Mode string `yaml:"mode,omitempty"`
Asset string `yaml:"asset,omitempty"`
Probability float64 `yaml:"probability,omitempty"`
} `yaml:"runtime"`
}
type PolicyNetwork struct {
type Network struct {
Url *string `yaml:"url,omitempty"`
File *string `yaml:"file,omitempty"`
@@ -93,7 +22,7 @@ type PolicyNetwork struct {
Prefixes []string `yaml:"prefixes,omitempty"`
}
func (n PolicyNetwork) FetchPrefixes(c *http.Client) (output []net.IPNet, err error) {
func (n Network) FetchPrefixes(c *http.Client) (output []net.IPNet, err error) {
if len(n.Prefixes) > 0 {
for _, prefix := range n.Prefixes {
ipNet, err := parseCIDROrIP(prefix)

46
lib/policy/policy.go Normal file
View File

@@ -0,0 +1,46 @@
package policy
import (
"errors"
"fmt"
"net"
)
func parseCIDROrIP(value string) (net.IPNet, error) {
_, ipNet, err := net.ParseCIDR(value)
if err != nil {
ip := net.ParseIP(value)
if ip == nil {
return net.IPNet{}, fmt.Errorf("failed to parse CIDR: %s", err)
}
if ip4 := ip.To4(); ip4 != nil {
return net.IPNet{
IP: ip4,
// single ip
Mask: net.CIDRMask(len(ip4)*8, len(ip4)*8),
}, nil
}
return net.IPNet{
IP: ip,
// single ip
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}, nil
} else if ipNet != nil {
return *ipNet, nil
} else {
return net.IPNet{}, errors.New("invalid CIDR")
}
}
type Policy struct {
// Networks map of networks and prefixes to be loaded
Networks map[string][]Network `yaml:"networks"`
Conditions map[string][]string `yaml:"conditions"`
Challenges map[string]Challenge `yaml:"challenges"`
Rules []Rule `yaml:"rules"`
}

20
lib/policy/rule.go Normal file
View File

@@ -0,0 +1,20 @@
package policy
type RuleAction string
const (
RuleActionPASS RuleAction = "PASS"
RuleActionDENY RuleAction = "DENY"
RuleActionBLOCK RuleAction = "BLOCK"
RuleActionCHALLENGE RuleAction = "CHALLENGE"
RuleActionCHECK RuleAction = "CHECK"
)
type Rule struct {
Name string `yaml:"name"`
Conditions []string `yaml:"conditions"`
Action string `yaml:"action"`
Challenges []string `yaml:"challenges"`
}

View File

@@ -1,4 +1,4 @@
package go_away
package lib
import (
"codeberg.org/meta/gzipped/v2"
@@ -10,8 +10,11 @@ import (
"encoding/json"
"errors"
"fmt"
go_away "git.gammaspectra.live/git/go-away"
"git.gammaspectra.live/git/go-away/challenge"
"git.gammaspectra.live/git/go-away/challenge/inline"
"git.gammaspectra.live/git/go-away/lib/condition"
"git.gammaspectra.live/git/go-away/lib/policy"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
@@ -56,7 +59,7 @@ type RuleState struct {
Name string
Program cel.Program
Action PolicyRuleAction
Action policy.RuleAction
Continue bool
Challenges []string
}
@@ -88,7 +91,7 @@ type ChallengeState struct {
Verify func(key []byte, result string) (bool, error)
}
func NewState(policy Policy, packagePath string, backend http.Handler) (state *State, err error) {
func NewState(p policy.Policy, packagePath string, backend http.Handler) (state *State, err error) {
state = new(State)
state.Client = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
@@ -100,7 +103,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
state.Backend = backend
state.Networks = make(map[string]cidranger.Ranger)
for k, network := range policy.Networks {
for k, network := range p.Networks {
ranger := cidranger.NewPCTrieRanger()
for _, e := range network {
if e.Url != nil {
@@ -129,7 +132,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
state.Challenges = make(map[string]ChallengeState)
for challengeName, p := range policy.Challenges {
for challengeName, p := range p.Challenges {
c := ChallengeState{
Path: fmt.Sprintf("%s/challenge/%s", state.UrlPath, challengeName),
VerifyProbability: p.Runtime.Probability,
@@ -143,7 +146,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
}
assetPath := c.Path + "/static/"
subFs, err := fs.Sub(challengesFs, fmt.Sprintf("challenge/%s/static", challengeName))
subFs, err := fs.Sub(go_away.ChallengeFs, fmt.Sprintf("challenge/%s/static", challengeName))
if err == nil {
c.Static = http.StripPrefix(
assetPath,
@@ -294,6 +297,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
w.WriteHeader(http.StatusOK)
params, _ := json.Marshal(p.Parameters)
context.Background()
err := templates["challenge.mjs"].Execute(w, map[string]any{
"Path": c.Path,
@@ -336,7 +340,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
}
case "wasm":
wasmData, err := challengesFs.ReadFile(fmt.Sprintf("challenge/%s/runtime/%s", challengeName, p.Runtime.Asset))
wasmData, err := go_away.ChallengeFs.ReadFile(fmt.Sprintf("challenge/%s/runtime/%s", challengeName, p.Runtime.Asset))
if err != nil {
return nil, fmt.Errorf("c %s: could not load runtime: %w", challengeName, err)
}
@@ -463,8 +467,8 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
}
var replacements []string
for k, entries := range policy.Conditions {
ast, err := ConditionFromStrings(state.RulesEnv, OperatorOr, entries...)
for k, entries := range p.Conditions {
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...)
if err != nil {
return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err)
}
@@ -479,14 +483,14 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
}
conditionReplacer := strings.NewReplacer(replacements...)
for _, rule := range policy.Rules {
for _, rule := range p.Rules {
r := RuleState{
Name: rule.Name,
Action: PolicyRuleAction(strings.ToUpper(rule.Action)),
Action: policy.RuleAction(strings.ToUpper(rule.Action)),
Challenges: rule.Challenges,
}
if (r.Action == PolicyRuleActionCHALLENGE || r.Action == PolicyRuleActionCHECK) && len(r.Challenges) == 0 {
if (r.Action == policy.RuleActionCHALLENGE || r.Action == policy.RuleActionCHECK) && len(r.Challenges) == 0 {
return nil, fmt.Errorf("no challenges found in rule %s", rule.Name)
}
@@ -497,7 +501,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
conditions = append(conditions, cond)
}
ast, err := ConditionFromStrings(state.RulesEnv, OperatorOr, conditions...)
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, conditions...)
if err != nil {
return nil, fmt.Errorf("rules %s: error compiling conditions: %v", rule.Name, err)
}