Implement cache for networks

This commit is contained in:
WeebDataHoarder
2025-04-23 20:12:02 +02:00
parent a0224cb21c
commit 6bb7ca979d
6 changed files with 164 additions and 13 deletions

View File

@@ -298,7 +298,7 @@ However, a few points are left before go-away can be called v1.0.0:
* [ ] Replace Anubis-like default template with own one.
* [ ] Define strings and multi-language support for quick modification by operators without custom templates.
* [ ] Have highly tested paths that match examples.
* [ ] Caching of temporary fetches, for example, network ranges.
* [x] Caching of temporary fetches, for example, network ranges.
* [x] Allow live and dynamic policy reloading.
* [x] Multiple domains / subdomains -> one backend handling, CEL rules for backends
* [ ] Merge all rules and conditions into one large AST for higher performance.

View File

@@ -223,11 +223,23 @@ func main() {
log.Fatal(fmt.Errorf("no backends defined in policy file"))
}
var cache utils.Cache
if *cachePath != "" {
err = os.MkdirAll(*cachePath, 0755)
if err != nil {
log.Fatal(fmt.Errorf("failed to create cache directory: %w", err))
}
for _, n := range []string{"networks", "acme"} {
err = os.MkdirAll(path.Join(*cachePath, n), 0755)
if err != nil {
log.Fatal(fmt.Errorf("failed to create cache sub directory %s: %w", n, err))
}
}
cache, err = utils.CacheDirectory(*cachePath)
if err != nil {
log.Fatal(fmt.Errorf("failed to open cache directory: %w", err))
}
}
var tlsConfig *tls.Config
@@ -293,6 +305,7 @@ func main() {
}
settings := policy.Settings{
Cache: cache,
Backends: createdBackends,
Debug: *debugMode,
PackageName: *packageName,

View File

@@ -27,6 +27,7 @@ type Network struct {
}
func (n Network) FetchPrefixes(c *http.Client, whois *utils.RADb) (output []net.IPNet, err error) {
if len(n.Prefixes) > 0 {
for _, prefix := range n.Prefixes {
ipNet, err := parseCIDROrIP(prefix)

View File

@@ -1,10 +1,12 @@
package policy
import (
"git.gammaspectra.live/git/go-away/utils"
"net/http"
)
type Settings struct {
Cache utils.Cache
Backends map[string]http.Handler
PrivateKeySeed []byte
Debug bool

View File

@@ -3,6 +3,7 @@ package lib
import (
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"fmt"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/condition"
@@ -11,11 +12,13 @@ import (
"github.com/google/cel-go/cel"
"github.com/yl2chen/cidranger"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"os"
"path"
"strings"
"time"
)
type State struct {
@@ -105,26 +108,72 @@ func NewState(p policy.Policy, settings policy.Settings) (handler http.Handler,
}
state.networks = make(map[string]cidranger.Ranger)
networkCache := utils.CachePrefix(state.Settings().Cache, "networks/")
for k, network := range p.Networks {
ranger := cidranger.NewPCTrieRanger()
for _, e := range network {
if e.Url != nil {
slog.Debug("loading network url list", "network", k, "url", *e.Url)
}
if e.ASN != nil {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
}
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if err != nil {
slog.Error("error fetching network list", "network", k, "url", *e.Url)
continue
}
for i, e := range network {
prefixes, err := func() ([]net.IPNet, error) {
var useCache bool
if e.Url != nil {
slog.Debug("loading network url list", "network", k, "url", *e.Url)
useCache = true
} else if e.ASN != nil {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
useCache = true
}
cacheKey := fmt.Sprintf("%s-%d", k, i)
var cached []net.IPNet
if useCache && networkCache != nil {
cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
var l []string
_ = json.Unmarshal(cachedData, &l)
for _, n := range l {
_, ipNet, err := net.ParseCIDR(n)
if err == nil {
cached = append(cached, *ipNet)
}
}
if err == nil {
// use
return cached, nil
}
}
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if err != nil {
if len(cached) > 0 {
// use cached meanwhile
return cached, err
}
return nil, err
}
if useCache && networkCache != nil {
var l []string
for _, n := range prefixes {
l = append(l, n.String())
}
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
for _, prefix := range prefixes {
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
if err != nil {
return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err)
}
}
if err != nil {
slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
continue
}
}
slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len())

86
utils/cache.go Normal file
View File

@@ -0,0 +1,86 @@
package utils
import (
"errors"
"os"
"path"
"time"
)
type Cache interface {
Get(key string, maxAge time.Duration) ([]byte, error)
Set(key string, value []byte) error
}
func CachePrefix(c Cache, prefix string) Cache {
if c == nil {
return nil
}
return prefixCache{
c: c,
prefix: prefix,
}
}
func CacheDirectory(directory string) (Cache, error) {
if stat, err := os.Stat(directory); err != nil {
return nil, err
} else if !stat.IsDir() {
return nil, errors.New("not a directory")
}
return dirCache(directory), nil
}
type prefixCache struct {
c Cache
prefix string
}
func (c prefixCache) Get(key string, maxAge time.Duration) ([]byte, error) {
return c.c.Get(c.prefix+key, maxAge)
}
func (c prefixCache) Set(key string, value []byte) error {
return c.c.Set(c.prefix+key, value)
}
type dirCache string
var ErrExpired = errors.New("key expired")
func (d dirCache) Get(key string, maxAge time.Duration) ([]byte, error) {
fname := path.Join(string(d), key)
stat, err := os.Stat(fname)
if err != nil {
return nil, err
}
if stat.IsDir() {
return nil, errors.New("key is directory")
}
data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}
if stat.ModTime().Before(time.Now().Add(-maxAge)) {
return data, ErrExpired
} else {
return data, nil
}
}
func (d dirCache) Set(key string, value []byte) error {
fname := path.Join(string(d), key)
fs, err := os.Create(fname)
if err != nil {
return err
}
defer fs.Close()
_, err = fs.Write(value)
fs.Sync()
fs.Close()
_ = os.Chtimes(fname, time.Time{}, time.Now())
return err
}