condition: generalize AST compilation, hot load network prefix blocks as needed, walk the AST and detect and preload networks

This commit is contained in:
WeebDataHoarder
2025-05-01 02:35:27 +02:00
parent 6e47cec540
commit d6c29846df
6 changed files with 223 additions and 132 deletions

View File

@@ -1,37 +1,37 @@
networks:
# aws-cloud:
# - url: https://ip-ranges.amazonaws.com/ip-ranges.json
# jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)'
# google-cloud:
# - url: https://www.gstatic.com/ipranges/cloud.json
# jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)'
# oracle-cloud:
# - url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json
# jq-path: '.regions[] | .cidrs[] | .cidr'
# azure-cloud:
# # todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON
# - url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json
# jq-path: '.values[] | .properties.addressPrefixes[]'
#
# digitalocean:
# - url: https://www.digitalocean.com/geo/google.csv
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# linode:
# - url: https://geoip.linode.com/
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# vultr:
# - url: "https://geofeed.constant.com/?json"
# jq-path: '.subnets[] | .ip_prefix'
# cloudflare:
# - url: https://www.cloudflare.com/ips-v4
# regex: "(?P<prefix>[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)"
# - url: https://www.cloudflare.com/ips-v6
# regex: "(?P<prefix>[0-9a-f:]+::/[0-9]+)"
#
# icloud-private-relay:
# - url: https://mask-api.icloud.com/egress-ip-ranges.csv
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# tunnelbroker-relay:
# # HE Tunnelbroker
# - url: https://tunnelbroker.net/export/google
# regex: "(?P<prefix>([0-9a-f:]+::)/[0-9]+),"
aws-cloud:
- url: https://ip-ranges.amazonaws.com/ip-ranges.json
jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)'
google-cloud:
- url: https://www.gstatic.com/ipranges/cloud.json
jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)'
oracle-cloud:
- url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json
jq-path: '.regions[] | .cidrs[] | .cidr'
azure-cloud:
# todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON
- url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json
jq-path: '.values[] | .properties.addressPrefixes[]'
digitalocean:
- url: https://www.digitalocean.com/geo/google.csv
regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
linode:
- url: https://geoip.linode.com/
regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
vultr:
- url: "https://geofeed.constant.com/?json"
jq-path: '.subnets[] | .ip_prefix'
cloudflare:
- url: https://www.cloudflare.com/ips-v4
regex: "(?P<prefix>[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)"
- url: https://www.cloudflare.com/ips-v6
regex: "(?P<prefix>[0-9a-f:]+::/[0-9]+)"
icloud-private-relay:
- url: https://mask-api.icloud.com/egress-ip-ranges.csv
regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
tunnelbroker-relay:
# HE Tunnelbroker
- url: https://tunnelbroker.net/export/google
regex: "(?P<prefix>([0-9a-f:]+::)/[0-9]+),"

View File

@@ -11,7 +11,6 @@ import (
"github.com/go-jose/go-jose/v4/jwt"
"github.com/goccy/go-yaml/ast"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"io"
"math/rand/v2"
"net/http"
@@ -68,20 +67,10 @@ func (r Register) Create(state StateInterface, name string, pol policy.Challenge
}
if len(conditions) > 0 {
ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...)
var err error
reg.Condition, err = state.RegisterCondition(http_cel.OperatorOr, conditions...)
if err != nil {
return nil, 0, fmt.Errorf("error compiling conditions: %v", err)
}
if out := ast.OutputType(); out == nil {
return nil, 0, fmt.Errorf("error compiling conditions: no output")
} else if out != types.BoolType {
return nil, 0, fmt.Errorf("error compiling conditions: output type is not bool")
}
reg.Condition, err = http_cel.ProgramAst(state.ProgramEnv(), ast)
if err != nil {
return nil, 0, fmt.Errorf("error compiling program: %v", err)
return nil, 0, fmt.Errorf("error compiling condition: %w", err)
}
}

View File

@@ -86,7 +86,7 @@ func (r VerifyResult) String() string {
}
type StateInterface interface {
ProgramEnv() *cel.Env
RegisterCondition(operator string, conditions ...string) (cel.Program, error)
Client() *http.Client
PrivateKey() ed25519.PrivateKey

View File

@@ -4,6 +4,7 @@ import (
http_cel "codeberg.org/gone/http-cel"
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"log/slog"
@@ -55,7 +56,7 @@ func (state *State) initConditions() (err error) {
}
return types.Bool(ipNet.Contains(ip))
} else {
ok, err := network.Contains(ip)
ok, err := network().Contains(ip)
if err != nil {
panic(err)
}
@@ -96,7 +97,7 @@ func (state *State) initConditions() (err error) {
}
return types.Bool(ipNet.Contains(ip))
} else {
ok, err := network.Contains(ip)
ok, err := network().Contains(ip)
if err != nil {
panic(err)
}
@@ -111,3 +112,113 @@ func (state *State) initConditions() (err error) {
}
return nil
}
func (state *State) RegisterCondition(operator string, conditions ...string) (cel.Program, error) {
compiledAst, err := http_cel.NewAst(state.ProgramEnv(), operator, conditions...)
if err != nil {
return nil, err
}
if out := compiledAst.OutputType(); out == nil {
return nil, fmt.Errorf("no output")
} else if out != types.BoolType {
return nil, fmt.Errorf("output type is not bool")
}
walkExpr(compiledAst.NativeRep().Expr(), func(e ast.Expr) {
if e.Kind() == ast.CallKind {
call := e.AsCall()
switch call.FunctionName() {
// deprecated
case "inNetwork":
args := call.Args()
if call.IsMemberFunction() && len(args) == 2 {
// we have a network select function
switch args[1].Kind() {
case ast.LiteralKind:
lit := args[1].AsLiteral()
if lit.Type() == types.StringType {
if fn, ok := state.networks[lit.Value().(string)]; ok {
// preload
fn()
}
}
}
}
case "network":
args := call.Args()
if call.IsMemberFunction() && len(args) == 1 {
// we have a network select function
switch args[0].Kind() {
case ast.LiteralKind:
lit := args[0].AsLiteral()
if lit.Type() == types.StringType {
if fn, ok := state.networks[lit.Value().(string)]; ok {
// preload
fn()
}
}
}
}
}
}
})
return http_cel.ProgramAst(state.ProgramEnv(), compiledAst)
}
func walkExpr(e ast.Expr, fn func(ast.Expr)) {
fn(e)
switch e.Kind() {
case ast.CallKind:
ee := e.AsCall()
walkExpr(ee.Target(), fn)
for _, arg := range ee.Args() {
walkExpr(arg, fn)
}
case ast.ComprehensionKind:
ee := e.AsComprehension()
walkExpr(ee.Result(), fn)
walkExpr(ee.IterRange(), fn)
walkExpr(ee.AccuInit(), fn)
walkExpr(ee.LoopCondition(), fn)
walkExpr(ee.LoopStep(), fn)
case ast.ListKind:
ee := e.AsList()
for _, element := range ee.Elements() {
walkExpr(element, fn)
}
case ast.MapKind:
ee := e.AsMap()
for _, entry := range ee.Entries() {
switch entry.Kind() {
case ast.MapEntryKind:
eee := entry.AsMapEntry()
walkExpr(eee.Key(), fn)
walkExpr(eee.Value(), fn)
case ast.StructFieldKind:
eee := entry.AsStructField()
walkExpr(eee.Value(), fn)
}
}
case ast.SelectKind:
ee := e.AsSelect()
walkExpr(ee.Operand(), fn)
case ast.StructKind:
ee := e.AsStruct()
for _, field := range ee.Fields() {
switch field.Kind() {
case ast.MapEntryKind:
eee := field.AsMapEntry()
walkExpr(eee.Key(), fn)
walkExpr(eee.Value(), fn)
case ast.StructFieldKind:
eee := field.AsStructField()
walkExpr(eee.Value(), fn)
}
}
}
}

View File

@@ -66,20 +66,9 @@ func NewRuleState(state challenge.StateInterface, r policy.Rule, replacer *strin
conditions = append(conditions, cond)
}
ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...)
program, err := state.RegisterCondition(http_cel.OperatorOr, conditions...)
if err != nil {
return RuleState{}, fmt.Errorf("error compiling conditions: %w", err)
}
if out := ast.OutputType(); out == nil {
return RuleState{}, fmt.Errorf("error compiling conditions: no output")
} else if out != types.BoolType {
return RuleState{}, fmt.Errorf("error compiling conditions: output type is not bool")
}
program, err := http_cel.ProgramAst(state.ProgramEnv(), ast)
if err != nil {
return RuleState{}, fmt.Errorf("error compiling program: %w", err)
return RuleState{}, fmt.Errorf("error compiling condition: %w", err)
}
rule.Condition = program
}

View File

@@ -24,6 +24,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
)
@@ -40,7 +41,7 @@ type State struct {
opt settings.Settings
settings policy.StateSettings
networks map[string]cidranger.Ranger
networks map[string]func() cidranger.Ranger
challenges challenge.Register
@@ -54,6 +55,7 @@ type State struct {
}
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) {
state := new(State)
state.close = make(chan struct{})
state.settings = settings
@@ -114,89 +116,89 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti
return nil, fmt.Errorf("no template defined for %s", state.opt.ChallengeTemplate)
}
state.networks = make(map[string]cidranger.Ranger)
state.networks = make(map[string]func() cidranger.Ranger)
networkCache := utils.CachePrefix(state.Settings().Cache, "networks/")
for k, network := range p.Networks {
state.networks[k] = sync.OnceValue[cidranger.Ranger](func() cidranger.Ranger {
ranger := cidranger.NewPCTrieRanger()
for i, e := range network {
prefixes, err := func() ([]net.IPNet, error) {
var useCache bool
ranger := cidranger.NewPCTrieRanger()
for i, e := range network {
prefixes, err := func() ([]net.IPNet, error) {
var useCache bool
cacheKey := fmt.Sprintf("%s-%d-", k, i)
if e.Url != nil {
slog.Debug("loading network url list", "network", k, "url", *e.Url)
useCache = true
sum := sha256.Sum256([]byte(*e.Url))
cacheKey += hex.EncodeToString(sum[:4])
} else if e.ASN != nil {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
useCache = true
cacheKey += strconv.FormatInt(int64(*e.ASN), 10)
}
cacheKey := fmt.Sprintf("%s-%d-", k, i)
if e.Url != nil {
slog.Debug("loading network url list", "network", k, "url", *e.Url)
useCache = true
sum := sha256.Sum256([]byte(*e.Url))
cacheKey += hex.EncodeToString(sum[:4])
} else if e.ASN != nil {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
useCache = true
cacheKey += strconv.FormatInt(int64(*e.ASN), 10)
}
var cached []net.IPNet
if useCache && networkCache != nil {
//TODO: add randomness
cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
var l []string
_ = json.Unmarshal(cachedData, &l)
for _, n := range l {
_, ipNet, err := net.ParseCIDR(n)
var cached []net.IPNet
if useCache && networkCache != nil {
//TODO: add randomness
cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
var l []string
_ = json.Unmarshal(cachedData, &l)
for _, n := range l {
_, ipNet, err := net.ParseCIDR(n)
if err == nil {
cached = append(cached, *ipNet)
}
}
if err == nil {
cached = append(cached, *ipNet)
// use
return cached, nil
}
}
if err == nil {
// use
return cached, nil
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if err != nil {
if len(cached) > 0 {
// use cached meanwhile
return cached, err
}
return nil, err
}
}
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if useCache && networkCache != nil {
var l []string
for _, n := range prefixes {
l = append(l, n.String())
}
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
if err != nil {
if len(cached) > 0 {
// use cached meanwhile
return cached, err
if e.Url != nil {
slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
} else if e.ASN != nil {
slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err)
} else {
slog.Error("error loading list", "network", k, "error", err)
}
return nil, err
continue
}
if useCache && networkCache != nil {
var l []string
for _, n := range prefixes {
l = append(l, n.String())
for _, prefix := range prefixes {
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
if err != nil {
slog.Error("error inserting prefix", "network", k, "prefix", prefix.String(), "error", err)
}
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
if err != nil {
if e.Url != nil {
slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
} else if e.ASN != nil {
slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err)
} else {
slog.Error("error loading list", "network", k, "error", err)
}
continue
}
for _, prefix := range prefixes {
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
if err != nil {
return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err)
}
}
}
slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len())
state.networks[k] = ranger
slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len())
return ranger
})
}
err = state.initConditions()