Add wasm helper, wasm test utility

This commit is contained in:
WeebDataHoarder
2025-04-06 11:44:06 +02:00
parent 6623824d44
commit 65561ab00e
13 changed files with 423 additions and 130 deletions

View File

@@ -2,15 +2,12 @@ package lib
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"math/rand/v2"
"net"
"net/http"
@@ -171,29 +168,3 @@ func (state *State) VerifyChallengeToken(name string, expectedKey []byte, w http
return true, nil
}
func (state *State) ChallengeMod(name string, cb func(ctx context.Context, mod api.Module) error) error {
c, ok := state.Challenges[name]
if !ok {
return errors.New("challenge not found")
}
if c.RuntimeModule == nil {
return errors.New("challenge module is nil")
}
ctx := state.WasmContext
mod, err := state.WasmRuntime.InstantiateModule(
ctx,
c.RuntimeModule,
wazero.NewModuleConfig().WithName(name).WithStartFunctions("_initialize"),
)
if err != nil {
return err
}
defer mod.Close(ctx)
err = cb(ctx, mod)
if err != nil {
return err
}
return nil
}

View File

@@ -1,74 +0,0 @@
//go:build !tinygo
package challenge
import (
"context"
"encoding/json"
"errors"
"github.com/tetratelabs/wazero/api"
)
func MakeChallengeCall(ctx context.Context, mod api.Module, in MakeChallengeInput) (*MakeChallengeOutput, error) {
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return nil, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return nil, errors.New("could not write memory")
}
result, err := makeChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return nil, err
}
resultPtr := Allocation(result[0])
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
if !ok {
return nil, errors.New("could not read result")
}
defer free.Call(ctx, uint64(resultPtr.Pointer()))
var out MakeChallengeOutput
err = json.Unmarshal(outData, &out)
if err != nil {
return nil, err
}
return &out, nil
}
func VerifyChallengeCall(ctx context.Context, mod api.Module, in VerifyChallengeInput) (VerifyChallengeOutput, error) {
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return VerifyChallengeOutputError, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return VerifyChallengeOutputError, errors.New("could not write memory")
}
result, err := verifyChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return VerifyChallengeOutputError, err
}
return VerifyChallengeOutput(result[0]), nil
}
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }

View File

@@ -5,8 +5,7 @@ import (
"git.gammaspectra.live/git/go-away/utils/inline"
)
type MakeChallenge func(in Allocation) (out Allocation)
// Allocation is a combination of pointer location in WASM memory and size of it
type Allocation uint64
func NewAllocation(ptr, size uint32) Allocation {

View File

@@ -0,0 +1,195 @@
//go:build !tinygo || !wasip1
package challenge
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"slices"
)
type Runner struct {
context context.Context
runtime wazero.Runtime
modules map[string]wazero.CompiledModule
}
func NewRunner(useNativeCompiler bool) *Runner {
var r Runner
r.context = context.Background()
var runtimeConfig wazero.RuntimeConfig
if useNativeCompiler {
runtimeConfig = wazero.NewRuntimeConfigCompiler()
} else {
runtimeConfig = wazero.NewRuntimeConfigInterpreter()
}
r.runtime = wazero.NewRuntimeWithConfig(r.context, runtimeConfig)
wasi_snapshot_preview1.MustInstantiate(r.context, r.runtime)
r.modules = make(map[string]wazero.CompiledModule)
return &r
}
func (r *Runner) Compile(key string, binary []byte) error {
module, err := r.runtime.CompileModule(r.context, binary)
if err != nil {
return err
}
// check interface
functions := module.ExportedFunctions()
if f, ok := functions["MakeChallenge"]; ok {
if slices.Compare(f.ParamTypes(), []api.ValueType{api.ValueTypeI64}) != 0 {
return fmt.Errorf("MakeChallenge does not follow parameter interface")
}
if slices.Compare(f.ResultTypes(), []api.ValueType{api.ValueTypeI64}) != 0 {
return fmt.Errorf("MakeChallenge does not follow result interface")
}
} else {
module.Close(r.context)
return errors.New("no MakeChallenge exported")
}
if f, ok := functions["VerifyChallenge"]; ok {
if slices.Compare(f.ParamTypes(), []api.ValueType{api.ValueTypeI64}) != 0 {
return fmt.Errorf("VerifyChallenge does not follow parameter interface")
}
if slices.Compare(f.ResultTypes(), []api.ValueType{api.ValueTypeI64}) != 0 {
return fmt.Errorf("VerifyChallenge does not follow result interface")
}
} else {
module.Close(r.context)
return errors.New("no VerifyChallenge exported")
}
if f, ok := functions["malloc"]; ok {
if slices.Compare(f.ParamTypes(), []api.ValueType{api.ValueTypeI32}) != 0 {
return fmt.Errorf("malloc does not follow parameter interface")
}
if slices.Compare(f.ResultTypes(), []api.ValueType{api.ValueTypeI32}) != 0 {
return fmt.Errorf("malloc does not follow result interface")
}
} else {
module.Close(r.context)
return errors.New("no malloc exported")
}
if f, ok := functions["free"]; ok {
if slices.Compare(f.ParamTypes(), []api.ValueType{api.ValueTypeI32}) != 0 {
return fmt.Errorf("free does not follow parameter interface")
}
if slices.Compare(f.ResultTypes(), []api.ValueType{}) != 0 {
return fmt.Errorf("free does not follow result interface")
}
} else {
module.Close(r.context)
return errors.New("no free exported")
}
r.modules[key] = module
return nil
}
func (r *Runner) Close() {
for _, module := range r.modules {
module.Close(r.context)
}
r.runtime.Close(r.context)
}
var ErrModuleNotFound = errors.New("module not found")
func (r *Runner) Instantiate(key string, f func(ctx context.Context, mod api.Module) error) (err error) {
compiledModule, ok := r.modules[key]
if !ok {
return ErrModuleNotFound
}
mod, err := r.runtime.InstantiateModule(
r.context,
compiledModule,
wazero.NewModuleConfig().WithName(key).WithStartFunctions("_initialize"),
)
if err != nil {
return err
}
defer mod.Close(r.context)
return f(r.context, mod)
}
func MakeChallengeCall(ctx context.Context, mod api.Module, in MakeChallengeInput) (*MakeChallengeOutput, error) {
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return nil, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return nil, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return nil, errors.New("could not write memory")
}
result, err := makeChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return nil, err
}
resultPtr := Allocation(result[0])
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
if !ok {
return nil, errors.New("could not read result")
}
defer free.Call(ctx, uint64(resultPtr.Pointer()))
var out MakeChallengeOutput
err = json.Unmarshal(outData, &out)
if err != nil {
return nil, err
}
return &out, nil
}
func VerifyChallengeCall(ctx context.Context, mod api.Module, in VerifyChallengeInput) (VerifyChallengeOutput, error) {
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
inData, err := json.Marshal(in)
if err != nil {
return VerifyChallengeOutputError, err
}
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
if err != nil {
return VerifyChallengeOutputError, err
}
defer free.Call(ctx, mallocResult[0])
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
return VerifyChallengeOutputError, errors.New("could not write memory")
}
result, err := verifyChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
if err != nil {
return VerifyChallengeOutputError, err
}
return VerifyChallengeOutput(result[0]), nil
}
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }

View File

@@ -12,16 +12,14 @@ import (
"errors"
"fmt"
"git.gammaspectra.live/git/go-away/embed"
challenge2 "git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/condition"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/utils/inline"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/yl2chen/cidranger"
"html/template"
"io"
@@ -46,8 +44,7 @@ type State struct {
Networks map[string]cidranger.Ranger
WasmRuntime wazero.Runtime
WasmContext context.Context
Wasm *challenge.Runner
Challenges map[string]ChallengeState
@@ -84,8 +81,6 @@ const (
)
type ChallengeState struct {
RuntimeModule wazero.CompiledModule
Path string
Static http.Handler
@@ -191,9 +186,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
state.Networks[k] = ranger
}
state.WasmContext = context.Background()
state.WasmRuntime = wazero.NewRuntimeWithConfig(state.WasmContext, wazero.NewRuntimeConfigCompiler())
wasi_snapshot_preview1.MustInstantiate(state.WasmContext, state.WasmRuntime)
state.Wasm = challenge.NewRunner(true)
state.Challenges = make(map[string]ChallengeState)
@@ -475,15 +468,15 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
if err != nil {
return nil, fmt.Errorf("c %s: could not load runtime: %w", challengeName, err)
}
c.RuntimeModule, err = state.WasmRuntime.CompileModule(state.WasmContext, wasmData)
err = state.Wasm.Compile(challengeName, wasmData)
if err != nil {
return nil, fmt.Errorf("c %s: compiling runtime: %w", challengeName, err)
}
c.MakeChallenge = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := state.ChallengeMod(challengeName, func(ctx context.Context, mod api.Module) (err error) {
err := state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
in := challenge2.MakeChallengeInput{
in := challenge.MakeChallengeInput{
Key: state.GetChallengeKeyForRequest(challengeName, time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity), r),
Parameters: p.Parameters,
Headers: inline.MIMEHeader(r.Header),
@@ -493,7 +486,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
return err
}
out, err := challenge2.MakeChallengeCall(state.WasmContext, mod, in)
out, err := challenge.MakeChallengeCall(ctx, mod, in)
if err != nil {
return err
}
@@ -514,22 +507,22 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
})
c.Verify = func(key []byte, result string) (ok bool, err error) {
err = state.ChallengeMod(challengeName, func(ctx context.Context, mod api.Module) (err error) {
in := challenge2.VerifyChallengeInput{
err = state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
in := challenge.VerifyChallengeInput{
Key: key,
Parameters: p.Parameters,
Result: []byte(result),
}
out, err := challenge2.VerifyChallengeCall(state.WasmContext, mod, in)
out, err := challenge.VerifyChallengeCall(ctx, mod, in)
if err != nil {
return err
}
if out == challenge2.VerifyChallengeOutputError {
if out == challenge.VerifyChallengeOutputError {
return errors.New("error checking challenge")
}
ok = out == challenge2.VerifyChallengeOutputOK
ok = out == challenge.VerifyChallengeOutputOK
return nil
})
if err != nil {