testdata: Initial action/challenges testing
This commit is contained in:
@@ -12,6 +12,7 @@ local Build(mirror, go, alpine, os, arch) = {
|
||||
CGO_ENABLED: "0",
|
||||
GOOS: os,
|
||||
GOARCH: arch,
|
||||
GORACE: "halt_on_error=1"
|
||||
},
|
||||
steps: [
|
||||
{
|
||||
@@ -26,6 +27,16 @@ local Build(mirror, go, alpine, os, arch) = {
|
||||
"go build -v -o ./.bin/test-wasm-runtime ./cmd/test-wasm-runtime",
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "test",
|
||||
image: "golang:" + go +"-alpine" + alpine,
|
||||
mirror: mirror,
|
||||
commands: [
|
||||
"apk update",
|
||||
"apk add --no-cache git",
|
||||
"go test -p 1 -timeout 20m -v ./tests/"
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "check-policy-forgejo",
|
||||
image: "alpine:" + alpine,
|
||||
@@ -92,6 +103,16 @@ local Publish(mirror, registry, repo, secret, go, alpine, os, arch, trigger, pla
|
||||
},
|
||||
trigger: trigger,
|
||||
steps: [
|
||||
{
|
||||
name: "test",
|
||||
image: "golang:" + go +"-alpine" + alpine,
|
||||
mirror: mirror,
|
||||
commands: [
|
||||
"apk update",
|
||||
"apk add --no-cache git",
|
||||
"go test -p 1 -timeout 20m -v ./tests/"
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "setup-buildkitd",
|
||||
image: "alpine:" + alpine,
|
||||
|
60
.drone.yml
60
.drone.yml
@@ -3,6 +3,7 @@ environment:
|
||||
CGO_ENABLED: "0"
|
||||
GOARCH: amd64
|
||||
GOOS: linux
|
||||
GORACE: halt_on_error=1
|
||||
GOTOOLCHAIN: local
|
||||
kind: pipeline
|
||||
name: build-1.24-alpine3.21-amd64
|
||||
@@ -19,6 +20,13 @@ steps:
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: build
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80
|
||||
--policy examples/forgejo.yml --policy-snippets examples/snippets/
|
||||
@@ -71,6 +79,7 @@ environment:
|
||||
CGO_ENABLED: "0"
|
||||
GOARCH: arm64
|
||||
GOOS: linux
|
||||
GORACE: halt_on_error=1
|
||||
GOTOOLCHAIN: local
|
||||
kind: pipeline
|
||||
name: build-1.24-alpine3.21-arm64
|
||||
@@ -87,6 +96,13 @@ steps:
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: build
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80
|
||||
--policy examples/forgejo.yml --policy-snippets examples/snippets/
|
||||
@@ -141,6 +157,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -189,6 +212,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -237,6 +267,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -285,6 +322,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -333,6 +377,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -381,6 +432,13 @@ platform:
|
||||
arch: amd64
|
||||
os: linux
|
||||
steps:
|
||||
- commands:
|
||||
- apk update
|
||||
- apk add --no-cache git
|
||||
- go test -p 1 -timeout 20m -v ./tests/
|
||||
image: golang:1.24-alpine3.21
|
||||
mirror: https://mirror.gcr.io
|
||||
name: test
|
||||
- commands:
|
||||
- echo '[registry."docker.io"]' > buildkitd.toml
|
||||
- echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml
|
||||
@@ -424,6 +482,6 @@ trigger:
|
||||
type: docker
|
||||
---
|
||||
kind: signature
|
||||
hmac: 6eab8ae9773b048e780db2bf9d440095eb5615d0baf8da71878069ad7124e167
|
||||
hmac: 07ac33f9298a9910aacb29ef18931cb999841f76be8a95ca210f9f3704c347f9
|
||||
|
||||
...
|
||||
|
15
lib/state.go
15
lib/state.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
@@ -52,8 +53,8 @@ type State struct {
|
||||
Mux *http.ServeMux
|
||||
}
|
||||
|
||||
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) {
|
||||
state := new(State)
|
||||
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (state *State, err error) {
|
||||
state = new(State)
|
||||
state.close = make(chan struct{})
|
||||
state.settings = settings
|
||||
state.opt = opt
|
||||
@@ -264,3 +265,13 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (state *State) Close() error {
|
||||
select {
|
||||
case <-state.close:
|
||||
return errors.New("already closed")
|
||||
default:
|
||||
close(state.close)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
280
tests/action_test.go
Normal file
280
tests/action_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"git.gammaspectra.live/git/go-away/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testAction(t *testing.T, pol policy.Policy, expected int, url string) (*http.Response, error) {
|
||||
settings := setupDefaultSettings(t)
|
||||
var r *http.Response
|
||||
err := MakeGoAwayState(pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
|
||||
request, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
request.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
response, err := do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != expected {
|
||||
return fmt.Errorf("expected status code %d, got %d", expected, response.StatusCode)
|
||||
}
|
||||
r = response
|
||||
|
||||
return nil
|
||||
})
|
||||
return r, err
|
||||
}
|
||||
|
||||
func TestActionPass(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: pass
|
||||
settings:
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
_, err = testAction(t, *pol, http.StatusOK, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionNone(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: none
|
||||
settings:
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
_, err = testAction(t, *pol, http.StatusOK, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionDrop(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: drop
|
||||
settings:
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 0 {
|
||||
t.Fatal(fmt.Errorf("expected empty response, got %s", string(data)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionDeny(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: deny
|
||||
settings:
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal(errors.New("expected non-empty response, got none"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionBlock(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: block
|
||||
settings:
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
response, err := testAction(t, *pol, http.StatusForbidden, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal(errors.New("expected non-empty response, got none"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionCode(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: code
|
||||
settings:
|
||||
http-code: 418
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
_, err = testAction(t, *pol, http.StatusTeapot, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionContextResponseHeaders(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: context
|
||||
settings:
|
||||
response-headers:
|
||||
X-World-Domination: yes
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
response, err := testAction(t, *pol, http.StatusOK, "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if response.Header.Get("X-World-Domination") != "yes" {
|
||||
t.Fatal(fmt.Errorf("expected header set, got %s", response.Header.Get("X-World-Domination")))
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionContextSetMetaTags(t *testing.T) {
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
rules:
|
||||
- name: test-context
|
||||
conditions: ["true"]
|
||||
action: context
|
||||
settings:
|
||||
context-set:
|
||||
proxy-meta-tags: yes
|
||||
|
||||
- name: test
|
||||
conditions: ["true"]
|
||||
action: deny
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
uri, err := url.Parse("/test")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
q := uri.Query()
|
||||
q.Set("mime-type", "text/html")
|
||||
q.Set("content", base64.RawURLEncoding.EncodeToString([]byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta name="description" content="test">
|
||||
</head>
|
||||
</html>
|
||||
`)))
|
||||
|
||||
uri.RawQuery = q.Encode()
|
||||
|
||||
response, err := testAction(t, *pol, http.StatusForbidden, uri.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tags := utils.FetchTagsFromReader(response.Body, "meta")
|
||||
|
||||
if str := func() string {
|
||||
for _, t := range tags {
|
||||
var is bool
|
||||
var val string
|
||||
for _, a := range t.Attr {
|
||||
if a.Key == "name" && a.Val == "description" {
|
||||
is = true
|
||||
}
|
||||
if a.Key == "content" {
|
||||
val = a.Val
|
||||
}
|
||||
}
|
||||
if is {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return "NONE"
|
||||
}(); str != "test" {
|
||||
t.Fatal(fmt.Errorf("expected meta tag with 'test', got %s", str))
|
||||
}
|
||||
}
|
34
tests/away.go
Normal file
34
tests/away.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"git.gammaspectra.live/git/go-away/lib"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"git.gammaspectra.live/git/go-away/lib/settings"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
)
|
||||
|
||||
var DefaultSettings = policy.StateSettings{
|
||||
Cache: nil,
|
||||
Backends: map[string]http.Handler{
|
||||
"*": MakeTestBackend(),
|
||||
},
|
||||
MainName: "go-away/tests",
|
||||
MainVersion: "testing",
|
||||
BasePath: "/.go-away",
|
||||
ChallengeResponseCode: http.StatusTeapot,
|
||||
ClientIpHeader: "X-Forwarded-For",
|
||||
}
|
||||
|
||||
func MakeGoAwayState(pol policy.Policy, stateSettings policy.StateSettings, f func(do func(r *http.Request, errs ...error) (*http.Response, error)) error) error {
|
||||
state, err := lib.NewState(pol, settings.DefaultSettings, stateSettings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return f(func(r *http.Request, errs ...error) (*http.Response, error) {
|
||||
recorder := httptest.NewRecorder()
|
||||
state.ServeHTTP(recorder, r)
|
||||
return recorder.Result(), nil
|
||||
})
|
||||
}
|
57
tests/backend.go
Normal file
57
tests/backend.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func MakeTestBackend() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
responseCode := http.StatusOK
|
||||
var err error
|
||||
if opt := q.Get("http-code"); opt != "" {
|
||||
rc, err := strconv.ParseInt(opt, 10, 64)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
responseCode = int(rc)
|
||||
}
|
||||
type ResponseJson struct {
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
if opt := q.Get("mime-type"); opt != "" {
|
||||
w.Header().Set("Content-Type", opt)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if opt := q.Get("content"); opt != "" {
|
||||
data, err = base64.RawURLEncoding.DecodeString(opt)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
data, err = json.Marshal(ResponseJson{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Query: r.URL.RawQuery,
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(responseCode)
|
||||
_, _ = w.Write(data)
|
||||
})
|
||||
}
|
362
tests/challenge_test.go
Normal file
362
tests/challenge_test.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
challenge2 "git.gammaspectra.live/git/go-away/lib/challenge"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"golang.org/x/net/html"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupDefaultSettings(t *testing.T) policy.StateSettings {
|
||||
settings := DefaultSettings
|
||||
slog.SetDefault(slog.New(initLogger(t)))
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func TestChallengeCookie(t *testing.T) {
|
||||
settings := setupDefaultSettings(t)
|
||||
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
challenges:
|
||||
"challenge-cookie":
|
||||
runtime: "cookie"
|
||||
|
||||
rules:
|
||||
- name: catch-all
|
||||
conditions: ["true"]
|
||||
action: challenge
|
||||
settings:
|
||||
challenges: ["challenge-cookie"]
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
var expectedCode = http.StatusTemporaryRedirect
|
||||
|
||||
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
|
||||
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
challengeResponse, err := do(challenge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer challengeResponse.Body.Close()
|
||||
if challengeResponse.StatusCode != expectedCode {
|
||||
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
|
||||
} else if cookies := challengeResponse.Cookies(); len(cookies) == 0 {
|
||||
return fmt.Errorf("expected set cookies to be non-empty, got none")
|
||||
} else if challengeResponse.Header.Get("Location") == "" {
|
||||
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
|
||||
}
|
||||
|
||||
solveLocation := challengeResponse.Header.Get("Location")
|
||||
|
||||
if !strings.HasPrefix(solveLocation, "/test") {
|
||||
return fmt.Errorf("expected next location to start with '/test', got %s", solveLocation)
|
||||
}
|
||||
|
||||
// test pass
|
||||
pass, err := http.NewRequest(http.MethodGet, solveLocation, nil)
|
||||
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range challengeResponse.Cookies() {
|
||||
pass.AddCookie(c)
|
||||
}
|
||||
|
||||
response, err := do(pass)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
// test failure
|
||||
fail, err := http.NewRequest(http.MethodGet, solveLocation, nil)
|
||||
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err = do(fail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusForbidden {
|
||||
return fmt.Errorf("expected fail status code %d, got %d", http.StatusForbidden, response.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallengeHeaderRefresh(t *testing.T) {
|
||||
settings := setupDefaultSettings(t)
|
||||
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
challenges:
|
||||
"challenge-header-refresh":
|
||||
runtime: "refresh"
|
||||
parameters:
|
||||
refresh-via: "header"
|
||||
|
||||
rules:
|
||||
- name: catch-all
|
||||
conditions: ["true"]
|
||||
action: challenge
|
||||
settings:
|
||||
challenges: ["challenge-header-refresh"]
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
var expectedCode = settings.ChallengeResponseCode
|
||||
|
||||
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
|
||||
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
challengeResponse, err := do(challenge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer challengeResponse.Body.Close()
|
||||
if challengeResponse.StatusCode != expectedCode {
|
||||
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
|
||||
} else if challengeResponse.Header.Get("Refresh") == "" {
|
||||
return fmt.Errorf("expected header 'Refresh' to be non-empty, got none")
|
||||
}
|
||||
|
||||
solveLocation, err := url.QueryUnescape(strings.Split(challengeResponse.Header.Get("Refresh"), "; url=")[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// test solve
|
||||
solve, err := http.NewRequest(http.MethodGet, solveLocation, nil)
|
||||
solve.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := do(solve)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusTemporaryRedirect {
|
||||
return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode)
|
||||
} else if cookies := response.Cookies(); len(cookies) == 0 {
|
||||
return fmt.Errorf("expected set cookies to be non-empty, got none")
|
||||
} else if response.Header.Get("Location") == "" {
|
||||
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
|
||||
} else if !strings.HasPrefix(response.Header.Get("Location"), "/test") {
|
||||
return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location"))
|
||||
}
|
||||
|
||||
// test pass
|
||||
pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil)
|
||||
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range response.Cookies() {
|
||||
pass.AddCookie(c)
|
||||
}
|
||||
|
||||
response, err = do(pass)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
// test failure
|
||||
uri, err := url.Parse(solveLocation)
|
||||
q := uri.Query()
|
||||
q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize)))
|
||||
uri.RawQuery = q.Encode()
|
||||
|
||||
fail, err := http.NewRequest(http.MethodGet, uri.String(), nil)
|
||||
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err = do(fail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallengeMetaRefresh(t *testing.T) {
|
||||
settings := setupDefaultSettings(t)
|
||||
|
||||
pol, err := policy.NewPolicy(strings.NewReader(
|
||||
`
|
||||
challenges:
|
||||
"challenge-meta-refresh":
|
||||
runtime: "refresh"
|
||||
parameters:
|
||||
refresh-via: "meta"
|
||||
|
||||
rules:
|
||||
- name: catch-all
|
||||
conditions: ["true"]
|
||||
action: challenge
|
||||
settings:
|
||||
challenges: ["challenge-meta-refresh"]
|
||||
|
||||
`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to create policy: %w", err))
|
||||
}
|
||||
|
||||
var expectedCode = settings.ChallengeResponseCode
|
||||
|
||||
err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error {
|
||||
challenge, err := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
challengeResponse, err := do(challenge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer challengeResponse.Body.Close()
|
||||
if challengeResponse.StatusCode != expectedCode {
|
||||
return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode)
|
||||
} else if challengeResponse.Header.Get("Refresh") != "" {
|
||||
return fmt.Errorf("expected header 'Refresh' to be empty, got \"%s\"", challengeResponse.Header.Get("Refresh"))
|
||||
}
|
||||
|
||||
node, err := html.ParseWithOptions(challengeResponse.Body, html.ParseOptionEnableScripting(false))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var refresh string
|
||||
for n := range node.Descendants() {
|
||||
if n.Type == html.ElementNode && n.Data == "meta" {
|
||||
var is bool
|
||||
var val string
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "http-equiv" && a.Val == "refresh" {
|
||||
is = true
|
||||
}
|
||||
if a.Key == "content" {
|
||||
val = a.Val
|
||||
}
|
||||
}
|
||||
if is {
|
||||
refresh = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
solveLocation, err := url.QueryUnescape(strings.Split(refresh, "; url=")[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// test solve
|
||||
solve, err := http.NewRequest(http.MethodGet, solveLocation, nil)
|
||||
solve.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := do(solve)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusTemporaryRedirect {
|
||||
return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode)
|
||||
} else if cookies := response.Cookies(); len(cookies) == 0 {
|
||||
return fmt.Errorf("expected set cookies to be non-empty, got none")
|
||||
} else if response.Header.Get("Location") == "" {
|
||||
return fmt.Errorf("expected header 'Location' to be non-empty, got none")
|
||||
} else if !strings.HasPrefix(response.Header.Get("Location"), "/test") {
|
||||
return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location"))
|
||||
}
|
||||
|
||||
// test pass
|
||||
pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil)
|
||||
pass.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range response.Cookies() {
|
||||
pass.AddCookie(c)
|
||||
}
|
||||
|
||||
response, err = do(pass)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
// test failure
|
||||
uri, err := url.Parse(solveLocation)
|
||||
q := uri.Query()
|
||||
q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize)))
|
||||
uri.RawQuery = q.Encode()
|
||||
|
||||
fail, err := http.NewRequest(http.MethodGet, uri.String(), nil)
|
||||
fail.Header.Set(settings.ClientIpHeader, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err = do(fail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
57
tests/logger_test.go
Normal file
57
tests/logger_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
t *testing.T
|
||||
attrs []slog.Attr
|
||||
}
|
||||
|
||||
func (l logger) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (l logger) Handle(ctx context.Context, record slog.Record) error {
|
||||
str := fmt.Sprintf("[%s] %s", record.Level, record.Message)
|
||||
|
||||
if record.NumAttrs() > 0 || len(l.attrs) > 0 {
|
||||
str += ": "
|
||||
}
|
||||
for _, attr := range l.attrs {
|
||||
str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String())
|
||||
}
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String())
|
||||
return true
|
||||
})
|
||||
|
||||
if record.Level == slog.LevelError {
|
||||
l.t.Error(str)
|
||||
} else {
|
||||
l.t.Log(str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l logger) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newAttrs := make([]slog.Attr, 0, len(attrs)+len(l.attrs))
|
||||
newAttrs = append(newAttrs, l.attrs...)
|
||||
newAttrs = append(newAttrs, attrs...)
|
||||
return logger{
|
||||
t: l.t,
|
||||
attrs: newAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
func (l logger) WithGroup(name string) slog.Handler {
|
||||
return l
|
||||
}
|
||||
|
||||
func initLogger(t *testing.T) slog.Handler {
|
||||
return logger{t: t}
|
||||
}
|
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"golang.org/x/net/html"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -32,8 +33,12 @@ func FetchTags(backend http.Handler, uri *url.URL, kind string) (result []html.N
|
||||
return nil
|
||||
}
|
||||
|
||||
return FetchTagsFromReader(response.Body, kind)
|
||||
}
|
||||
|
||||
func FetchTagsFromReader(r io.Reader, kind string) (result []html.Node) {
|
||||
//TODO: handle non UTF-8 documents
|
||||
node, err := html.ParseWithOptions(response.Body, html.ParseOptionEnableScripting(false))
|
||||
node, err := html.ParseWithOptions(r, html.ParseOptionEnableScripting(false))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user