diff --git a/lib/challenge/register.go b/lib/challenge/register.go index 449b975..a92c336 100644 --- a/lib/challenge/register.go +++ b/lib/challenge/register.go @@ -11,6 +11,7 @@ import ( "github.com/go-jose/go-jose/v4/jwt" "github.com/goccy/go-yaml/ast" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" "io" "math/rand/v2" "net/http" @@ -71,6 +72,13 @@ func (r Register) Create(state StateInterface, name string, pol policy.Challenge if err != nil { return nil, 0, fmt.Errorf("error compiling conditions: %v", err) } + + if out := ast.OutputType(); out == nil { + return nil, 0, fmt.Errorf("error compiling conditions: no output") + } else if out != types.BoolType { + return nil, 0, fmt.Errorf("error compiling conditions: output type is not bool") + } + reg.Condition, err = http_cel.ProgramAst(state.ProgramEnv(), ast) if err != nil { return nil, 0, fmt.Errorf("error compiling program: %v", err) diff --git a/lib/rule.go b/lib/rule.go index 2fd5670..c43f41b 100644 --- a/lib/rule.go +++ b/lib/rule.go @@ -71,6 +71,12 @@ func NewRuleState(state challenge.StateInterface, r policy.Rule, replacer *strin return RuleState{}, fmt.Errorf("error compiling conditions: %w", err) } + if out := ast.OutputType(); out == nil { + return RuleState{}, fmt.Errorf("error compiling conditions: no output") + } else if out != types.BoolType { + return RuleState{}, fmt.Errorf("error compiling conditions: output type is not bool") + } + program, err := http_cel.ProgramAst(state.ProgramEnv(), ast) if err != nil { return RuleState{}, fmt.Errorf("error compiling program: %w", err) diff --git a/lib/state.go b/lib/state.go index d820d71..1524262 100644 --- a/lib/state.go +++ b/lib/state.go @@ -13,6 +13,7 @@ import ( "git.gammaspectra.live/git/go-away/lib/settings" "git.gammaspectra.live/git/go-away/utils" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" "github.com/yl2chen/cidranger" "golang.org/x/net/html" "log/slog" @@ -210,6 +211,12 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err) } + if out := ast.OutputType(); out == nil { + return nil, fmt.Errorf("conditions %s: error compiling conditions: no output", k) + } else if out != types.BoolType { + return nil, fmt.Errorf("conditions %s: error compiling conditions: output type is not bool", k) + } + cond, err := cel.AstToString(ast) if err != nil { return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err)