From d6c29846dffc6a3f3f4a266b94f2113f675105fd Mon Sep 17 00:00:00 2001 From: WeebDataHoarder Date: Thu, 1 May 2025 02:35:27 +0200 Subject: [PATCH] condition: generalize AST compilation, hot load network prefix blocks as needed, walk the AST and detect and preload networks --- examples/snippets/networks-other.yml | 72 +++++++------- lib/challenge/register.go | 17 +--- lib/challenge/types.go | 2 +- lib/conditions.go | 115 ++++++++++++++++++++++- lib/rule.go | 15 +-- lib/state.go | 134 ++++++++++++++------------- 6 files changed, 223 insertions(+), 132 deletions(-) diff --git a/examples/snippets/networks-other.yml b/examples/snippets/networks-other.yml index 3952e1f..c5ffdda 100644 --- a/examples/snippets/networks-other.yml +++ b/examples/snippets/networks-other.yml @@ -1,37 +1,37 @@ networks: -# aws-cloud: -# - url: https://ip-ranges.amazonaws.com/ip-ranges.json -# jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)' -# google-cloud: -# - url: https://www.gstatic.com/ipranges/cloud.json -# jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)' -# oracle-cloud: -# - url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json -# jq-path: '.regions[] | .cidrs[] | .cidr' -# azure-cloud: -# # todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON -# - url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json -# jq-path: '.values[] | .properties.addressPrefixes[]' -# -# digitalocean: -# - url: https://www.digitalocean.com/geo/google.csv -# regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," -# linode: -# - url: https://geoip.linode.com/ -# regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," -# vultr: -# - url: "https://geofeed.constant.com/?json" -# jq-path: '.subnets[] | .ip_prefix' -# cloudflare: -# - url: https://www.cloudflare.com/ips-v4 -# regex: "(?P[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)" -# - url: https://www.cloudflare.com/ips-v6 -# regex: "(?P[0-9a-f:]+::/[0-9]+)" -# -# icloud-private-relay: -# - url: https://mask-api.icloud.com/egress-ip-ranges.csv -# regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," -# tunnelbroker-relay: -# # HE Tunnelbroker -# - url: https://tunnelbroker.net/export/google -# regex: "(?P([0-9a-f:]+::)/[0-9]+)," + aws-cloud: + - url: https://ip-ranges.amazonaws.com/ip-ranges.json + jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)' + google-cloud: + - url: https://www.gstatic.com/ipranges/cloud.json + jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)' + oracle-cloud: + - url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json + jq-path: '.regions[] | .cidrs[] | .cidr' + azure-cloud: + # todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON + - url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json + jq-path: '.values[] | .properties.addressPrefixes[]' + + digitalocean: + - url: https://www.digitalocean.com/geo/google.csv + regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," + linode: + - url: https://geoip.linode.com/ + regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," + vultr: + - url: "https://geofeed.constant.com/?json" + jq-path: '.subnets[] | .ip_prefix' + cloudflare: + - url: https://www.cloudflare.com/ips-v4 + regex: "(?P[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)" + - url: https://www.cloudflare.com/ips-v6 + regex: "(?P[0-9a-f:]+::/[0-9]+)" + + icloud-private-relay: + - url: https://mask-api.icloud.com/egress-ip-ranges.csv + regex: "(?P(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," + tunnelbroker-relay: + # HE Tunnelbroker + - url: https://tunnelbroker.net/export/google + regex: "(?P([0-9a-f:]+::)/[0-9]+)," diff --git a/lib/challenge/register.go b/lib/challenge/register.go index a92c336..353206b 100644 --- a/lib/challenge/register.go +++ b/lib/challenge/register.go @@ -11,7 +11,6 @@ import ( "github.com/go-jose/go-jose/v4/jwt" "github.com/goccy/go-yaml/ast" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/types" "io" "math/rand/v2" "net/http" @@ -68,20 +67,10 @@ func (r Register) Create(state StateInterface, name string, pol policy.Challenge } if len(conditions) > 0 { - ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...) + var err error + reg.Condition, err = state.RegisterCondition(http_cel.OperatorOr, conditions...) if err != nil { - return nil, 0, fmt.Errorf("error compiling conditions: %v", err) - } - - if out := ast.OutputType(); out == nil { - return nil, 0, fmt.Errorf("error compiling conditions: no output") - } else if out != types.BoolType { - return nil, 0, fmt.Errorf("error compiling conditions: output type is not bool") - } - - reg.Condition, err = http_cel.ProgramAst(state.ProgramEnv(), ast) - if err != nil { - return nil, 0, fmt.Errorf("error compiling program: %v", err) + return nil, 0, fmt.Errorf("error compiling condition: %w", err) } } diff --git a/lib/challenge/types.go b/lib/challenge/types.go index 570fe58..51a3a75 100644 --- a/lib/challenge/types.go +++ b/lib/challenge/types.go @@ -86,7 +86,7 @@ func (r VerifyResult) String() string { } type StateInterface interface { - ProgramEnv() *cel.Env + RegisterCondition(operator string, conditions ...string) (cel.Program, error) Client() *http.Client PrivateKey() ed25519.PrivateKey diff --git a/lib/conditions.go b/lib/conditions.go index 5e94b05..bdfe8d9 100644 --- a/lib/conditions.go +++ b/lib/conditions.go @@ -4,6 +4,7 @@ import ( http_cel "codeberg.org/gone/http-cel" "fmt" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "log/slog" @@ -55,7 +56,7 @@ func (state *State) initConditions() (err error) { } return types.Bool(ipNet.Contains(ip)) } else { - ok, err := network.Contains(ip) + ok, err := network().Contains(ip) if err != nil { panic(err) } @@ -96,7 +97,7 @@ func (state *State) initConditions() (err error) { } return types.Bool(ipNet.Contains(ip)) } else { - ok, err := network.Contains(ip) + ok, err := network().Contains(ip) if err != nil { panic(err) } @@ -111,3 +112,113 @@ func (state *State) initConditions() (err error) { } return nil } + +func (state *State) RegisterCondition(operator string, conditions ...string) (cel.Program, error) { + compiledAst, err := http_cel.NewAst(state.ProgramEnv(), operator, conditions...) + if err != nil { + return nil, err + } + + if out := compiledAst.OutputType(); out == nil { + return nil, fmt.Errorf("no output") + } else if out != types.BoolType { + return nil, fmt.Errorf("output type is not bool") + } + + walkExpr(compiledAst.NativeRep().Expr(), func(e ast.Expr) { + if e.Kind() == ast.CallKind { + call := e.AsCall() + switch call.FunctionName() { + // deprecated + case "inNetwork": + args := call.Args() + if call.IsMemberFunction() && len(args) == 2 { + // we have a network select function + switch args[1].Kind() { + case ast.LiteralKind: + lit := args[1].AsLiteral() + if lit.Type() == types.StringType { + if fn, ok := state.networks[lit.Value().(string)]; ok { + // preload + fn() + } + } + } + + } + case "network": + args := call.Args() + if call.IsMemberFunction() && len(args) == 1 { + // we have a network select function + switch args[0].Kind() { + case ast.LiteralKind: + lit := args[0].AsLiteral() + if lit.Type() == types.StringType { + if fn, ok := state.networks[lit.Value().(string)]; ok { + // preload + fn() + } + } + } + + } + } + } + }) + + return http_cel.ProgramAst(state.ProgramEnv(), compiledAst) +} + +func walkExpr(e ast.Expr, fn func(ast.Expr)) { + fn(e) + + switch e.Kind() { + case ast.CallKind: + ee := e.AsCall() + walkExpr(ee.Target(), fn) + for _, arg := range ee.Args() { + walkExpr(arg, fn) + } + case ast.ComprehensionKind: + ee := e.AsComprehension() + walkExpr(ee.Result(), fn) + walkExpr(ee.IterRange(), fn) + walkExpr(ee.AccuInit(), fn) + walkExpr(ee.LoopCondition(), fn) + walkExpr(ee.LoopStep(), fn) + case ast.ListKind: + ee := e.AsList() + for _, element := range ee.Elements() { + walkExpr(element, fn) + } + case ast.MapKind: + ee := e.AsMap() + for _, entry := range ee.Entries() { + switch entry.Kind() { + case ast.MapEntryKind: + eee := entry.AsMapEntry() + walkExpr(eee.Key(), fn) + walkExpr(eee.Value(), fn) + case ast.StructFieldKind: + eee := entry.AsStructField() + walkExpr(eee.Value(), fn) + } + } + case ast.SelectKind: + ee := e.AsSelect() + walkExpr(ee.Operand(), fn) + case ast.StructKind: + ee := e.AsStruct() + for _, field := range ee.Fields() { + switch field.Kind() { + case ast.MapEntryKind: + eee := field.AsMapEntry() + walkExpr(eee.Key(), fn) + walkExpr(eee.Value(), fn) + case ast.StructFieldKind: + eee := field.AsStructField() + walkExpr(eee.Value(), fn) + } + } + } +} diff --git a/lib/rule.go b/lib/rule.go index c43f41b..88c37b7 100644 --- a/lib/rule.go +++ b/lib/rule.go @@ -66,20 +66,9 @@ func NewRuleState(state challenge.StateInterface, r policy.Rule, replacer *strin conditions = append(conditions, cond) } - ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...) + program, err := state.RegisterCondition(http_cel.OperatorOr, conditions...) if err != nil { - return RuleState{}, fmt.Errorf("error compiling conditions: %w", err) - } - - if out := ast.OutputType(); out == nil { - return RuleState{}, fmt.Errorf("error compiling conditions: no output") - } else if out != types.BoolType { - return RuleState{}, fmt.Errorf("error compiling conditions: output type is not bool") - } - - program, err := http_cel.ProgramAst(state.ProgramEnv(), ast) - if err != nil { - return RuleState{}, fmt.Errorf("error compiling program: %w", err) + return RuleState{}, fmt.Errorf("error compiling condition: %w", err) } rule.Condition = program } diff --git a/lib/state.go b/lib/state.go index 1524262..8c16eab 100644 --- a/lib/state.go +++ b/lib/state.go @@ -24,6 +24,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -40,7 +41,7 @@ type State struct { opt settings.Settings settings policy.StateSettings - networks map[string]cidranger.Ranger + networks map[string]func() cidranger.Ranger challenges challenge.Register @@ -54,6 +55,7 @@ type State struct { } func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) { + state := new(State) state.close = make(chan struct{}) state.settings = settings @@ -114,89 +116,89 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti return nil, fmt.Errorf("no template defined for %s", state.opt.ChallengeTemplate) } - state.networks = make(map[string]cidranger.Ranger) + state.networks = make(map[string]func() cidranger.Ranger) networkCache := utils.CachePrefix(state.Settings().Cache, "networks/") for k, network := range p.Networks { + state.networks[k] = sync.OnceValue[cidranger.Ranger](func() cidranger.Ranger { + ranger := cidranger.NewPCTrieRanger() + for i, e := range network { + prefixes, err := func() ([]net.IPNet, error) { + var useCache bool - ranger := cidranger.NewPCTrieRanger() - for i, e := range network { - prefixes, err := func() ([]net.IPNet, error) { - var useCache bool + cacheKey := fmt.Sprintf("%s-%d-", k, i) + if e.Url != nil { + slog.Debug("loading network url list", "network", k, "url", *e.Url) + useCache = true + sum := sha256.Sum256([]byte(*e.Url)) + cacheKey += hex.EncodeToString(sum[:4]) + } else if e.ASN != nil { + slog.Debug("loading ASN", "network", k, "asn", *e.ASN) + useCache = true + cacheKey += strconv.FormatInt(int64(*e.ASN), 10) + } - cacheKey := fmt.Sprintf("%s-%d-", k, i) - if e.Url != nil { - slog.Debug("loading network url list", "network", k, "url", *e.Url) - useCache = true - sum := sha256.Sum256([]byte(*e.Url)) - cacheKey += hex.EncodeToString(sum[:4]) - } else if e.ASN != nil { - slog.Debug("loading ASN", "network", k, "asn", *e.ASN) - useCache = true - cacheKey += strconv.FormatInt(int64(*e.ASN), 10) - } - - var cached []net.IPNet - if useCache && networkCache != nil { - //TODO: add randomness - cachedData, err := networkCache.Get(cacheKey, time.Hour*24) - var l []string - _ = json.Unmarshal(cachedData, &l) - for _, n := range l { - _, ipNet, err := net.ParseCIDR(n) + var cached []net.IPNet + if useCache && networkCache != nil { + //TODO: add randomness + cachedData, err := networkCache.Get(cacheKey, time.Hour*24) + var l []string + _ = json.Unmarshal(cachedData, &l) + for _, n := range l { + _, ipNet, err := net.ParseCIDR(n) + if err == nil { + cached = append(cached, *ipNet) + } + } if err == nil { - cached = append(cached, *ipNet) + // use + return cached, nil + } } - if err == nil { - // use - return cached, nil + prefixes, err := e.FetchPrefixes(state.client, state.radb) + if err != nil { + if len(cached) > 0 { + // use cached meanwhile + return cached, err + } + return nil, err } - } - - prefixes, err := e.FetchPrefixes(state.client, state.radb) + if useCache && networkCache != nil { + var l []string + for _, n := range prefixes { + l = append(l, n.String()) + } + cachedData, err := json.Marshal(l) + if err == nil { + _ = networkCache.Set(cacheKey, cachedData) + } + } + return prefixes, nil + }() if err != nil { - if len(cached) > 0 { - // use cached meanwhile - return cached, err + if e.Url != nil { + slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err) + } else if e.ASN != nil { + slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err) + } else { + slog.Error("error loading list", "network", k, "error", err) } - return nil, err + continue } - if useCache && networkCache != nil { - var l []string - for _, n := range prefixes { - l = append(l, n.String()) + for _, prefix := range prefixes { + err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix)) + if err != nil { + slog.Error("error inserting prefix", "network", k, "prefix", prefix.String(), "error", err) } - cachedData, err := json.Marshal(l) - if err == nil { - _ = networkCache.Set(cacheKey, cachedData) - } - } - return prefixes, nil - }() - if err != nil { - if e.Url != nil { - slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err) - } else if e.ASN != nil { - slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err) - } else { - slog.Error("error loading list", "network", k, "error", err) - } - continue - } - for _, prefix := range prefixes { - err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix)) - if err != nil { - return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err) } } - } - slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len()) - - state.networks[k] = ranger + slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len()) + return ranger + }) } err = state.initConditions()