Move most code under lib
This commit is contained in:
10
cmd/away.go
10
cmd/away.go
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
go_away "git.gammaspectra.live/git/go-away"
|
||||
"git.gammaspectra.live/git/go-away/lib"
|
||||
"git.gammaspectra.live/git/go-away/lib/network"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"gopkg.in/yaml.v3"
|
||||
"log"
|
||||
"log/slog"
|
||||
@@ -36,7 +38,7 @@ func makeReverseProxy(target string) (http.Handler, error) {
|
||||
return dialer.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
// tell transport how to handle the unix url scheme
|
||||
transport.RegisterProtocol("unix", go_away.UnixRoundTripper{Transport: transport})
|
||||
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
|
||||
}
|
||||
|
||||
rp := httputil.NewSingleHostReverseProxy(u)
|
||||
@@ -116,7 +118,7 @@ func main() {
|
||||
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
||||
}
|
||||
|
||||
var policy go_away.Policy
|
||||
var policy policy.Policy
|
||||
|
||||
if err = yaml.Unmarshal(policyData, &policy); err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
||||
@@ -127,7 +129,7 @@ func main() {
|
||||
log.Fatal(fmt.Errorf("failed to create reverse proxy for %s: %w", *target, err))
|
||||
}
|
||||
|
||||
state, err := go_away.NewState(policy, "git.gammaspectra.live/git/go-away/cmd", backend)
|
||||
state, err := lib.NewState(policy, "git.gammaspectra.live/git/go-away/cmd", backend)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("failed to create state: %w", err))
|
||||
|
12
embed.go
Normal file
12
embed.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package go_away
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed assets
|
||||
var AssetsFs embed.FS
|
||||
|
||||
//go:embed challenge
|
||||
var ChallengeFs embed.FS
|
||||
|
||||
//go:embed templates
|
||||
var TemplatesFs embed.FS
|
@@ -1,4 +1,4 @@
|
||||
package go_away
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
@@ -1,4 +1,4 @@
|
||||
package go_away
|
||||
package condition
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
OperatorAnd = "&&"
|
||||
)
|
||||
|
||||
func ConditionFromStrings(env *cel.Env, operator string, conditions ...string) (*cel.Ast, error) {
|
||||
func FromStrings(env *cel.Env, operator string, conditions ...string) (*cel.Ast, error) {
|
||||
var asts []*cel.Ast
|
||||
for _, c := range conditions {
|
||||
ast, issues := env.Compile(c)
|
||||
@@ -25,10 +25,10 @@ func ConditionFromStrings(env *cel.Env, operator string, conditions ...string) (
|
||||
asts = append(asts, ast)
|
||||
}
|
||||
|
||||
return MergeConditions(env, operator, asts...)
|
||||
return Merge(env, operator, asts...)
|
||||
}
|
||||
|
||||
func MergeConditions(env *cel.Env, operator string, conditions ...*cel.Ast) (*cel.Ast, error) {
|
||||
func Merge(env *cel.Env, operator string, conditions ...*cel.Ast) (*cel.Ast, error) {
|
||||
if len(conditions) == 0 {
|
||||
return nil, nil
|
||||
} else if len(conditions) == 1 {
|
@@ -1,4 +1,4 @@
|
||||
package go_away
|
||||
package lib
|
||||
|
||||
import (
|
||||
"net/http"
|
@@ -1,12 +1,13 @@
|
||||
package go_away
|
||||
package lib
|
||||
|
||||
import (
|
||||
"codeberg.org/meta/gzipped/v2"
|
||||
"crypto/rand"
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
go_away "git.gammaspectra.live/git/go-away"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"html/template"
|
||||
"net/http"
|
||||
@@ -15,15 +16,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed assets
|
||||
var assetsFs embed.FS
|
||||
|
||||
//go:embed challenge
|
||||
var challengesFs embed.FS
|
||||
|
||||
//go:embed templates
|
||||
var templatesFs embed.FS
|
||||
|
||||
var templates map[string]*template.Template
|
||||
|
||||
var cacheBust string
|
||||
@@ -39,7 +31,7 @@ func init() {
|
||||
|
||||
templates = make(map[string]*template.Template)
|
||||
|
||||
dir, err := templatesFs.ReadDir("templates")
|
||||
dir, err := go_away.TemplatesFs.ReadDir("templates")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -47,7 +39,7 @@ func init() {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
data, err := templatesFs.ReadFile(filepath.Join("templates", e.Name()))
|
||||
data, err := go_away.TemplatesFs.ReadFile(filepath.Join("templates", e.Name()))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -92,10 +84,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
switch rule.Action {
|
||||
default:
|
||||
panic(fmt.Errorf("unknown action %s", rule.Action))
|
||||
case PolicyRuleActionPASS:
|
||||
case policy.RuleActionPASS:
|
||||
state.Backend.ServeHTTP(w, r)
|
||||
return
|
||||
case PolicyRuleActionCHALLENGE, PolicyRuleActionCHECK:
|
||||
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
|
||||
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
|
||||
|
||||
for _, challengeName := range rule.Challenges {
|
||||
@@ -106,7 +98,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
}
|
||||
} else {
|
||||
if rule.Action == PolicyRuleActionCHECK {
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
// we passed the challenge!
|
||||
@@ -127,7 +119,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
case ChallengeResultContinue:
|
||||
continue
|
||||
case ChallengeResultPass:
|
||||
if rule.Action == PolicyRuleActionCHECK {
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
// we pass the challenge early!
|
||||
@@ -138,11 +130,11 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
panic("challenge not found")
|
||||
}
|
||||
}
|
||||
case PolicyRuleActionDENY:
|
||||
case policy.RuleActionDENY:
|
||||
//TODO: config error code
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
case PolicyRuleActionBLOCK:
|
||||
case policy.RuleActionBLOCK:
|
||||
//TODO: config error code
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
return
|
||||
@@ -161,7 +153,7 @@ func (state *State) setupRoutes() error {
|
||||
|
||||
state.Mux.HandleFunc("/", state.handleRequest)
|
||||
|
||||
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(assetsFs))))
|
||||
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(go_away.AssetsFs))))
|
||||
|
||||
for challengeName, c := range state.Challenges {
|
||||
if c.Static != nil {
|
||||
@@ -213,19 +205,3 @@ func (state *State) setupRoutes() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnixRoundTripper https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
|
||||
type UnixRoundTripper struct {
|
||||
Transport *http.Transport
|
||||
}
|
||||
|
||||
// RoundTrip set bare minimum stuff
|
||||
func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
if req.Host == "" {
|
||||
req.Host = "localhost"
|
||||
}
|
||||
req.URL.Host = req.Host // proxy error: no Host in request URL
|
||||
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
19
lib/network/unix.go
Normal file
19
lib/network/unix.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package network
|
||||
|
||||
import "net/http"
|
||||
|
||||
// UnixRoundTripper https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
|
||||
type UnixRoundTripper struct {
|
||||
Transport *http.Transport
|
||||
}
|
||||
|
||||
// RoundTrip set bare minimum stuff
|
||||
func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
if req.Host == "" {
|
||||
req.Host = "localhost"
|
||||
}
|
||||
req.URL.Host = req.Host // proxy error: no Host in request URL
|
||||
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
14
lib/policy/challenge.go
Normal file
14
lib/policy/challenge.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package policy
|
||||
|
||||
type Challenge struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Asset *string `yaml:"asset,omitempty"`
|
||||
Url *string `yaml:"url,omitempty"`
|
||||
|
||||
Parameters map[string]string `json:"parameters,omitempty"`
|
||||
Runtime struct {
|
||||
Mode string `yaml:"mode,omitempty"`
|
||||
Asset string `yaml:"asset,omitempty"`
|
||||
Probability float64 `yaml:"probability,omitempty"`
|
||||
} `yaml:"runtime"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package go_away
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -12,78 +12,7 @@ import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
func parseCIDROrIP(value string) (net.IPNet, error) {
|
||||
_, ipNet, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return net.IPNet{}, fmt.Errorf("failed to parse CIDR: %s", err)
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return net.IPNet{
|
||||
IP: ip4,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip4)*8, len(ip4)*8),
|
||||
}, nil
|
||||
}
|
||||
return net.IPNet{
|
||||
IP: ip,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
|
||||
}, nil
|
||||
} else if ipNet != nil {
|
||||
return *ipNet, nil
|
||||
} else {
|
||||
return net.IPNet{}, errors.New("invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
|
||||
// Networks map of networks and prefixes to be loaded
|
||||
Networks map[string][]PolicyNetwork `yaml:"networks"`
|
||||
|
||||
Conditions map[string][]string `yaml:"conditions"`
|
||||
|
||||
Challenges map[string]PolicyChallenge `yaml:"challenges"`
|
||||
|
||||
Rules []PolicyRule `yaml:"rules"`
|
||||
}
|
||||
|
||||
type PolicyRuleAction string
|
||||
|
||||
const (
|
||||
PolicyRuleActionPASS PolicyRuleAction = "PASS"
|
||||
PolicyRuleActionDENY PolicyRuleAction = "DENY"
|
||||
PolicyRuleActionBLOCK PolicyRuleAction = "BLOCK"
|
||||
PolicyRuleActionCHALLENGE PolicyRuleAction = "CHALLENGE"
|
||||
PolicyRuleActionCHECK PolicyRuleAction = "CHECK"
|
||||
)
|
||||
|
||||
type PolicyRule struct {
|
||||
Name string `yaml:"name"`
|
||||
Conditions []string `yaml:"conditions"`
|
||||
|
||||
Action string `yaml:"action"`
|
||||
|
||||
Challenges []string `yaml:"challenges"`
|
||||
}
|
||||
|
||||
type PolicyChallenge struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Asset *string `yaml:"asset,omitempty"`
|
||||
Url *string `yaml:"url,omitempty"`
|
||||
|
||||
Parameters map[string]string `json:"parameters,omitempty"`
|
||||
Runtime struct {
|
||||
Mode string `yaml:"mode,omitempty"`
|
||||
Asset string `yaml:"asset,omitempty"`
|
||||
Probability float64 `yaml:"probability,omitempty"`
|
||||
} `yaml:"runtime"`
|
||||
}
|
||||
|
||||
type PolicyNetwork struct {
|
||||
type Network struct {
|
||||
Url *string `yaml:"url,omitempty"`
|
||||
File *string `yaml:"file,omitempty"`
|
||||
|
||||
@@ -93,7 +22,7 @@ type PolicyNetwork struct {
|
||||
Prefixes []string `yaml:"prefixes,omitempty"`
|
||||
}
|
||||
|
||||
func (n PolicyNetwork) FetchPrefixes(c *http.Client) (output []net.IPNet, err error) {
|
||||
func (n Network) FetchPrefixes(c *http.Client) (output []net.IPNet, err error) {
|
||||
if len(n.Prefixes) > 0 {
|
||||
for _, prefix := range n.Prefixes {
|
||||
ipNet, err := parseCIDROrIP(prefix)
|
46
lib/policy/policy.go
Normal file
46
lib/policy/policy.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func parseCIDROrIP(value string) (net.IPNet, error) {
|
||||
_, ipNet, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return net.IPNet{}, fmt.Errorf("failed to parse CIDR: %s", err)
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return net.IPNet{
|
||||
IP: ip4,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip4)*8, len(ip4)*8),
|
||||
}, nil
|
||||
}
|
||||
return net.IPNet{
|
||||
IP: ip,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
|
||||
}, nil
|
||||
} else if ipNet != nil {
|
||||
return *ipNet, nil
|
||||
} else {
|
||||
return net.IPNet{}, errors.New("invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
|
||||
// Networks map of networks and prefixes to be loaded
|
||||
Networks map[string][]Network `yaml:"networks"`
|
||||
|
||||
Conditions map[string][]string `yaml:"conditions"`
|
||||
|
||||
Challenges map[string]Challenge `yaml:"challenges"`
|
||||
|
||||
Rules []Rule `yaml:"rules"`
|
||||
}
|
20
lib/policy/rule.go
Normal file
20
lib/policy/rule.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package policy
|
||||
|
||||
type RuleAction string
|
||||
|
||||
const (
|
||||
RuleActionPASS RuleAction = "PASS"
|
||||
RuleActionDENY RuleAction = "DENY"
|
||||
RuleActionBLOCK RuleAction = "BLOCK"
|
||||
RuleActionCHALLENGE RuleAction = "CHALLENGE"
|
||||
RuleActionCHECK RuleAction = "CHECK"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
Name string `yaml:"name"`
|
||||
Conditions []string `yaml:"conditions"`
|
||||
|
||||
Action string `yaml:"action"`
|
||||
|
||||
Challenges []string `yaml:"challenges"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package go_away
|
||||
package lib
|
||||
|
||||
import (
|
||||
"codeberg.org/meta/gzipped/v2"
|
||||
@@ -10,8 +10,11 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
go_away "git.gammaspectra.live/git/go-away"
|
||||
"git.gammaspectra.live/git/go-away/challenge"
|
||||
"git.gammaspectra.live/git/go-away/challenge/inline"
|
||||
"git.gammaspectra.live/git/go-away/lib/condition"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
@@ -56,7 +59,7 @@ type RuleState struct {
|
||||
Name string
|
||||
|
||||
Program cel.Program
|
||||
Action PolicyRuleAction
|
||||
Action policy.RuleAction
|
||||
Continue bool
|
||||
Challenges []string
|
||||
}
|
||||
@@ -88,7 +91,7 @@ type ChallengeState struct {
|
||||
Verify func(key []byte, result string) (bool, error)
|
||||
}
|
||||
|
||||
func NewState(policy Policy, packagePath string, backend http.Handler) (state *State, err error) {
|
||||
func NewState(p policy.Policy, packagePath string, backend http.Handler) (state *State, err error) {
|
||||
state = new(State)
|
||||
state.Client = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
@@ -100,7 +103,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
state.Backend = backend
|
||||
|
||||
state.Networks = make(map[string]cidranger.Ranger)
|
||||
for k, network := range policy.Networks {
|
||||
for k, network := range p.Networks {
|
||||
ranger := cidranger.NewPCTrieRanger()
|
||||
for _, e := range network {
|
||||
if e.Url != nil {
|
||||
@@ -129,7 +132,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
|
||||
state.Challenges = make(map[string]ChallengeState)
|
||||
|
||||
for challengeName, p := range policy.Challenges {
|
||||
for challengeName, p := range p.Challenges {
|
||||
c := ChallengeState{
|
||||
Path: fmt.Sprintf("%s/challenge/%s", state.UrlPath, challengeName),
|
||||
VerifyProbability: p.Runtime.Probability,
|
||||
@@ -143,7 +146,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
}
|
||||
|
||||
assetPath := c.Path + "/static/"
|
||||
subFs, err := fs.Sub(challengesFs, fmt.Sprintf("challenge/%s/static", challengeName))
|
||||
subFs, err := fs.Sub(go_away.ChallengeFs, fmt.Sprintf("challenge/%s/static", challengeName))
|
||||
if err == nil {
|
||||
c.Static = http.StripPrefix(
|
||||
assetPath,
|
||||
@@ -294,6 +297,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
params, _ := json.Marshal(p.Parameters)
|
||||
context.Background()
|
||||
|
||||
err := templates["challenge.mjs"].Execute(w, map[string]any{
|
||||
"Path": c.Path,
|
||||
@@ -336,7 +340,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
}
|
||||
|
||||
case "wasm":
|
||||
wasmData, err := challengesFs.ReadFile(fmt.Sprintf("challenge/%s/runtime/%s", challengeName, p.Runtime.Asset))
|
||||
wasmData, err := go_away.ChallengeFs.ReadFile(fmt.Sprintf("challenge/%s/runtime/%s", challengeName, p.Runtime.Asset))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("c %s: could not load runtime: %w", challengeName, err)
|
||||
}
|
||||
@@ -463,8 +467,8 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
}
|
||||
|
||||
var replacements []string
|
||||
for k, entries := range policy.Conditions {
|
||||
ast, err := ConditionFromStrings(state.RulesEnv, OperatorOr, entries...)
|
||||
for k, entries := range p.Conditions {
|
||||
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err)
|
||||
}
|
||||
@@ -479,14 +483,14 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
}
|
||||
conditionReplacer := strings.NewReplacer(replacements...)
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
for _, rule := range p.Rules {
|
||||
r := RuleState{
|
||||
Name: rule.Name,
|
||||
Action: PolicyRuleAction(strings.ToUpper(rule.Action)),
|
||||
Action: policy.RuleAction(strings.ToUpper(rule.Action)),
|
||||
Challenges: rule.Challenges,
|
||||
}
|
||||
|
||||
if (r.Action == PolicyRuleActionCHALLENGE || r.Action == PolicyRuleActionCHECK) && len(r.Challenges) == 0 {
|
||||
if (r.Action == policy.RuleActionCHALLENGE || r.Action == policy.RuleActionCHECK) && len(r.Challenges) == 0 {
|
||||
return nil, fmt.Errorf("no challenges found in rule %s", rule.Name)
|
||||
}
|
||||
|
||||
@@ -497,7 +501,7 @@ func NewState(policy Policy, packagePath string, backend http.Handler) (state *S
|
||||
conditions = append(conditions, cond)
|
||||
}
|
||||
|
||||
ast, err := ConditionFromStrings(state.RulesEnv, OperatorOr, conditions...)
|
||||
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, conditions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rules %s: error compiling conditions: %v", rule.Name, err)
|
||||
}
|
Reference in New Issue
Block a user