Move most code under lib
This commit is contained in:
171
lib/challenge.go
Normal file
171
lib/challenge.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChallengeInformation struct {
|
||||
Name string `json:"name"`
|
||||
Key []byte `json:"key"`
|
||||
Result []byte `json:"result"`
|
||||
|
||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
func (state *State) GetRequestAddress(r *http.Request) net.IP {
|
||||
//TODO: verified upstream
|
||||
ipStr := r.Header.Get("X-Real-Ip")
|
||||
if ipStr == "" {
|
||||
ipStr = strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]
|
||||
}
|
||||
if ipStr == "" {
|
||||
parts := strings.Split(r.RemoteAddr, ":")
|
||||
// drop port
|
||||
ipStr = strings.Join(parts[:len(parts)-1], ":")
|
||||
}
|
||||
return net.ParseIP(ipStr)
|
||||
}
|
||||
|
||||
func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *http.Request) []byte {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte("challenge\x00"))
|
||||
hasher.Write([]byte(name))
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write(state.GetRequestAddress(r).To16())
|
||||
hasher.Write([]byte{0})
|
||||
|
||||
// specific headers
|
||||
for _, k := range []string{
|
||||
"Accept-Language",
|
||||
// General browser information
|
||||
"User-Agent",
|
||||
"Sec-Ch-Ua",
|
||||
"Sec-Ch-Ua-Platform",
|
||||
} {
|
||||
hasher.Write([]byte(r.Header.Get(k)))
|
||||
hasher.Write([]byte{0})
|
||||
}
|
||||
hasher.Write([]byte{0})
|
||||
_ = binary.Write(hasher, binary.LittleEndian, until.UTC().Unix())
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write(state.PublicKey)
|
||||
hasher.Write([]byte{0})
|
||||
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func (state *State) IssueChallengeToken(name string, key, result []byte, until time.Time) (token string, err error) {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.EdDSA,
|
||||
Key: state.PrivateKey,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
expiry := jwt.NumericDate(until.Unix())
|
||||
notBefore := jwt.NumericDate(time.Now().UTC().AddDate(0, 0, -1).Unix())
|
||||
issuedAt := jwt.NumericDate(time.Now().UTC().Unix())
|
||||
|
||||
token, err = jwt.Signed(signer).Claims(ChallengeInformation{
|
||||
Name: name,
|
||||
Key: key,
|
||||
Result: result,
|
||||
Expiry: &expiry,
|
||||
NotBefore: ¬Before,
|
||||
IssuedAt: &issuedAt,
|
||||
}).Serialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (state *State) VerifyChallengeToken(name string, expectedKey []byte, r *http.Request) (ok bool, err error) {
|
||||
c, ok := state.Challenges[name]
|
||||
if !ok {
|
||||
return false, errors.New("challenge not found")
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie(CookiePrefix + name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
token, err := jwt.ParseSigned(cookie.Value, []jose.SignatureAlgorithm{jose.EdDSA})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var i ChallengeInformation
|
||||
err = token.Claims(state.PublicKey, &i)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if i.Name != name {
|
||||
return false, errors.New("token invalid name")
|
||||
}
|
||||
if i.Expiry == nil && i.Expiry.Time().Compare(time.Now()) < 0 {
|
||||
return false, errors.New("token expired")
|
||||
}
|
||||
if i.NotBefore == nil && i.NotBefore.Time().Compare(time.Now()) > 0 {
|
||||
return false, errors.New("token not valid yet")
|
||||
}
|
||||
|
||||
if bytes.Compare(expectedKey, i.Key) != 0 {
|
||||
return false, errors.New("key mismatch")
|
||||
}
|
||||
|
||||
if c.Verify != nil && rand.Float64() < c.VerifyProbability {
|
||||
// random spot check
|
||||
if ok, err := c.Verify(expectedKey, string(i.Result)); err != nil {
|
||||
return false, err
|
||||
} else if !ok {
|
||||
return false, errors.New("failed challenge verification")
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (state *State) ChallengeMod(name string, cb func(ctx context.Context, mod api.Module) error) error {
|
||||
c, ok := state.Challenges[name]
|
||||
if !ok {
|
||||
return errors.New("challenge not found")
|
||||
}
|
||||
if c.RuntimeModule == nil {
|
||||
return errors.New("challenge module is nil")
|
||||
}
|
||||
|
||||
ctx := state.WasmContext
|
||||
mod, err := state.WasmRuntime.InstantiateModule(
|
||||
ctx,
|
||||
c.RuntimeModule,
|
||||
wazero.NewModuleConfig().WithName(name).WithStartFunctions("_initialize"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mod.Close(ctx)
|
||||
err = cb(ctx, mod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
53
lib/condition/condition.go
Normal file
53
lib/condition/condition.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package condition
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/cel-go/cel"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Condition struct {
|
||||
Expression *cel.Ast
|
||||
}
|
||||
|
||||
const (
|
||||
OperatorOr = "||"
|
||||
OperatorAnd = "&&"
|
||||
)
|
||||
|
||||
func FromStrings(env *cel.Env, operator string, conditions ...string) (*cel.Ast, error) {
|
||||
var asts []*cel.Ast
|
||||
for _, c := range conditions {
|
||||
ast, issues := env.Compile(c)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, fmt.Errorf("condition %s: %s", issues.Err(), c)
|
||||
}
|
||||
asts = append(asts, ast)
|
||||
}
|
||||
|
||||
return Merge(env, operator, asts...)
|
||||
}
|
||||
|
||||
func Merge(env *cel.Env, operator string, conditions ...*cel.Ast) (*cel.Ast, error) {
|
||||
if len(conditions) == 0 {
|
||||
return nil, nil
|
||||
} else if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
var asts []string
|
||||
for _, c := range conditions {
|
||||
ast, err := cel.AstToString(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
asts = append(asts, "("+ast+")")
|
||||
}
|
||||
|
||||
condition := strings.Join(asts, " "+operator+" ")
|
||||
ast, issues := env.Compile(condition)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, issues.Err()
|
||||
}
|
||||
|
||||
return ast, nil
|
||||
}
|
27
lib/cookie.go
Normal file
27
lib/cookie.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var CookiePrefix = ".go-away-"
|
||||
|
||||
func SetCookie(name, value string, expiry time.Time, w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Expires: expiry,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
func ClearCookie(name string, w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Expires: time.Now().Add(-1 * time.Hour),
|
||||
MaxAge: -1,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
207
lib/http.go
Normal file
207
lib/http.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"codeberg.org/meta/gzipped/v2"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
go_away "git.gammaspectra.live/git/go-away"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var templates map[string]*template.Template
|
||||
|
||||
var cacheBust string
|
||||
|
||||
// DefaultValidity TODO: adjust
|
||||
const DefaultValidity = time.Hour * 24 * 7
|
||||
|
||||
func init() {
|
||||
|
||||
buf := make([]byte, 16)
|
||||
_, _ = rand.Read(buf)
|
||||
cacheBust = base64.RawURLEncoding.EncodeToString(buf)
|
||||
|
||||
templates = make(map[string]*template.Template)
|
||||
|
||||
dir, err := go_away.TemplatesFs.ReadDir("templates")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, e := range dir {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
data, err := go_away.TemplatesFs.ReadFile(filepath.Join("templates", e.Name()))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tpl := template.New(e.Name())
|
||||
_, err = tpl.Parse(string(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
templates[e.Name()] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
//TODO better matcher! combo ast?
|
||||
env := map[string]any{
|
||||
"remoteAddress": state.GetRequestAddress(r),
|
||||
"userAgent": r.UserAgent(),
|
||||
"path": r.URL.Path,
|
||||
"query": func() map[string]string {
|
||||
result := make(map[string]string)
|
||||
for k, v := range r.URL.Query() {
|
||||
result[k] = strings.Join(v, ",")
|
||||
}
|
||||
return result
|
||||
}(),
|
||||
"headers": func() map[string]string {
|
||||
result := make(map[string]string)
|
||||
for k, v := range r.Header {
|
||||
result[k] = strings.Join(v, ",")
|
||||
}
|
||||
return result
|
||||
}(),
|
||||
}
|
||||
|
||||
for _, rule := range state.Rules {
|
||||
if out, _, err := rule.Program.Eval(env); err != nil {
|
||||
//TODO error
|
||||
panic(err)
|
||||
} else if out != nil && out.Type() == types.BoolType {
|
||||
if out.Equal(types.True) == types.True {
|
||||
switch rule.Action {
|
||||
default:
|
||||
panic(fmt.Errorf("unknown action %s", rule.Action))
|
||||
case policy.RuleActionPASS:
|
||||
state.Backend.ServeHTTP(w, r)
|
||||
return
|
||||
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
|
||||
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
|
||||
|
||||
for _, challengeName := range rule.Challenges {
|
||||
key := state.GetChallengeKeyForRequest(challengeName, expiry, r)
|
||||
ok, err := state.VerifyChallengeToken(challengeName, key, r)
|
||||
if !ok || err != nil {
|
||||
if !errors.Is(err, http.ErrNoCookie) {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
}
|
||||
} else {
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
// we passed the challenge!
|
||||
//TODO log?
|
||||
state.Backend.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// none matched, issue first challenge in priority
|
||||
for _, challengeName := range rule.Challenges {
|
||||
c := state.Challenges[challengeName]
|
||||
if c.Challenge != nil {
|
||||
result := c.Challenge(w, r, state.GetChallengeKeyForRequest(challengeName, expiry, r), expiry)
|
||||
switch result {
|
||||
case ChallengeResultStop:
|
||||
return
|
||||
case ChallengeResultContinue:
|
||||
continue
|
||||
case ChallengeResultPass:
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
// we pass the challenge early!
|
||||
state.Backend.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
panic("challenge not found")
|
||||
}
|
||||
}
|
||||
case policy.RuleActionDENY:
|
||||
//TODO: config error code
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
case policy.RuleActionBLOCK:
|
||||
//TODO: config error code
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nextRule:
|
||||
}
|
||||
|
||||
state.Backend.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
func (state *State) setupRoutes() error {
|
||||
|
||||
state.Mux.HandleFunc("/", state.handleRequest)
|
||||
|
||||
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(go_away.AssetsFs))))
|
||||
|
||||
for challengeName, c := range state.Challenges {
|
||||
if c.Static != nil {
|
||||
state.Mux.Handle("GET "+c.Path+"/static/", c.Static)
|
||||
}
|
||||
|
||||
if c.ChallengeScript != nil {
|
||||
state.Mux.Handle("GET "+c.ChallengeScriptPath, c.ChallengeScript)
|
||||
}
|
||||
|
||||
if c.MakeChallenge != nil {
|
||||
state.Mux.Handle(fmt.Sprintf("POST %s/make-challenge", c.Path), c.MakeChallenge)
|
||||
}
|
||||
|
||||
if c.Verify != nil {
|
||||
state.Mux.HandleFunc(fmt.Sprintf("GET %s/verify-challenge", c.Path), func(w http.ResponseWriter, r *http.Request) {
|
||||
err := func() (err error) {
|
||||
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
|
||||
key := state.GetChallengeKeyForRequest(challengeName, expiry, r)
|
||||
result := r.FormValue("result")
|
||||
|
||||
if ok, err := c.Verify(key, result); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||
if err != nil {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
} else {
|
||||
SetCookie(CookiePrefix+challengeName, token, expiry, w)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, r.FormValue("redirect"), http.StatusTemporaryRedirect)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
19
lib/network/unix.go
Normal file
19
lib/network/unix.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package network
|
||||
|
||||
import "net/http"
|
||||
|
||||
// UnixRoundTripper https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
|
||||
type UnixRoundTripper struct {
|
||||
Transport *http.Transport
|
||||
}
|
||||
|
||||
// RoundTrip set bare minimum stuff
|
||||
func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
if req.Host == "" {
|
||||
req.Host = "localhost"
|
||||
}
|
||||
req.URL.Host = req.Host // proxy error: no Host in request URL
|
||||
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
14
lib/policy/challenge.go
Normal file
14
lib/policy/challenge.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package policy
|
||||
|
||||
type Challenge struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Asset *string `yaml:"asset,omitempty"`
|
||||
Url *string `yaml:"url,omitempty"`
|
||||
|
||||
Parameters map[string]string `json:"parameters,omitempty"`
|
||||
Runtime struct {
|
||||
Mode string `yaml:"mode,omitempty"`
|
||||
Asset string `yaml:"asset,omitempty"`
|
||||
Probability float64 `yaml:"probability,omitempty"`
|
||||
} `yaml:"runtime"`
|
||||
}
|
117
lib/policy/network.go
Normal file
117
lib/policy/network.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/itchyny/gojq"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
Url *string `yaml:"url,omitempty"`
|
||||
File *string `yaml:"file,omitempty"`
|
||||
|
||||
JqPath *string `yaml:"jq-path,omitempty"`
|
||||
Regex *string `yaml:"regex,omitempty"`
|
||||
|
||||
Prefixes []string `yaml:"prefixes,omitempty"`
|
||||
}
|
||||
|
||||
func (n Network) FetchPrefixes(c *http.Client) (output []net.IPNet, err error) {
|
||||
if len(n.Prefixes) > 0 {
|
||||
for _, prefix := range n.Prefixes {
|
||||
ipNet, err := parseCIDROrIP(prefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
output = append(output, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
if n.Url != nil {
|
||||
response, err := c.Get(*n.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", response.StatusCode)
|
||||
}
|
||||
reader = response.Body
|
||||
} else if n.File != nil {
|
||||
file, err := os.Open(*n.File)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
reader = file
|
||||
} else {
|
||||
if len(output) > 0 {
|
||||
return output, nil
|
||||
}
|
||||
return nil, errors.New("no url, file or prefixes specified")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n.JqPath != nil {
|
||||
var jsonData any
|
||||
err = json.Unmarshal(data, &jsonData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, err := gojq.Parse(*n.JqPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iter := query.Run(jsonData)
|
||||
for {
|
||||
value, more := iter.Next()
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
|
||||
if strValue, ok := value.(string); ok {
|
||||
ipNet, err := parseCIDROrIP(strValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
output = append(output, ipNet)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid value from jq-query: %v", value)
|
||||
}
|
||||
}
|
||||
return output, nil
|
||||
} else if n.Regex != nil {
|
||||
expr, err := regexp.Compile(*n.Regex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefixName := expr.SubexpIndex("prefix")
|
||||
if prefixName == -1 {
|
||||
return nil, fmt.Errorf("invalid regex %q: could not find prefix named match", *n.Regex)
|
||||
}
|
||||
matches := expr.FindAllSubmatch(data, -1)
|
||||
for _, match := range matches {
|
||||
matchName := string(match[prefixName])
|
||||
ipNet, err := parseCIDROrIP(matchName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
output = append(output, ipNet)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("no jq-path or regex specified")
|
||||
}
|
||||
return output, nil
|
||||
}
|
46
lib/policy/policy.go
Normal file
46
lib/policy/policy.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func parseCIDROrIP(value string) (net.IPNet, error) {
|
||||
_, ipNet, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return net.IPNet{}, fmt.Errorf("failed to parse CIDR: %s", err)
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return net.IPNet{
|
||||
IP: ip4,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip4)*8, len(ip4)*8),
|
||||
}, nil
|
||||
}
|
||||
return net.IPNet{
|
||||
IP: ip,
|
||||
// single ip
|
||||
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
|
||||
}, nil
|
||||
} else if ipNet != nil {
|
||||
return *ipNet, nil
|
||||
} else {
|
||||
return net.IPNet{}, errors.New("invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
|
||||
// Networks map of networks and prefixes to be loaded
|
||||
Networks map[string][]Network `yaml:"networks"`
|
||||
|
||||
Conditions map[string][]string `yaml:"conditions"`
|
||||
|
||||
Challenges map[string]Challenge `yaml:"challenges"`
|
||||
|
||||
Rules []Rule `yaml:"rules"`
|
||||
}
|
20
lib/policy/rule.go
Normal file
20
lib/policy/rule.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package policy
|
||||
|
||||
type RuleAction string
|
||||
|
||||
const (
|
||||
RuleActionPASS RuleAction = "PASS"
|
||||
RuleActionDENY RuleAction = "DENY"
|
||||
RuleActionBLOCK RuleAction = "BLOCK"
|
||||
RuleActionCHALLENGE RuleAction = "CHALLENGE"
|
||||
RuleActionCHECK RuleAction = "CHECK"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
Name string `yaml:"name"`
|
||||
Conditions []string `yaml:"conditions"`
|
||||
|
||||
Action string `yaml:"action"`
|
||||
|
||||
Challenges []string `yaml:"challenges"`
|
||||
}
|
533
lib/state.go
Normal file
533
lib/state.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"codeberg.org/meta/gzipped/v2"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
go_away "git.gammaspectra.live/git/go-away"
|
||||
"git.gammaspectra.live/git/go-away/challenge"
|
||||
"git.gammaspectra.live/git/go-away/challenge/inline"
|
||||
"git.gammaspectra.live/git/go-away/lib/condition"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||
"github.com/yl2chen/cidranger"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
Client *http.Client
|
||||
PackagePath string
|
||||
UrlPath string
|
||||
Mux *http.ServeMux
|
||||
Backend http.Handler
|
||||
|
||||
Networks map[string]cidranger.Ranger
|
||||
|
||||
WasmRuntime wazero.Runtime
|
||||
WasmContext context.Context
|
||||
|
||||
Challenges map[string]ChallengeState
|
||||
|
||||
RulesEnv *cel.Env
|
||||
|
||||
Rules []RuleState
|
||||
|
||||
PublicKey ed25519.PublicKey
|
||||
PrivateKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
type RuleState struct {
|
||||
Name string
|
||||
|
||||
Program cel.Program
|
||||
Action policy.RuleAction
|
||||
Continue bool
|
||||
Challenges []string
|
||||
}
|
||||
|
||||
type ChallengeResult int
|
||||
|
||||
const (
|
||||
// ChallengeResultStop Stop testing challenges and return
|
||||
ChallengeResultStop = ChallengeResult(iota)
|
||||
// ChallengeResultContinue Test next challenge
|
||||
ChallengeResultContinue
|
||||
// ChallengeResultPass Challenge passed, return and proxy
|
||||
ChallengeResultPass
|
||||
)
|
||||
|
||||
type ChallengeState struct {
|
||||
RuntimeModule wazero.CompiledModule
|
||||
|
||||
Path string
|
||||
|
||||
Static http.Handler
|
||||
Challenge func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult
|
||||
ChallengeScriptPath string
|
||||
ChallengeScript http.Handler
|
||||
MakeChallenge http.Handler
|
||||
VerifyChallenge http.Handler
|
||||
|
||||
VerifyProbability float64
|
||||
Verify func(key []byte, result string) (bool, error)
|
||||
}
|
||||
|
||||
func NewState(p policy.Policy, packagePath string, backend http.Handler) (state *State, err error) {
|
||||
state = new(State)
|
||||
state.Client = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
state.PackagePath = packagePath
|
||||
state.UrlPath = "/.well-known/." + state.PackagePath
|
||||
state.Backend = backend
|
||||
|
||||
state.Networks = make(map[string]cidranger.Ranger)
|
||||
for k, network := range p.Networks {
|
||||
ranger := cidranger.NewPCTrieRanger()
|
||||
for _, e := range network {
|
||||
if e.Url != nil {
|
||||
slog.Debug("loading network url list", "network", k, "url", *e.Url)
|
||||
}
|
||||
prefixes, err := e.FetchPrefixes(state.Client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("networks %s: error fetching prefixes: %v", k, err)
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("loaded network prefixes", "network", k, "count", ranger.Len())
|
||||
|
||||
state.Networks[k] = ranger
|
||||
}
|
||||
|
||||
state.WasmContext = context.Background()
|
||||
state.WasmRuntime = wazero.NewRuntimeWithConfig(state.WasmContext, wazero.NewRuntimeConfigCompiler())
|
||||
wasi_snapshot_preview1.MustInstantiate(state.WasmContext, state.WasmRuntime)
|
||||
|
||||
state.Challenges = make(map[string]ChallengeState)
|
||||
|
||||
for challengeName, p := range p.Challenges {
|
||||
c := ChallengeState{
|
||||
Path: fmt.Sprintf("%s/challenge/%s", state.UrlPath, challengeName),
|
||||
VerifyProbability: p.Runtime.Probability,
|
||||
}
|
||||
|
||||
if c.VerifyProbability <= 0 {
|
||||
//10% default
|
||||
c.VerifyProbability = 0.1
|
||||
} else if c.VerifyProbability > 1.0 {
|
||||
c.VerifyProbability = 1.0
|
||||
}
|
||||
|
||||
assetPath := c.Path + "/static/"
|
||||
subFs, err := fs.Sub(go_away.ChallengeFs, fmt.Sprintf("challenge/%s/static", challengeName))
|
||||
if err == nil {
|
||||
c.Static = http.StripPrefix(
|
||||
assetPath,
|
||||
gzipped.FileServer(gzipped.FS(subFs)),
|
||||
)
|
||||
}
|
||||
|
||||
switch p.Mode {
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown challenge mode: %s", p.Mode)
|
||||
case "http":
|
||||
if p.Url == nil {
|
||||
return nil, fmt.Errorf("challenge %s: missing url", challengeName)
|
||||
}
|
||||
method := p.Parameters["http-method"]
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
httpCode, _ := strconv.Atoi(p.Parameters["http-code"])
|
||||
if httpCode == 0 {
|
||||
httpCode = http.StatusOK
|
||||
}
|
||||
|
||||
expectedCookie := p.Parameters["http-cookie"]
|
||||
|
||||
//todo
|
||||
c.Challenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult {
|
||||
if expectedCookie != "" {
|
||||
if cookie, err := r.Cookie(expectedCookie); err != nil || cookie == nil {
|
||||
// skip check if we don't have cookie or it's expired
|
||||
return ChallengeResultContinue
|
||||
}
|
||||
}
|
||||
|
||||
request, err := http.NewRequest(method, *p.Url, nil)
|
||||
if err != nil {
|
||||
return ChallengeResultContinue
|
||||
}
|
||||
|
||||
request.Header = r.Header
|
||||
response, err := state.Client.Do(request)
|
||||
if err != nil {
|
||||
return ChallengeResultContinue
|
||||
}
|
||||
defer response.Body.Close()
|
||||
defer io.Copy(io.Discard, response.Body)
|
||||
|
||||
if response.StatusCode != httpCode {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
// continue other challenges!
|
||||
return ChallengeResultContinue
|
||||
} else {
|
||||
token, err := state.IssueChallengeToken(challengeName, key, nil, expiry)
|
||||
if err != nil {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
} else {
|
||||
SetCookie(CookiePrefix+challengeName, token, expiry, w)
|
||||
}
|
||||
|
||||
// we passed it!
|
||||
return ChallengeResultPass
|
||||
}
|
||||
}
|
||||
|
||||
case "cookie":
|
||||
c.Challenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult {
|
||||
token, err := state.IssueChallengeToken(challengeName, key, nil, expiry)
|
||||
if err != nil {
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
} else {
|
||||
SetCookie(CookiePrefix+challengeName, token, expiry, w)
|
||||
}
|
||||
// self redirect!
|
||||
//TODO: add redirect loop detect parameter
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return ChallengeResultStop
|
||||
}
|
||||
case "meta-refresh":
|
||||
c.Challenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult {
|
||||
redirectUri := new(url.URL)
|
||||
redirectUri.Path = c.Path + "/verify-challenge"
|
||||
|
||||
values := make(url.Values)
|
||||
values.Set("result", hex.EncodeToString(key))
|
||||
values.Set("redirect", r.URL.String())
|
||||
|
||||
redirectUri.RawQuery = values.Encode()
|
||||
|
||||
// self redirect!
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
|
||||
_ = templates["challenge.gohtml"].Execute(w, map[string]any{
|
||||
"Title": "Bot",
|
||||
"Path": state.UrlPath,
|
||||
"Random": cacheBust,
|
||||
"Challenge": "",
|
||||
"Meta": map[string]string{
|
||||
"refresh": "0; url=" + redirectUri.String(),
|
||||
},
|
||||
})
|
||||
return ChallengeResultStop
|
||||
}
|
||||
case "header-refresh":
|
||||
c.Challenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult {
|
||||
redirectUri := new(url.URL)
|
||||
redirectUri.Path = c.Path + "/verify-challenge"
|
||||
|
||||
values := make(url.Values)
|
||||
values.Set("result", hex.EncodeToString(key))
|
||||
values.Set("redirect", r.URL.String())
|
||||
|
||||
redirectUri.RawQuery = values.Encode()
|
||||
|
||||
// self redirect!
|
||||
w.Header().Set("Refresh", "0; url="+redirectUri.String())
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
|
||||
_ = templates["challenge.gohtml"].Execute(w, map[string]any{
|
||||
"Title": "Bot",
|
||||
"Path": state.UrlPath,
|
||||
"Random": cacheBust,
|
||||
"Challenge": "",
|
||||
})
|
||||
return ChallengeResultStop
|
||||
}
|
||||
case "js":
|
||||
c.Challenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) ChallengeResult {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
|
||||
err := templates["challenge.gohtml"].Execute(w, map[string]any{
|
||||
"Title": "Bot",
|
||||
"Path": state.UrlPath,
|
||||
"Random": cacheBust,
|
||||
"Challenge": challengeName,
|
||||
})
|
||||
if err != nil {
|
||||
//TODO: log
|
||||
}
|
||||
return ChallengeResultStop
|
||||
}
|
||||
c.ChallengeScriptPath = c.Path + "/challenge.mjs"
|
||||
c.ChallengeScript = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/javascript; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
params, _ := json.Marshal(p.Parameters)
|
||||
context.Background()
|
||||
|
||||
err := templates["challenge.mjs"].Execute(w, map[string]any{
|
||||
"Path": c.Path,
|
||||
"Parameters": string(params),
|
||||
"Random": cacheBust,
|
||||
"Challenge": challengeName,
|
||||
"ChallengeScript": func() string {
|
||||
if p.Asset != nil {
|
||||
return assetPath + *p.Asset
|
||||
} else if p.Url != nil {
|
||||
return *p.Url
|
||||
} else {
|
||||
panic("not implemented")
|
||||
}
|
||||
}(),
|
||||
})
|
||||
if err != nil {
|
||||
//TODO: log
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// how to runtime
|
||||
switch p.Runtime.Mode {
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown challenge runtime mode: %s", p.Runtime.Mode)
|
||||
case "":
|
||||
case "http":
|
||||
case "key":
|
||||
c.Verify = func(key []byte, result string) (bool, error) {
|
||||
resultBytes, err := hex.DecodeString(result)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(resultBytes, key) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
case "wasm":
|
||||
wasmData, err := go_away.ChallengeFs.ReadFile(fmt.Sprintf("challenge/%s/runtime/%s", challengeName, p.Runtime.Asset))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("c %s: could not load runtime: %w", challengeName, err)
|
||||
}
|
||||
c.RuntimeModule, err = state.WasmRuntime.CompileModule(state.WasmContext, wasmData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("c %s: compiling runtime: %w", challengeName, err)
|
||||
}
|
||||
|
||||
c.MakeChallenge = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := state.ChallengeMod(challengeName, func(ctx context.Context, mod api.Module) (err error) {
|
||||
|
||||
in := challenge.MakeChallengeInput{
|
||||
Key: state.GetChallengeKeyForRequest(challengeName, time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity), r),
|
||||
Parameters: p.Parameters,
|
||||
Headers: inline.MIMEHeader(r.Header),
|
||||
}
|
||||
in.Data, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := challenge.MakeChallengeCall(state.WasmContext, mod, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set output headers
|
||||
for k, v := range out.Headers {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out.Data)))
|
||||
w.WriteHeader(out.Code)
|
||||
_, _ = w.Write(out.Data)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
c.Verify = func(key []byte, result string) (ok bool, err error) {
|
||||
err = state.ChallengeMod(challengeName, func(ctx context.Context, mod api.Module) (err error) {
|
||||
in := challenge.VerifyChallengeInput{
|
||||
Key: key,
|
||||
Parameters: p.Parameters,
|
||||
Result: []byte(result),
|
||||
}
|
||||
|
||||
out, err := challenge.VerifyChallengeCall(state.WasmContext, mod, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if out == challenge.VerifyChallengeOutputError {
|
||||
return errors.New("error checking challenge")
|
||||
}
|
||||
ok = out == challenge.VerifyChallengeOutputOK
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
state.Challenges[challengeName] = c
|
||||
}
|
||||
|
||||
state.RulesEnv, err = cel.NewEnv(
|
||||
cel.DefaultUTCTimeZone(true),
|
||||
cel.Variable("remoteAddress", cel.BytesType),
|
||||
cel.Variable("userAgent", cel.StringType),
|
||||
cel.Variable("path", cel.StringType),
|
||||
cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)),
|
||||
// http.Header
|
||||
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
|
||||
//TODO: dynamic type?
|
||||
cel.Function("inNetwork",
|
||||
cel.Overload("inNetwork_string_ip",
|
||||
[]*cel.Type{cel.StringType, cel.AnyType},
|
||||
cel.BoolType,
|
||||
cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val {
|
||||
var ip net.IP
|
||||
switch v := rhs.Value().(type) {
|
||||
case []byte:
|
||||
ip = v
|
||||
case net.IP:
|
||||
ip = v
|
||||
case string:
|
||||
ip = net.ParseIP(v)
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
panic(fmt.Errorf("invalid ip %v", rhs.Value()))
|
||||
}
|
||||
|
||||
val, ok := lhs.Value().(string)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("invalid value %v", lhs.Value()))
|
||||
}
|
||||
|
||||
network, ok := state.Networks[val]
|
||||
if !ok {
|
||||
_, ipNet, err := net.ParseCIDR(val)
|
||||
if err != nil {
|
||||
panic("network not found")
|
||||
}
|
||||
return types.Bool(ipNet.Contains(ip))
|
||||
} else {
|
||||
ok, err := network.Contains(ip)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return types.Bool(ok)
|
||||
}
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var replacements []string
|
||||
for k, entries := range p.Conditions {
|
||||
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err)
|
||||
}
|
||||
|
||||
cond, err := cel.AstToString(ast)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err)
|
||||
}
|
||||
|
||||
replacements = append(replacements, fmt.Sprintf("($%s)", k))
|
||||
replacements = append(replacements, "("+cond+")")
|
||||
}
|
||||
conditionReplacer := strings.NewReplacer(replacements...)
|
||||
|
||||
for _, rule := range p.Rules {
|
||||
r := RuleState{
|
||||
Name: rule.Name,
|
||||
Action: policy.RuleAction(strings.ToUpper(rule.Action)),
|
||||
Challenges: rule.Challenges,
|
||||
}
|
||||
|
||||
if (r.Action == policy.RuleActionCHALLENGE || r.Action == policy.RuleActionCHECK) && len(r.Challenges) == 0 {
|
||||
return nil, fmt.Errorf("no challenges found in rule %s", rule.Name)
|
||||
}
|
||||
|
||||
// allow nesting
|
||||
var conditions []string
|
||||
for _, cond := range rule.Conditions {
|
||||
cond = conditionReplacer.Replace(cond)
|
||||
conditions = append(conditions, cond)
|
||||
}
|
||||
|
||||
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, conditions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rules %s: error compiling conditions: %v", rule.Name, err)
|
||||
}
|
||||
program, err := state.RulesEnv.Program(ast)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rules %s: error compiling program: %v", rule.Name, err)
|
||||
}
|
||||
r.Program = program
|
||||
|
||||
state.Rules = append(state.Rules, r)
|
||||
}
|
||||
|
||||
state.Mux = http.NewServeMux()
|
||||
|
||||
state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = state.setupRoutes(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
state.Mux.ServeHTTP(w, r)
|
||||
}
|
Reference in New Issue
Block a user