Implement cache for networks

This commit is contained in:
WeebDataHoarder
2025-04-23 20:12:02 +02:00
parent a0224cb21c
commit 6bb7ca979d
6 changed files with 164 additions and 13 deletions

View File

@@ -298,7 +298,7 @@ However, a few points are left before go-away can be called v1.0.0:
* [ ] Replace Anubis-like default template with own one. * [ ] Replace Anubis-like default template with own one.
* [ ] Define strings and multi-language support for quick modification by operators without custom templates. * [ ] Define strings and multi-language support for quick modification by operators without custom templates.
* [ ] Have highly tested paths that match examples. * [ ] Have highly tested paths that match examples.
* [ ] Caching of temporary fetches, for example, network ranges. * [x] Caching of temporary fetches, for example, network ranges.
* [x] Allow live and dynamic policy reloading. * [x] Allow live and dynamic policy reloading.
* [x] Multiple domains / subdomains -> one backend handling, CEL rules for backends * [x] Multiple domains / subdomains -> one backend handling, CEL rules for backends
* [ ] Merge all rules and conditions into one large AST for higher performance. * [ ] Merge all rules and conditions into one large AST for higher performance.

View File

@@ -223,11 +223,23 @@ func main() {
log.Fatal(fmt.Errorf("no backends defined in policy file")) log.Fatal(fmt.Errorf("no backends defined in policy file"))
} }
var cache utils.Cache
if *cachePath != "" { if *cachePath != "" {
err = os.MkdirAll(*cachePath, 0755) err = os.MkdirAll(*cachePath, 0755)
if err != nil { if err != nil {
log.Fatal(fmt.Errorf("failed to create cache directory: %w", err)) log.Fatal(fmt.Errorf("failed to create cache directory: %w", err))
} }
for _, n := range []string{"networks", "acme"} {
err = os.MkdirAll(path.Join(*cachePath, n), 0755)
if err != nil {
log.Fatal(fmt.Errorf("failed to create cache sub directory %s: %w", n, err))
}
}
cache, err = utils.CacheDirectory(*cachePath)
if err != nil {
log.Fatal(fmt.Errorf("failed to open cache directory: %w", err))
}
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
@@ -293,6 +305,7 @@ func main() {
} }
settings := policy.Settings{ settings := policy.Settings{
Cache: cache,
Backends: createdBackends, Backends: createdBackends,
Debug: *debugMode, Debug: *debugMode,
PackageName: *packageName, PackageName: *packageName,

View File

@@ -27,6 +27,7 @@ type Network struct {
} }
func (n Network) FetchPrefixes(c *http.Client, whois *utils.RADb) (output []net.IPNet, err error) { func (n Network) FetchPrefixes(c *http.Client, whois *utils.RADb) (output []net.IPNet, err error) {
if len(n.Prefixes) > 0 { if len(n.Prefixes) > 0 {
for _, prefix := range n.Prefixes { for _, prefix := range n.Prefixes {
ipNet, err := parseCIDROrIP(prefix) ipNet, err := parseCIDROrIP(prefix)

View File

@@ -1,10 +1,12 @@
package policy package policy
import ( import (
"git.gammaspectra.live/git/go-away/utils"
"net/http" "net/http"
) )
type Settings struct { type Settings struct {
Cache utils.Cache
Backends map[string]http.Handler Backends map[string]http.Handler
PrivateKeySeed []byte PrivateKeySeed []byte
Debug bool Debug bool

View File

@@ -3,6 +3,7 @@ package lib
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"encoding/json"
"fmt" "fmt"
"git.gammaspectra.live/git/go-away/lib/challenge" "git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/condition" "git.gammaspectra.live/git/go-away/lib/condition"
@@ -11,11 +12,13 @@ import (
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/yl2chen/cidranger" "github.com/yl2chen/cidranger"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"os" "os"
"path" "path"
"strings" "strings"
"time"
) )
type State struct { type State struct {
@@ -105,26 +108,72 @@ func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler,
} }
state.networks = make(map[string]cidranger.Ranger) state.networks = make(map[string]cidranger.Ranger)
networkCache := utils.CachePrefix(state.Settings().Cache, "networks/")
for k, network := range p.Networks { for k, network := range p.Networks {
ranger := cidranger.NewPCTrieRanger() ranger := cidranger.NewPCTrieRanger()
for _, e := range network { for i, e := range network {
if e.Url != nil { prefixes, err := func() ([]net.IPNet, error) {
slog.Debug("loading network url list", "network", k, "url", *e.Url) var useCache bool
} if e.Url != nil {
if e.ASN != nil { slog.Debug("loading network url list", "network", k, "url", *e.Url)
slog.Debug("loading ASN", "network", k, "asn", *e.ASN) useCache = true
} } else if e.ASN != nil {
prefixes, err := e.FetchPrefixes(state.client, state.radb) slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
if err != nil { useCache = true
slog.Error("error fetching network list", "network", k, "url", *e.Url) }
continue
} cacheKey := fmt.Sprintf("%s-%d", k, i)
var cached []net.IPNet
if useCache && networkCache != nil {
cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
var l []string
_ = json.Unmarshal(cachedData, &l)
for _, n := range l {
_, ipNet, err := net.ParseCIDR(n)
if err == nil {
cached = append(cached, *ipNet)
}
}
if err == nil {
// use
return cached, nil
}
}
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if err != nil {
if len(cached) > 0 {
// use cached meanwhile
return cached, err
}
return nil, err
}
if useCache && networkCache != nil {
var l []string
for _, n := range prefixes {
l = append(l, n.String())
}
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
for _, prefix := range prefixes { for _, prefix := range prefixes {
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix)) err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
if err != nil { if err != nil {
return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err) return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err)
} }
} }
if err != nil {
slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
continue
}
} }
slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len()) slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len())

86
utils/cache.go Normal file
View File

@@ -0,0 +1,86 @@
package utils
import (
"errors"
"os"
"path"
"time"
)
type Cache interface {
Get(key string, maxAge time.Duration) ([]byte, error)
Set(key string, value []byte) error
}
func CachePrefix(c Cache, prefix string) Cache {
if c == nil {
return nil
}
return prefixCache{
c: c,
prefix: prefix,
}
}
func CacheDirectory(directory string) (Cache, error) {
if stat, err := os.Stat(directory); err != nil {
return nil, err
} else if !stat.IsDir() {
return nil, errors.New("not a directory")
}
return dirCache(directory), nil
}
type prefixCache struct {
c Cache
prefix string
}
func (c prefixCache) Get(key string, maxAge time.Duration) ([]byte, error) {
return c.c.Get(c.prefix+key, maxAge)
}
func (c prefixCache) Set(key string, value []byte) error {
return c.c.Set(c.prefix+key, value)
}
type dirCache string
var ErrExpired = errors.New("key expired")
func (d dirCache) Get(key string, maxAge time.Duration) ([]byte, error) {
fname := path.Join(string(d), key)
stat, err := os.Stat(fname)
if err != nil {
return nil, err
}
if stat.IsDir() {
return nil, errors.New("key is directory")
}
data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}
if stat.ModTime().Before(time.Now().Add(-maxAge)) {
return data, ErrExpired
} else {
return data, nil
}
}
func (d dirCache) Set(key string, value []byte) error {
fname := path.Join(string(d), key)
fs, err := os.Create(fname)
if err != nil {
return err
}
defer fs.Close()
_, err = fs.Write(value)
fs.Sync()
fs.Close()
_ = os.Chtimes(fname, time.Time{}, time.Now())
return err
}