Allow conditions on challenges, and early hint deadline

This commit is contained in:
WeebDataHoarder
2025-04-08 11:40:16 +02:00
parent b0ab78ef65
commit baf9df9f0a
7 changed files with 197 additions and 115 deletions

View File

@@ -116,12 +116,16 @@ challenges:
# Challenges with a redirect via Link header with rel=preload and early hints (non-JS, requires HTTP parsing, fetching and logic)
# Works on HTTP/2 and above!
self-preload-link:
# doesn't seem to work reliably on other stuff that firefox
# userAgent.contains("Firefox/") &&
condition: '"Sec-Fetch-Mode" in headers && headers["Sec-Fetch-Mode"] == "navigate"'
mode: "preload-link"
runtime:
# verifies that result = key
mode: "key"
probability: 0.1
parameters:
preload-early-hint-deadline: 3s
key-code: 200
key-mime: text/css
key-content: ""

View File

@@ -20,6 +20,18 @@ type ChallengeInformation struct {
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
}
func getRequestScheme(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto == "http" || proto == "https" {
return proto
}
if r.TLS != nil {
return "https"
}
return "http"
}
func getRequestAddress(r *http.Request, clientHeader string) net.IP {
var ipStr string
if clientHeader != "" {

View File

@@ -7,6 +7,7 @@ import (
"git.gammaspectra.live/git/go-away/utils"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/cel-go/cel"
"math/rand/v2"
"net/http"
"time"
@@ -26,9 +27,10 @@ const (
type Id int
type Challenge struct {
Id Id
Name string
Path string
Id Id
Program cel.Program
Name string
Path string
Verify func(key []byte, result string, r *http.Request) (bool, error)
VerifyProbability float64
@@ -86,6 +88,7 @@ type VerifyResult int
const (
VerifyResultNONE = VerifyResult(iota)
VerifyResultFAIL
VerifyResultSKIP
// VerifyResultPASS Client just passed this challenge
VerifyResultPASS
@@ -95,7 +98,7 @@ const (
)
func (r VerifyResult) Ok() bool {
return r > VerifyResultFAIL
return r >= VerifyResultPASS
}
func (r VerifyResult) String() string {
@@ -104,6 +107,8 @@ func (r VerifyResult) String() string {
return "NONE"
case VerifyResultFAIL:
return "FAIL"
case VerifyResultSKIP:
return "SKIP"
case VerifyResultPASS:
return "PASS"
case VerifyResultOK:

69
lib/conditions.go Normal file
View File

@@ -0,0 +1,69 @@
package lib
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"net"
)
func (state *State) initConditions() (err error) {
state.RulesEnv, err = cel.NewEnv(
cel.DefaultUTCTimeZone(true),
cel.Variable("remoteAddress", cel.BytesType),
cel.Variable("host", cel.StringType),
cel.Variable("method", cel.StringType),
cel.Variable("userAgent", cel.StringType),
cel.Variable("path", cel.StringType),
cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)),
// http.Header
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
//TODO: dynamic type?
cel.Function("inNetwork",
cel.Overload("inNetwork_string_ip",
[]*cel.Type{cel.StringType, cel.AnyType},
cel.BoolType,
cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val {
var ip net.IP
switch v := rhs.Value().(type) {
case []byte:
ip = v
case net.IP:
ip = v
case string:
ip = net.ParseIP(v)
}
if ip == nil {
panic(fmt.Errorf("invalid ip %v", rhs.Value()))
}
val, ok := lhs.Value().(string)
if !ok {
panic(fmt.Errorf("invalid value %v", lhs.Value()))
}
network, ok := state.Networks[val]
if !ok {
_, ipNet, err := net.ParseCIDR(val)
if err != nil {
panic("network not found")
}
return types.Bool(ipNet.Contains(ip))
} else {
ok, err := network.Contains(ip)
if err != nil {
panic(err)
}
return types.Bool(ok)
}
}),
),
),
)
if err != nil {
return err
}
return nil
}

View File

@@ -159,29 +159,6 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
start := time.Now()
//TODO better matcher! combo ast?
env := map[string]any{
"host": host,
"method": r.Method,
"remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader),
"userAgent": r.UserAgent(),
"path": r.URL.Path,
"query": func() map[string]string {
result := make(map[string]string)
for k, v := range r.URL.Query() {
result[k] = strings.Join(v, ",")
}
return result
}(),
"headers": func() map[string]string {
result := make(map[string]string)
for k, v := range r.Header {
result[k] = strings.Join(v, ",")
}
return result
}(),
}
state.addTiming(w, "rule-env", "Setup the rule environment", time.Since(start))
var (
@@ -211,7 +188,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
continue
}
start = time.Now()
out, _, err := rule.Program.Eval(env)
out, _, err := rule.Program.Eval(data.ProgramEnv)
ruleEvalDuration += time.Since(start)
if err != nil {
@@ -230,7 +207,6 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
serve()
return
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
for _, challengeId := range rule.Challenges {
if result := data.Challenges[challengeId]; !result.Ok() {
continue
@@ -249,6 +225,11 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
// none matched, issue first challenge in priority
for _, challengeId := range rule.Challenges {
result := data.Challenges[challengeId]
if result.Ok() || result == challenge.VerifyResultSKIP {
// skip already ok'd challenges for some reason, and also skip skipped challenges
continue
}
c := state.Challenges[challengeId]
if c.ServeChallenge != nil {
result := c.ServeChallenge(w, r, state.GetChallengeKeyForRequest(c.Name, data.Expires, r), data.Expires)
@@ -264,7 +245,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
}
state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name)
data.Challenges[c.Id] = challenge.VerifyResultOK
// set pass if caller didn't set one
if !data.Challenges[c.Id].Ok() {
data.Challenges[c.Id] = challenge.VerifyResultPASS
}
// we pass the challenge early!
lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name)
@@ -425,6 +409,27 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, _ = rand.Read(data.Id[:])
data.Challenges = make(map[challenge.Id]challenge.VerifyResult, len(state.Challenges))
data.Expires = time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
data.ProgramEnv = map[string]any{
"host": r.Host,
"method": r.Method,
"remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader),
"userAgent": r.UserAgent(),
"path": r.URL.Path,
"query": func() map[string]string {
result := make(map[string]string)
for k, v := range r.URL.Query() {
result[k] = strings.Join(v, ",")
}
return result
}(),
"headers": func() map[string]string {
result := make(map[string]string)
for k, v := range r.Header {
result[k] = strings.Join(v, ",")
}
return result
}(),
}
for _, c := range state.Challenges {
key := state.GetChallengeKeyForRequest(c.Name, data.Expires, r)
@@ -433,6 +438,21 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// clear invalid cookie
utils.ClearCookie(utils.CookiePrefix+c.Name, w)
}
// prevent the challenge if not solved
if !result.Ok() && c.Program != nil {
out, _, err := c.Program.Eval(data.ProgramEnv)
// verify eligibility
if err != nil {
state.logger(r).Error(err.Error(), "challenge", c.Name)
} else if out != nil && out.Type() == types.BoolType {
if out.Equal(types.True) != types.True {
// skip challenge match!
result = challenge.VerifyResultSKIP
continue
}
}
}
data.Challenges[c.Id] = result
}
@@ -449,6 +469,7 @@ func RequestDataFromContext(ctx context.Context) *RequestData {
type RequestData struct {
Id [16]byte
ProgramEnv map[string]any
Expires time.Time
Challenges map[challenge.Id]challenge.VerifyResult
}

View File

@@ -1,9 +1,10 @@
package policy
type Challenge struct {
Mode string `yaml:"mode"`
Asset *string `yaml:"asset,omitempty"`
Url *string `yaml:"url,omitempty"`
Conditions []string `yaml:"conditions"`
Mode string `yaml:"mode"`
Asset *string `yaml:"asset,omitempty"`
Url *string `yaml:"url,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
Runtime struct {

View File

@@ -20,15 +20,12 @@ import (
"git.gammaspectra.live/git/go-away/utils"
"git.gammaspectra.live/git/go-away/utils/inline"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/tetratelabs/wazero/api"
"github.com/yl2chen/cidranger"
"html/template"
"io"
"io/fs"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
@@ -197,13 +194,56 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
state.Wasm = wasm.NewRunner(true)
err = state.initConditions()
if err != nil {
return nil, err
}
var replacements []string
for k, entries := range p.Conditions {
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...)
if err != nil {
return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err)
}
cond, err := cel.AstToString(ast)
if err != nil {
return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err)
}
replacements = append(replacements, fmt.Sprintf("($%s)", k))
replacements = append(replacements, "("+cond+")")
}
conditionReplacer := strings.NewReplacer(replacements...)
state.Challenges = make(map[challenge.Id]challenge.Challenge)
idCounter := challenge.Id(1)
for challengeName, p := range p.Challenges {
// allow nesting
var conditions []string
for _, cond := range p.Conditions {
cond = conditionReplacer.Replace(cond)
conditions = append(conditions, cond)
}
var program cel.Program
if len(conditions) > 0 {
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, conditions...)
if err != nil {
return nil, fmt.Errorf("challenge %s: error compiling conditions: %v", challengeName, err)
}
program, err = state.RulesEnv.Program(ast)
if err != nil {
return nil, fmt.Errorf("challenge %s: error compiling program: %v", challengeName, err)
}
}
c := challenge.Challenge{
Id: idCounter,
Program: program,
Name: challengeName,
Path: fmt.Sprintf("%s/challenge/%s", state.UrlPath, challengeName),
VerifyProbability: p.Runtime.Probability,
@@ -383,13 +423,16 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
return challenge.ResultStop
}
case "preload-link":
deadline := time.Second * 5
deadline, _ := time.ParseDuration(p.Parameters["preload-early-hint-deadline"])
if deadline == 0 {
deadline = time.Second * 3
}
c.ServeChallenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) challenge.Result {
// this only works on HTTP/2 and HTTP/3
if r.ProtoMajor < 2 {
// this can happen if we are an upgraded request
// this can happen if we are an upgraded request from HTTP/1.1 to HTTP/2 in H2C
if _, ok := w.(http.Pusher); !ok {
return challenge.ResultContinue
}
@@ -397,18 +440,19 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
data := RequestDataFromContext(r.Context())
redirectUri := new(url.URL)
redirectUri.Scheme = getRequestScheme(r)
redirectUri.Host = r.Host
redirectUri.Path = c.Path + "/verify-challenge"
values := make(url.Values)
values.Set("result", hex.EncodeToString(key))
values.Set("redirect", r.URL.String())
values.Set("requestId", r.Header.Get("X-Away-Id"))
redirectUri.RawQuery = values.Encode()
w.Header().Set("Link", fmt.Sprintf("<%s>; rel=preload; as=style; fetchpriority=high", redirectUri.String()))
w.Header().Set("Link", fmt.Sprintf("<%s>; rel=\"preload\"; as=\"style\"; fetchpriority=high", redirectUri.String()))
defer func() {
// remove old header header!
// remove old header so it won't show on response!
w.Header().Del("Link")
}()
w.WriteHeader(http.StatusEarlyHints)
@@ -656,80 +700,6 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
state.Challenges[c.Id] = c
}
state.RulesEnv, err = cel.NewEnv(
cel.DefaultUTCTimeZone(true),
cel.Variable("remoteAddress", cel.BytesType),
cel.Variable("host", cel.StringType),
cel.Variable("method", cel.StringType),
cel.Variable("userAgent", cel.StringType),
cel.Variable("path", cel.StringType),
cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)),
// http.Header
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
//TODO: dynamic type?
cel.Function("inNetwork",
cel.Overload("inNetwork_string_ip",
[]*cel.Type{cel.StringType, cel.AnyType},
cel.BoolType,
cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val {
var ip net.IP
switch v := rhs.Value().(type) {
case []byte:
ip = v
case net.IP:
ip = v
case string:
ip = net.ParseIP(v)
}
if ip == nil {
panic(fmt.Errorf("invalid ip %v", rhs.Value()))
}
val, ok := lhs.Value().(string)
if !ok {
panic(fmt.Errorf("invalid value %v", lhs.Value()))
}
network, ok := state.Networks[val]
if !ok {
_, ipNet, err := net.ParseCIDR(val)
if err != nil {
panic("network not found")
}
return types.Bool(ipNet.Contains(ip))
} else {
ok, err := network.Contains(ip)
if err != nil {
panic(err)
}
return types.Bool(ok)
}
}),
),
),
)
if err != nil {
return nil, err
}
var replacements []string
for k, entries := range p.Conditions {
ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...)
if err != nil {
return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err)
}
cond, err := cel.AstToString(ast)
if err != nil {
return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err)
}
replacements = append(replacements, fmt.Sprintf("($%s)", k))
replacements = append(replacements, "("+cond+")")
}
conditionReplacer := strings.NewReplacer(replacements...)
for _, rule := range p.Rules {
hasher := sha256.New()
hasher.Write([]byte(rule.Name))