Rearranged wasm challenge utils

This commit is contained in:
WeebDataHoarder
2025-04-06 12:51:27 +02:00
parent 65561ab00e
commit 02f3c1cb19
12 changed files with 149 additions and 125 deletions

View File

@@ -25,12 +25,17 @@ type ChallengeInformation struct {
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
}
func getRequestAddress(r *http.Request) net.IP {
//TODO: verified upstream
ipStr := r.Header.Get("X-Real-Ip")
if ipStr == "" {
ipStr = strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]
func getRequestAddress(r *http.Request, clientHeader string) net.IP {
var ipStr string
if clientHeader != "" {
ipStr = r.Header.Get(clientHeader)
}
if ipStr != "" {
// handle X-Forwarded-For
ipStr = strings.Split(ipStr, ",")[0]
}
// fallback
if ipStr == "" {
parts := strings.Split(r.RemoteAddr, ":")
// drop port
@@ -44,7 +49,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h
hasher.Write([]byte("challenge\x00"))
hasher.Write([]byte(name))
hasher.Write([]byte{0})
hasher.Write(getRequestAddress(r).To16())
hasher.Write(getRequestAddress(r, state.Settings.ClientIpHeader).To16())
hasher.Write([]byte{0})
// specific headers

View File

@@ -1,4 +1,4 @@
package challenge
package _interface
import (
"encoding/json"

View File

@@ -0,0 +1,10 @@
//go:build !tinygo || !wasip1
package _interface
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }

View File

@@ -1,6 +1,6 @@
//go:build tinygo
package challenge
package _interface
// #include <stdlib.h>
import "C"

View File

@@ -1,10 +1,7 @@
//go:build !tinygo || !wasip1
package challenge
package wasm
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/tetratelabs/wazero"
@@ -123,73 +120,3 @@ func (r *Runner) Instantiate(key string, f func(ctx context.Context, mod api.Mod
return f(r.context, mod)
}
func MakeChallengeCall(ctx context.Context, mod api.Module, in MakeChallengeInput) (*MakeChallengeOutput, error) {
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return nil, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return nil, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return nil, errors.New("could not write memory")
}
result, err := makeChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return nil, err
}
resultPtr := Allocation(result[0])
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
if !ok {
return nil, errors.New("could not read result")
}
defer free.Call(ctx, uint64(resultPtr.Pointer()))
var out MakeChallengeOutput
err = json.Unmarshal(outData, &out)
if err != nil {
return nil, err
}
return &out, nil
}
func VerifyChallengeCall(ctx context.Context, mod api.Module, in VerifyChallengeInput) (VerifyChallengeOutput, error) {
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return VerifyChallengeOutputError, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return VerifyChallengeOutputError, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return VerifyChallengeOutputError, errors.New("could not write memory")
}
result, err := verifyChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return VerifyChallengeOutputError, err
}
return VerifyChallengeOutput(result[0]), nil
}
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }

View File

@@ -0,0 +1,72 @@
package wasm
import (
"context"
"encoding/json"
"errors"
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
"github.com/tetratelabs/wazero/api"
)
func MakeChallengeCall(ctx context.Context, mod api.Module, in _interface.MakeChallengeInput) (*_interface.MakeChallengeOutput, error) {
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return nil, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return nil, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return nil, errors.New("could not write memory")
}
result, err := makeChallengeFunc.Call(ctx, uint64(_interface.NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return nil, err
}
resultPtr := _interface.Allocation(result[0])
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
if !ok {
return nil, errors.New("could not read result")
}
defer free.Call(ctx, uint64(resultPtr.Pointer()))
var out _interface.MakeChallengeOutput
err = json.Unmarshal(outData, &out)
if err != nil {
return nil, err
}
return &out, nil
}
func VerifyChallengeCall(ctx context.Context, mod api.Module, in _interface.VerifyChallengeInput) (_interface.VerifyChallengeOutput, error) {
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return _interface.VerifyChallengeOutputError, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return _interface.VerifyChallengeOutputError, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return _interface.VerifyChallengeOutputError, errors.New("could not write memory")
}
result, err := verifyChallengeFunc.Call(ctx, uint64(_interface.NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return _interface.VerifyChallengeOutputError, err
}
return _interface.VerifyChallengeOutput(result[0]), nil
}

View File

@@ -125,10 +125,10 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration
}
}
func GetLoggerForRequest(r *http.Request) *slog.Logger {
func GetLoggerForRequest(r *http.Request, clientHeader string) *slog.Logger {
return slog.With(
"request_id", r.Header.Get("X-Away-Id"),
"remote_address", getRequestAddress(r),
"remote_address", getRequestAddress(r, clientHeader),
"user_agent", r.UserAgent(),
"host", r.Host,
"path", r.URL.Path,
@@ -136,6 +136,10 @@ func GetLoggerForRequest(r *http.Request) *slog.Logger {
)
}
func (state *State) logger(r *http.Request) *slog.Logger {
return GetLoggerForRequest(r, state.Settings.ClientIpHeader)
}
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
host := r.Host
@@ -145,7 +149,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
lg := GetLoggerForRequest(r)
lg := state.logger(r)
start := time.Now()
@@ -153,7 +157,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
env := map[string]any{
"host": host,
"method": r.Method,
"remoteAddress": getRequestAddress(r),
"remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader),
"userAgent": r.UserAgent(),
"path": r.URL.Path,
"query": func() map[string]string {
@@ -259,7 +263,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
if rule.Action == policy.RuleActionCHECK {
goto nextRule
}
GetLoggerForRequest(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
// we pass the challenge early!
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
@@ -374,15 +378,15 @@ func (state *State) setupRoutes() error {
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
if err != nil {
GetLoggerForRequest(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
return err
} else if !ok {
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
ClearCookie(CookiePrefix+challengeName, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
return nil
}
GetLoggerForRequest(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
if err != nil {

View File

@@ -12,7 +12,8 @@ import (
"errors"
"fmt"
"git.gammaspectra.live/git/go-away/embed"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/challenge/wasm"
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
"git.gammaspectra.live/git/go-away/lib/condition"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/utils/inline"
@@ -44,7 +45,7 @@ type State struct {
Networks map[string]cidranger.Ranger
Wasm *challenge.Runner
Wasm *wasm.Runner
Challenges map[string]ChallengeState
@@ -101,6 +102,7 @@ type StateSettings struct {
PackageName string
ChallengeTemplate string
ChallengeTemplateTheme string
ClientIpHeader string
}
func NewState(p policy.Policy, settings StateSettings) (state *State, err error) {
@@ -118,7 +120,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
if proxy, ok := backend.(*httputil.ReverseProxy); ok {
if proxy.ErrorHandler == nil {
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
GetLoggerForRequest(r).Error(err.Error())
state.logger(r).Error(err.Error())
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err)
}
}
@@ -186,7 +188,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
state.Networks[k] = ranger
}
state.Wasm = challenge.NewRunner(true)
state.Wasm = wasm.NewRunner(true)
state.Challenges = make(map[string]ChallengeState)
@@ -429,12 +431,12 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
if ok, err := c.Verify(key, result); err != nil {
return err
} else if !ok {
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
ClearCookie(CookiePrefix+challengeName, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
return nil
}
GetLoggerForRequest(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
if err != nil {
@@ -476,7 +478,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
c.MakeChallenge = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
in := challenge.MakeChallengeInput{
in := _interface.MakeChallengeInput{
Key: state.GetChallengeKeyForRequest(challengeName, time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity), r),
Parameters: p.Parameters,
Headers: inline.MIMEHeader(r.Header),
@@ -486,7 +488,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
return err
}
out, err := challenge.MakeChallengeCall(ctx, mod, in)
out, err := wasm.MakeChallengeCall(ctx, mod, in)
if err != nil {
return err
}
@@ -508,21 +510,21 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
c.Verify = func(key []byte, result string) (ok bool, err error) {
err = state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
in := challenge.VerifyChallengeInput{
in := _interface.VerifyChallengeInput{
Key: key,
Parameters: p.Parameters,
Result: []byte(result),
}
out, err := challenge.VerifyChallengeCall(ctx, mod, in)
out, err := wasm.VerifyChallengeCall(ctx, mod, in)
if err != nil {
return err
}
if out == challenge.VerifyChallengeOutputError {
if out == _interface.VerifyChallengeOutputError {
return errors.New("error checking challenge")
}
ok = out == challenge.VerifyChallengeOutputOK
ok = out == _interface.VerifyChallengeOutputOK
return nil
})
if err != nil {