testdata: Initial action/challenges testing

This commit is contained in:
WeebDataHoarder
2025-04-29 05:06:46 +02:00
parent 467ad9c5a9
commit 6a6c3fef07
9 changed files with 889 additions and 4 deletions

View File

@@ -12,6 +12,7 @@ local Build(mirror, go, alpine, os, arch) = {
CGO_ENABLED: "0",
GOOS: os,
GOARCH: arch,
GORACE: "halt_on_error=1"
},
steps: [
{
@@ -26,6 +27,16 @@ local Build(mirror, go, alpine, os, arch) = {
"go build -v -o ./.bin/test-wasm-runtime ./cmd/test-wasm-runtime",
],
},
{
name: "test",
image: "golang:" + go +"-alpine" + alpine,
mirror: mirror,
commands: [
"apk update",
"apk add --no-cache git",
"go test -p 1 -timeout 20m -v ./tests/"
],
},
{
name: "check-policy-forgejo",
image: "alpine:" + alpine,
@@ -92,6 +103,16 @@ local Publish(mirror, registry, repo, secret, go, alpine, os, arch, trigger, pla
},
trigger: trigger,
steps: [
{
name: "test",
image: "golang:" + go +"-alpine" + alpine,
mirror: mirror,
commands: [
"apk update",
"apk add --no-cache git",
"go test -p 1 -timeout 20m -v ./tests/"
],
},
{
name: "setup-buildkitd",
image: "alpine:" + alpine,

View File

@@ -3,6 +3,7 @@ environment:
CGO_ENABLED: "0"
GOARCH: amd64
GOOS: linux
GORACE: halt_on_error=1
GOTOOLCHAIN: local
kind: pipeline
name: build-1.24-alpine3.21-amd64
@@ -19,6 +20,13 @@ steps:
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: build
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80
--policy examples/forgejo.yml --policy-snippets examples/snippets/
@@ -71,6 +79,7 @@ environment:
CGO_ENABLED: "0"
GOARCH: arm64
GOOS: linux
GORACE: halt_on_error=1
GOTOOLCHAIN: local
kind: pipeline
name: build-1.24-alpine3.21-arm64
@@ -87,6 +96,13 @@ steps:
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: build
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80
--policy examples/forgejo.yml --policy-snippets examples/snippets/
@@ -141,6 +157,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -189,6 +212,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -237,6 +267,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -285,6 +322,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -333,6 +377,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -381,6 +432,13 @@ platform:
arch: amd64
os: linux
steps:
- commands:
- apk update
- apk add --no-cache git
- go test -p 1 -timeout 20m -v ./tests/
image: golang:1.24-alpine3.21
mirror: https://mirror.gcr.io
name: test
- commands:
- echo '[registry."docker.io"]' > buildkitd.toml
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
@@ -424,6 +482,6 @@ trigger:
type: docker
---
kind: signature
hmac: 6eab8ae9773b048e780db2bf9d440095eb5615d0baf8da71878069ad7124e167
hmac: 07ac33f9298a9910aacb29ef18931cb999841f76be8a95ca210f9f3704c347f9
...

View File

@@ -7,6 +7,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/policy"
@@ -52,8 +53,8 @@ type State struct {
Mux *http.ServeMux
}
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) {
state := new(State)
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (state *State, err error) {
state = new(State)
state.close = make(chan struct{})
state.settings = settings
state.opt = opt
@@ -264,3 +265,13 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti
return state, nil
}
func (state *State) Close() error {
select {
case <-state.close:
return errors.New("already closed")
default:
close(state.close)
}
return nil
}

280
tests/action_test.go Normal file
View File

@@ -0,0 +1,280 @@
package tests
import (
"encoding/base64"
"errors"
"fmt"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/utils"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func testAction(t *testing.T, pol policy.Policy, expected int, url string) (*http.Response, error) {
settings := setupDefaultSettings(t)
var r *http.Response
err := MakeGoAwayState(pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
request, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
request.Header.Set(settings.ClientIpHeader, "127.0.0.1")
response, err := do(request)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != expected {
return fmt.Errorf("expected status code %d, got %d", expected, response.StatusCode)
}
r = response
return nil
})
return r, err
}
func TestActionPass(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: pass
settings:
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
_, err = testAction(t, *pol, http.StatusOK, "/test")
if err != nil {
t.Fatal(err)
}
}
func TestActionNone(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: none
settings:
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
_, err = testAction(t, *pol, http.StatusOK, "/test")
if err != nil {
t.Fatal(err)
}
}
func TestActionDrop(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: drop
settings:
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
if err != nil {
t.Fatal(err)
}
data, err := io.ReadAll(response.Body)
if err != nil {
t.Fatal(err)
}
if len(data) != 0 {
t.Fatal(fmt.Errorf("expected empty response, got %s", string(data)))
}
}
func TestActionDeny(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: deny
settings:
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
if err != nil {
t.Fatal(err)
}
data, err := io.ReadAll(response.Body)
if err != nil {
t.Fatal(err)
}
if len(data) == 0 {
t.Fatal(errors.New("expected non-empty response, got none"))
}
}
func TestActionBlock(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: block
settings:
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
if err != nil {
t.Fatal(err)
}
data, err := io.ReadAll(response.Body)
if err != nil {
t.Fatal(err)
}
if len(data) == 0 {
t.Fatal(errors.New("expected non-empty response, got none"))
}
}
func TestActionCode(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: code
settings:
http-code: 418
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
_, err = testAction(t, *pol, http.StatusTeapot, "/test")
if err != nil {
t.Fatal(err)
}
}
func TestActionContextResponseHeaders(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test
conditions: ["true"]
action: context
settings:
response-headers:
X-World-Domination: yes
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
response, err := testAction(t, *pol, http.StatusOK, "/test")
if err != nil {
t.Fatal(err)
}
if response.Header.Get("X-World-Domination") != "yes" {
t.Fatal(fmt.Errorf("expected header set, got %s", response.Header.Get("X-World-Domination")))
}
}
func TestActionContextSetMetaTags(t *testing.T) {
pol, err := policy.NewPolicy(strings.NewReader(
`
rules:
- name: test-context
conditions: ["true"]
action: context
settings:
context-set:
proxy-meta-tags: yes
- name: test
conditions: ["true"]
action: deny
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
uri, err := url.Parse("/test")
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
q := uri.Query()
q.Set("mime-type", "text/html")
q.Set("content", base64.RawURLEncoding.EncodeToString([]byte(`
<!DOCTYPE html>
<html>
<head>
<meta name="description" content="test">
</head>
</html>
`)))
uri.RawQuery = q.Encode()
response, err := testAction(t, *pol, http.StatusForbidden, uri.String())
if err != nil {
t.Fatal(err)
}
tags := utils.FetchTagsFromReader(response.Body, "meta")
if str := func() string {
for _, t := range tags {
var is bool
var val string
for _, a := range t.Attr {
if a.Key == "name" && a.Val == "description" {
is = true
}
if a.Key == "content" {
val = a.Val
}
}
if is {
return val
}
}
return "NONE"
}(); str != "test" {
t.Fatal(fmt.Errorf("expected meta tag with 'test', got %s", str))
}
}

34
tests/away.go Normal file
View File

@@ -0,0 +1,34 @@
package tests
import (
"git.gammaspectra.live/git/go-away/lib"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/lib/settings"
"net/http"
"net/http/httptest"
)
var DefaultSettings = policy.StateSettings{
Cache: nil,
Backends: map[string]http.Handler{
"*": MakeTestBackend(),
},
MainName: "go-away/tests",
MainVersion: "testing",
BasePath: "/.go-away",
ChallengeResponseCode: http.StatusTeapot,
ClientIpHeader: "X-Forwarded-For",
}
func MakeGoAwayState(pol policy.Policy, stateSettings policy.StateSettings, f func(do func(r *http.Request, errs ...error) (*http.Response, error)) error) error {
state, err := lib.NewState(pol, settings.DefaultSettings, stateSettings)
if err != nil {
return err
}
return f(func(r *http.Request, errs ...error) (*http.Response, error) {
recorder := httptest.NewRecorder()
state.ServeHTTP(recorder, r)
return recorder.Result(), nil
})
}

57
tests/backend.go Normal file
View File

@@ -0,0 +1,57 @@
package tests
import (
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
)
func MakeTestBackend() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
responseCode := http.StatusOK
var err error
if opt := q.Get("http-code"); opt != "" {
rc, err := strconv.ParseInt(opt, 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
responseCode = int(rc)
}
type ResponseJson struct {
Method string `json:"method"`
Path string `json:"path"`
Query string `json:"query"`
}
if opt := q.Get("mime-type"); opt != "" {
w.Header().Set("Content-Type", opt)
} else {
w.Header().Set("Content-Type", "application/json")
}
var data []byte
if opt := q.Get("content"); opt != "" {
data, err = base64.RawURLEncoding.DecodeString(opt)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
} else {
data, err = json.Marshal(ResponseJson{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
w.WriteHeader(responseCode)
_, _ = w.Write(data)
})
}

362
tests/challenge_test.go Normal file
View File

@@ -0,0 +1,362 @@
package tests
import (
"encoding/hex"
"fmt"
challenge2 "git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/policy"
"golang.org/x/net/html"
"log/slog"
"net/http"
"net/url"
"strings"
"testing"
)
func setupDefaultSettings(t *testing.T) policy.StateSettings {
settings := DefaultSettings
slog.SetDefault(slog.New(initLogger(t)))
return settings
}
func TestChallengeCookie(t *testing.T) {
settings := setupDefaultSettings(t)
pol, err := policy.NewPolicy(strings.NewReader(
`
challenges:
"challenge-cookie":
runtime: "cookie"
rules:
- name: catch-all
conditions: ["true"]
action: challenge
settings:
challenges: ["challenge-cookie"]
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
var expectedCode = http.StatusTemporaryRedirect
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
challengeResponse, err := do(challenge)
if err != nil {
return err
}
defer challengeResponse.Body.Close()
if challengeResponse.StatusCode != expectedCode {
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
} else if cookies := challengeResponse.Cookies(); len(cookies) == 0 {
return fmt.Errorf("expected set cookies to be non-empty, got none")
} else if challengeResponse.Header.Get("Location") == "" {
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
}
solveLocation := challengeResponse.Header.Get("Location")
if !strings.HasPrefix(solveLocation, "/test") {
return fmt.Errorf("expected next location to start with '/test', got %s", solveLocation)
}
// test pass
pass, err := http.NewRequest(http.MethodGet, solveLocation, nil)
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
for _, c := range challengeResponse.Cookies() {
pass.AddCookie(c)
}
response, err := do(pass)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
}
// test failure
fail, err := http.NewRequest(http.MethodGet, solveLocation, nil)
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
response, err = do(fail)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusForbidden {
return fmt.Errorf("expected fail status code %d, got %d", http.StatusForbidden, response.StatusCode)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
func TestChallengeHeaderRefresh(t *testing.T) {
settings := setupDefaultSettings(t)
pol, err := policy.NewPolicy(strings.NewReader(
`
challenges:
"challenge-header-refresh":
runtime: "refresh"
parameters:
refresh-via: "header"
rules:
- name: catch-all
conditions: ["true"]
action: challenge
settings:
challenges: ["challenge-header-refresh"]
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
var expectedCode = settings.ChallengeResponseCode
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
challengeResponse, err := do(challenge)
if err != nil {
return err
}
defer challengeResponse.Body.Close()
if challengeResponse.StatusCode != expectedCode {
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
} else if challengeResponse.Header.Get("Refresh") == "" {
return fmt.Errorf("expected header 'Refresh' to be non-empty, got none")
}
solveLocation, err := url.QueryUnescape(strings.Split(challengeResponse.Header.Get("Refresh"), "; url=")[1])
if err != nil {
return err
}
// test solve
solve, err := http.NewRequest(http.MethodGet, solveLocation, nil)
solve.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
response, err := do(solve)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusTemporaryRedirect {
return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode)
} else if cookies := response.Cookies(); len(cookies) == 0 {
return fmt.Errorf("expected set cookies to be non-empty, got none")
} else if response.Header.Get("Location") == "" {
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
} else if !strings.HasPrefix(response.Header.Get("Location"), "/test") {
return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location"))
}
// test pass
pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil)
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
for _, c := range response.Cookies() {
pass.AddCookie(c)
}
response, err = do(pass)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
}
// test failure
uri, err := url.Parse(solveLocation)
q := uri.Query()
q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize)))
uri.RawQuery = q.Encode()
fail, err := http.NewRequest(http.MethodGet, uri.String(), nil)
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
response, err = do(fail)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
func TestChallengeMetaRefresh(t *testing.T) {
settings := setupDefaultSettings(t)
pol, err := policy.NewPolicy(strings.NewReader(
`
challenges:
"challenge-meta-refresh":
runtime: "refresh"
parameters:
refresh-via: "meta"
rules:
- name: catch-all
conditions: ["true"]
action: challenge
settings:
challenges: ["challenge-meta-refresh"]
`,
))
if err != nil {
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
}
var expectedCode = settings.ChallengeResponseCode
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
challengeResponse, err := do(challenge)
if err != nil {
return err
}
defer challengeResponse.Body.Close()
if challengeResponse.StatusCode != expectedCode {
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
} else if challengeResponse.Header.Get("Refresh") != "" {
return fmt.Errorf("expected header 'Refresh' to be empty, got \"%s\"", challengeResponse.Header.Get("Refresh"))
}
node, err := html.ParseWithOptions(challengeResponse.Body, html.ParseOptionEnableScripting(false))
if err != nil {
return nil
}
var refresh string
for n := range node.Descendants() {
if n.Type == html.ElementNode && n.Data == "meta" {
var is bool
var val string
for _, a := range n.Attr {
if a.Key == "http-equiv" && a.Val == "refresh" {
is = true
}
if a.Key == "content" {
val = a.Val
}
}
if is {
refresh = val
break
}
}
}
solveLocation, err := url.QueryUnescape(strings.Split(refresh, "; url=")[1])
if err != nil {
return err
}
// test solve
solve, err := http.NewRequest(http.MethodGet, solveLocation, nil)
solve.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
response, err := do(solve)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusTemporaryRedirect {
return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode)
} else if cookies := response.Cookies(); len(cookies) == 0 {
return fmt.Errorf("expected set cookies to be non-empty, got none")
} else if response.Header.Get("Location") == "" {
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
} else if !strings.HasPrefix(response.Header.Get("Location"), "/test") {
return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location"))
}
// test pass
pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil)
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
for _, c := range response.Cookies() {
pass.AddCookie(c)
}
response, err = do(pass)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
}
// test failure
uri, err := url.Parse(solveLocation)
q := uri.Query()
q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize)))
uri.RawQuery = q.Encode()
fail, err := http.NewRequest(http.MethodGet, uri.String(), nil)
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
if err != nil {
return err
}
response, err = do(fail)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}

57
tests/logger_test.go Normal file
View File

@@ -0,0 +1,57 @@
package tests
import (
"context"
"fmt"
"log/slog"
"testing"
)
type logger struct {
t *testing.T
attrs []slog.Attr
}
func (l logger) Enabled(ctx context.Context, level slog.Level) bool {
return true
}
func (l logger) Handle(ctx context.Context, record slog.Record) error {
str := fmt.Sprintf("[%s] %s", record.Level, record.Message)
if record.NumAttrs() > 0 || len(l.attrs) > 0 {
str += ": "
}
for _, attr := range l.attrs {
str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String())
}
record.Attrs(func(attr slog.Attr) bool {
str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String())
return true
})
if record.Level == slog.LevelError {
l.t.Error(str)
} else {
l.t.Log(str)
}
return nil
}
func (l logger) WithAttrs(attrs []slog.Attr) slog.Handler {
newAttrs := make([]slog.Attr, 0, len(attrs)+len(l.attrs))
newAttrs = append(newAttrs, l.attrs...)
newAttrs = append(newAttrs, attrs...)
return logger{
t: l.t,
attrs: newAttrs,
}
}
func (l logger) WithGroup(name string) slog.Handler {
return l
}
func initLogger(t *testing.T) slog.Handler {
return logger{t: t}
}

View File

@@ -2,6 +2,7 @@ package utils
import (
"golang.org/x/net/html"
"io"
"mime"
"net/http"
"net/http/httptest"
@@ -32,8 +33,12 @@ func FetchTags(backend http.Handler, uri *url.URL, kind string) (result []html.N
return nil
}
return FetchTagsFromReader(response.Body, kind)
}
func FetchTagsFromReader(r io.Reader, kind string) (result []html.Node) {
//TODO: handle non UTF-8 documents
node, err := html.ParseWithOptions(response.Body, html.ParseOptionEnableScripting(false))
node, err := html.ParseWithOptions(r, html.ParseOptionEnableScripting(false))
if err != nil {
return nil
}