diff --git a/.drone.jsonnet b/.drone.jsonnet index 3fa475f..baa8e54 100644 --- a/.drone.jsonnet +++ b/.drone.jsonnet @@ -12,6 +12,7 @@ local Build(mirror, go, alpine, os, arch) = { CGO_ENABLED: "0", GOOS: os, GOARCH: arch, + GORACE: "halt_on_error=1" }, steps: [ { @@ -26,6 +27,16 @@ local Build(mirror, go, alpine, os, arch) = { "go build -v -o ./.bin/test-wasm-runtime ./cmd/test-wasm-runtime", ], }, + { + name: "test", + image: "golang:" + go +"-alpine" + alpine, + mirror: mirror, + commands: [ + "apk update", + "apk add --no-cache git", + "go test -p 1 -timeout 20m -v ./tests/" + ], + }, { name: "check-policy-forgejo", image: "alpine:" + alpine, @@ -92,6 +103,16 @@ local Publish(mirror, registry, repo, secret, go, alpine, os, arch, trigger, pla }, trigger: trigger, steps: [ + { + name: "test", + image: "golang:" + go +"-alpine" + alpine, + mirror: mirror, + commands: [ + "apk update", + "apk add --no-cache git", + "go test -p 1 -timeout 20m -v ./tests/" + ], + }, { name: "setup-buildkitd", image: "alpine:" + alpine, diff --git a/.drone.yml b/.drone.yml index 1d6067f..6c33826 100644 --- a/.drone.yml +++ b/.drone.yml @@ -3,6 +3,7 @@ environment: CGO_ENABLED: "0" GOARCH: amd64 GOOS: linux + GORACE: halt_on_error=1 GOTOOLCHAIN: local kind: pipeline name: build-1.24-alpine3.21-amd64 @@ -19,6 +20,13 @@ steps: image: golang:1.24-alpine3.21 mirror: https://mirror.gcr.io name: build +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80 --policy examples/forgejo.yml --policy-snippets examples/snippets/ @@ -71,6 +79,7 @@ environment: CGO_ENABLED: "0" GOARCH: arm64 GOOS: linux + GORACE: halt_on_error=1 GOTOOLCHAIN: local kind: pipeline name: build-1.24-alpine3.21-arm64 @@ -87,6 +96,13 @@ steps: image: golang:1.24-alpine3.21 mirror: https://mirror.gcr.io name: build +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - ./.bin/go-away --check --slog-level DEBUG --backend example.com=http://127.0.0.1:80 --policy examples/forgejo.yml --policy-snippets examples/snippets/ @@ -141,6 +157,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -189,6 +212,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -237,6 +267,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -285,6 +322,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -333,6 +377,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -381,6 +432,13 @@ platform: arch: amd64 os: linux steps: +- commands: + - apk update + - apk add --no-cache git + - go test -p 1 -timeout 20m -v ./tests/ + image: golang:1.24-alpine3.21 + mirror: https://mirror.gcr.io + name: test - commands: - echo '[registry."docker.io"]' > buildkitd.toml - echo ' mirrors = ["mirror.gcr.io"]' >> buildkitd.toml @@ -424,6 +482,6 @@ trigger: type: docker --- kind: signature -hmac: 6eab8ae9773b048e780db2bf9d440095eb5615d0baf8da71878069ad7124e167 +hmac: 07ac33f9298a9910aacb29ef18931cb999841f76be8a95ca210f9f3704c347f9 ... diff --git a/lib/state.go b/lib/state.go index b853297..16b98f2 100644 --- a/lib/state.go +++ b/lib/state.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "git.gammaspectra.live/git/go-away/lib/challenge" "git.gammaspectra.live/git/go-away/lib/policy" @@ -52,8 +53,8 @@ type State struct { Mux *http.ServeMux } -func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) { - state := new(State) +func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (state *State, err error) { + state = new(State) state.close = make(chan struct{}) state.settings = settings state.opt = opt @@ -264,3 +265,13 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti return state, nil } + +func (state *State) Close() error { + select { + case <-state.close: + return errors.New("already closed") + default: + close(state.close) + } + return nil +} diff --git a/tests/action_test.go b/tests/action_test.go new file mode 100644 index 0000000..effeb6f --- /dev/null +++ b/tests/action_test.go @@ -0,0 +1,280 @@ +package tests + +import ( + "encoding/base64" + "errors" + "fmt" + "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/utils" + "io" + "net/http" + "net/url" + "strings" + "testing" +) + +func testAction(t *testing.T, pol policy.Policy, expected int, url string) (*http.Response, error) { + settings := setupDefaultSettings(t) + var r *http.Response + err := MakeGoAwayState(pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error { + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return err + } + request.Header.Set(settings.ClientIpHeader, "127.0.0.1") + response, err := do(request) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != expected { + return fmt.Errorf("expected status code %d, got %d", expected, response.StatusCode) + } + r = response + + return nil + }) + return r, err +} + +func TestActionPass(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: pass + settings: + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + _, err = testAction(t, *pol, http.StatusOK, "/test") + if err != nil { + t.Fatal(err) + } +} + +func TestActionNone(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: none + settings: + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + _, err = testAction(t, *pol, http.StatusOK, "/test") + if err != nil { + t.Fatal(err) + } +} + +func TestActionDrop(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: drop + settings: + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + response, err := testAction(t, *pol, http.StatusForbidden, "/test") + if err != nil { + t.Fatal(err) + } + data, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if len(data) != 0 { + t.Fatal(fmt.Errorf("expected empty response, got %s", string(data))) + } +} + +func TestActionDeny(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: deny + settings: + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + response, err := testAction(t, *pol, http.StatusForbidden, "/test") + if err != nil { + t.Fatal(err) + } + data, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fatal(errors.New("expected non-empty response, got none")) + } +} + +func TestActionBlock(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: block + settings: + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + response, err := testAction(t, *pol, http.StatusForbidden, "/test") + if err != nil { + t.Fatal(err) + } + data, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fatal(errors.New("expected non-empty response, got none")) + } +} + +func TestActionCode(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: code + settings: + http-code: 418 + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + _, err = testAction(t, *pol, http.StatusTeapot, "/test") + if err != nil { + t.Fatal(err) + } +} + +func TestActionContextResponseHeaders(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test + conditions: ["true"] + action: context + settings: + response-headers: + X-World-Domination: yes + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + response, err := testAction(t, *pol, http.StatusOK, "/test") + if err != nil { + t.Fatal(err) + } + + if response.Header.Get("X-World-Domination") != "yes" { + t.Fatal(fmt.Errorf("expected header set, got %s", response.Header.Get("X-World-Domination"))) + } +} + +func TestActionContextSetMetaTags(t *testing.T) { + pol, err := policy.NewPolicy(strings.NewReader( + ` +rules: + - name: test-context + conditions: ["true"] + action: context + settings: + context-set: + proxy-meta-tags: yes + + - name: test + conditions: ["true"] + action: deny + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + uri, err := url.Parse("/test") + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + q := uri.Query() + q.Set("mime-type", "text/html") + q.Set("content", base64.RawURLEncoding.EncodeToString([]byte(` + + + + + + +`))) + + uri.RawQuery = q.Encode() + + response, err := testAction(t, *pol, http.StatusForbidden, uri.String()) + if err != nil { + t.Fatal(err) + } + + tags := utils.FetchTagsFromReader(response.Body, "meta") + + if str := func() string { + for _, t := range tags { + var is bool + var val string + for _, a := range t.Attr { + if a.Key == "name" && a.Val == "description" { + is = true + } + if a.Key == "content" { + val = a.Val + } + } + if is { + return val + } + } + return "NONE" + }(); str != "test" { + t.Fatal(fmt.Errorf("expected meta tag with 'test', got %s", str)) + } +} diff --git a/tests/away.go b/tests/away.go new file mode 100644 index 0000000..8a81f09 --- /dev/null +++ b/tests/away.go @@ -0,0 +1,34 @@ +package tests + +import ( + "git.gammaspectra.live/git/go-away/lib" + "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/lib/settings" + "net/http" + "net/http/httptest" +) + +var DefaultSettings = policy.StateSettings{ + Cache: nil, + Backends: map[string]http.Handler{ + "*": MakeTestBackend(), + }, + MainName: "go-away/tests", + MainVersion: "testing", + BasePath: "/.go-away", + ChallengeResponseCode: http.StatusTeapot, + ClientIpHeader: "X-Forwarded-For", +} + +func MakeGoAwayState(pol policy.Policy, stateSettings policy.StateSettings, f func(do func(r *http.Request, errs ...error) (*http.Response, error)) error) error { + state, err := lib.NewState(pol, settings.DefaultSettings, stateSettings) + if err != nil { + return err + } + + return f(func(r *http.Request, errs ...error) (*http.Response, error) { + recorder := httptest.NewRecorder() + state.ServeHTTP(recorder, r) + return recorder.Result(), nil + }) +} diff --git a/tests/backend.go b/tests/backend.go new file mode 100644 index 0000000..12a35bf --- /dev/null +++ b/tests/backend.go @@ -0,0 +1,57 @@ +package tests + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "strconv" +) + +func MakeTestBackend() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + responseCode := http.StatusOK + var err error + if opt := q.Get("http-code"); opt != "" { + rc, err := strconv.ParseInt(opt, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + responseCode = int(rc) + } + type ResponseJson struct { + Method string `json:"method"` + Path string `json:"path"` + Query string `json:"query"` + } + + if opt := q.Get("mime-type"); opt != "" { + w.Header().Set("Content-Type", opt) + } else { + w.Header().Set("Content-Type", "application/json") + } + + var data []byte + if opt := q.Get("content"); opt != "" { + data, err = base64.RawURLEncoding.DecodeString(opt) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + } else { + data, err = json.Marshal(ResponseJson{ + Method: r.Method, + Path: r.URL.Path, + Query: r.URL.RawQuery, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + w.WriteHeader(responseCode) + _, _ = w.Write(data) + }) +} diff --git a/tests/challenge_test.go b/tests/challenge_test.go new file mode 100644 index 0000000..0aca367 --- /dev/null +++ b/tests/challenge_test.go @@ -0,0 +1,362 @@ +package tests + +import ( + "encoding/hex" + "fmt" + challenge2 "git.gammaspectra.live/git/go-away/lib/challenge" + "git.gammaspectra.live/git/go-away/lib/policy" + "golang.org/x/net/html" + "log/slog" + "net/http" + "net/url" + "strings" + "testing" +) + +func setupDefaultSettings(t *testing.T) policy.StateSettings { + settings := DefaultSettings + slog.SetDefault(slog.New(initLogger(t))) + + return settings +} + +func TestChallengeCookie(t *testing.T) { + settings := setupDefaultSettings(t) + + pol, err := policy.NewPolicy(strings.NewReader( + ` +challenges: + "challenge-cookie": + runtime: "cookie" + +rules: + - name: catch-all + conditions: ["true"] + action: challenge + settings: + challenges: ["challenge-cookie"] + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + var expectedCode = http.StatusTemporaryRedirect + + err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error { + challenge, err := http.NewRequest(http.MethodGet, "/test", nil) + challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1") + challengeResponse, err := do(challenge) + if err != nil { + return err + } + defer challengeResponse.Body.Close() + if challengeResponse.StatusCode != expectedCode { + return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode) + } else if cookies := challengeResponse.Cookies(); len(cookies) == 0 { + return fmt.Errorf("expected set cookies to be non-empty, got none") + } else if challengeResponse.Header.Get("Location") == "" { + return fmt.Errorf("expected header 'Location' to be non-empty, got none") + } + + solveLocation := challengeResponse.Header.Get("Location") + + if !strings.HasPrefix(solveLocation, "/test") { + return fmt.Errorf("expected next location to start with '/test', got %s", solveLocation) + } + + // test pass + pass, err := http.NewRequest(http.MethodGet, solveLocation, nil) + pass.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + for _, c := range challengeResponse.Cookies() { + pass.AddCookie(c) + } + + response, err := do(pass) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode) + } + + // test failure + fail, err := http.NewRequest(http.MethodGet, solveLocation, nil) + fail.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + + response, err = do(fail) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusForbidden { + return fmt.Errorf("expected fail status code %d, got %d", http.StatusForbidden, response.StatusCode) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestChallengeHeaderRefresh(t *testing.T) { + settings := setupDefaultSettings(t) + + pol, err := policy.NewPolicy(strings.NewReader( + ` +challenges: + "challenge-header-refresh": + runtime: "refresh" + parameters: + refresh-via: "header" + +rules: + - name: catch-all + conditions: ["true"] + action: challenge + settings: + challenges: ["challenge-header-refresh"] + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + var expectedCode = settings.ChallengeResponseCode + + err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error { + challenge, err := http.NewRequest(http.MethodGet, "/test", nil) + challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1") + challengeResponse, err := do(challenge) + if err != nil { + return err + } + defer challengeResponse.Body.Close() + if challengeResponse.StatusCode != expectedCode { + return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode) + } else if challengeResponse.Header.Get("Refresh") == "" { + return fmt.Errorf("expected header 'Refresh' to be non-empty, got none") + } + + solveLocation, err := url.QueryUnescape(strings.Split(challengeResponse.Header.Get("Refresh"), "; url=")[1]) + if err != nil { + return err + } + + // test solve + solve, err := http.NewRequest(http.MethodGet, solveLocation, nil) + solve.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + + response, err := do(solve) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusTemporaryRedirect { + return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode) + } else if cookies := response.Cookies(); len(cookies) == 0 { + return fmt.Errorf("expected set cookies to be non-empty, got none") + } else if response.Header.Get("Location") == "" { + return fmt.Errorf("expected header 'Location' to be non-empty, got none") + } else if !strings.HasPrefix(response.Header.Get("Location"), "/test") { + return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location")) + } + + // test pass + pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil) + pass.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + for _, c := range response.Cookies() { + pass.AddCookie(c) + } + + response, err = do(pass) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode) + } + + // test failure + uri, err := url.Parse(solveLocation) + q := uri.Query() + q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize))) + uri.RawQuery = q.Encode() + + fail, err := http.NewRequest(http.MethodGet, uri.String(), nil) + fail.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + + response, err = do(fail) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestChallengeMetaRefresh(t *testing.T) { + settings := setupDefaultSettings(t) + + pol, err := policy.NewPolicy(strings.NewReader( + ` +challenges: + "challenge-meta-refresh": + runtime: "refresh" + parameters: + refresh-via: "meta" + +rules: + - name: catch-all + conditions: ["true"] + action: challenge + settings: + challenges: ["challenge-meta-refresh"] + +`, + )) + if err != nil { + t.Fatal(fmt.Errorf("failed to create policy: %w", err)) + } + + var expectedCode = settings.ChallengeResponseCode + + err = MakeGoAwayState(*pol, settings, func(do func(r *http.Request, errs ...error) (*http.Response, error)) error { + challenge, err := http.NewRequest(http.MethodGet, "/test", nil) + challenge.Header.Set(settings.ClientIpHeader, "127.0.0.1") + challengeResponse, err := do(challenge) + if err != nil { + return err + } + defer challengeResponse.Body.Close() + if challengeResponse.StatusCode != expectedCode { + return fmt.Errorf("expected challenge status code %d, got %d", expectedCode, challengeResponse.StatusCode) + } else if challengeResponse.Header.Get("Refresh") != "" { + return fmt.Errorf("expected header 'Refresh' to be empty, got \"%s\"", challengeResponse.Header.Get("Refresh")) + } + + node, err := html.ParseWithOptions(challengeResponse.Body, html.ParseOptionEnableScripting(false)) + if err != nil { + return nil + } + + var refresh string + for n := range node.Descendants() { + if n.Type == html.ElementNode && n.Data == "meta" { + var is bool + var val string + for _, a := range n.Attr { + if a.Key == "http-equiv" && a.Val == "refresh" { + is = true + } + if a.Key == "content" { + val = a.Val + } + } + if is { + refresh = val + break + } + } + } + + solveLocation, err := url.QueryUnescape(strings.Split(refresh, "; url=")[1]) + if err != nil { + return err + } + + // test solve + solve, err := http.NewRequest(http.MethodGet, solveLocation, nil) + solve.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + + response, err := do(solve) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusTemporaryRedirect { + return fmt.Errorf("expected solve status code %d, got %d", http.StatusTemporaryRedirect, response.StatusCode) + } else if cookies := response.Cookies(); len(cookies) == 0 { + return fmt.Errorf("expected set cookies to be non-empty, got none") + } else if response.Header.Get("Location") == "" { + return fmt.Errorf("expected header 'Location' to be non-empty, got none") + } else if !strings.HasPrefix(response.Header.Get("Location"), "/test") { + return fmt.Errorf("expected next location to start with '/test', got %s", response.Header.Get("Location")) + } + + // test pass + pass, err := http.NewRequest(http.MethodGet, response.Header.Get("Location"), nil) + pass.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + for _, c := range response.Cookies() { + pass.AddCookie(c) + } + + response, err = do(pass) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("expected pass status code %d, got %d", http.StatusOK, response.StatusCode) + } + + // test failure + uri, err := url.Parse(solveLocation) + q := uri.Query() + q.Set(challenge2.QueryArgToken, hex.EncodeToString(make([]byte, challenge2.KeySize))) + uri.RawQuery = q.Encode() + + fail, err := http.NewRequest(http.MethodGet, uri.String(), nil) + fail.Header.Set(settings.ClientIpHeader, "127.0.0.1") + if err != nil { + return err + } + + response, err = do(fail) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected fail status code %d, got %d", http.StatusBadRequest, response.StatusCode) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/tests/logger_test.go b/tests/logger_test.go new file mode 100644 index 0000000..e608ddb --- /dev/null +++ b/tests/logger_test.go @@ -0,0 +1,57 @@ +package tests + +import ( + "context" + "fmt" + "log/slog" + "testing" +) + +type logger struct { + t *testing.T + attrs []slog.Attr +} + +func (l logger) Enabled(ctx context.Context, level slog.Level) bool { + return true +} + +func (l logger) Handle(ctx context.Context, record slog.Record) error { + str := fmt.Sprintf("[%s] %s", record.Level, record.Message) + + if record.NumAttrs() > 0 || len(l.attrs) > 0 { + str += ": " + } + for _, attr := range l.attrs { + str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String()) + } + record.Attrs(func(attr slog.Attr) bool { + str += fmt.Sprintf("%s=%s ", attr.Key, attr.Value.String()) + return true + }) + + if record.Level == slog.LevelError { + l.t.Error(str) + } else { + l.t.Log(str) + } + return nil +} + +func (l logger) WithAttrs(attrs []slog.Attr) slog.Handler { + newAttrs := make([]slog.Attr, 0, len(attrs)+len(l.attrs)) + newAttrs = append(newAttrs, l.attrs...) + newAttrs = append(newAttrs, attrs...) + return logger{ + t: l.t, + attrs: newAttrs, + } +} + +func (l logger) WithGroup(name string) slog.Handler { + return l +} + +func initLogger(t *testing.T) slog.Handler { + return logger{t: t} +} diff --git a/utils/tagfetcher.go b/utils/tagfetcher.go index 6ee90bd..f646f73 100644 --- a/utils/tagfetcher.go +++ b/utils/tagfetcher.go @@ -2,6 +2,7 @@ package utils import ( "golang.org/x/net/html" + "io" "mime" "net/http" "net/http/httptest" @@ -32,8 +33,12 @@ func FetchTags(backend http.Handler, uri *url.URL, kind string) (result []html.N return nil } + return FetchTagsFromReader(response.Body, kind) +} + +func FetchTagsFromReader(r io.Reader, kind string) (result []html.Node) { //TODO: handle non UTF-8 documents - node, err := html.ParseWithOptions(response.Body, html.ParseOptionEnableScripting(false)) + node, err := html.ParseWithOptions(r, html.ParseOptionEnableScripting(false)) if err != nil { return nil }