http/query: preserve raw query state when modifying url query

This commit is contained in:
WeebDataHoarder
2025-06-09 13:49:31 +02:00
parent c16f0863ae
commit 9a6f25df59
4 changed files with 58 additions and 18 deletions

View File

@@ -94,13 +94,13 @@ func VerifyUrl(r *http.Request, reg *Registration, token string) (*url.URL, erro
uri.Path = reg.Path + VerifyChallengeUrlSuffix uri.Path = reg.Path + VerifyChallengeUrlSuffix
data := RequestDataFromContext(r.Context()) data := RequestDataFromContext(r.Context())
values := uri.Query() values, _ := utils.ParseRawQuery(r.URL.RawQuery)
values.Set(QueryArgRequestId, data.Id.String()) values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
values.Set(QueryArgRedirect, redirectUrl.String()) values.Set(QueryArgRedirect, url.QueryEscape(redirectUrl.String()))
values.Set(QueryArgToken, token) values.Set(QueryArgToken, url.QueryEscape(token))
values.Set(QueryArgChallenge, reg.Name) values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
values.Set(QueryArgBust, strconv.FormatInt(time.Now().UTC().UnixMilli(), 10)) values.Set(QueryArgBust, url.QueryEscape(strconv.FormatInt(time.Now().UTC().UnixMilli(), 10)))
uri.RawQuery = values.Encode() uri.RawQuery = utils.EncodeRawQuery(values)
return uri, nil return uri, nil
} }
@@ -112,13 +112,13 @@ func RedirectUrl(r *http.Request, reg *Registration) (*url.URL, error) {
} }
data := RequestDataFromContext(r.Context()) data := RequestDataFromContext(r.Context())
values := uri.Query() values, _ := utils.ParseRawQuery(r.URL.RawQuery)
values.Set(QueryArgRequestId, data.Id.String()) values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
if ref := r.Referer(); ref != "" { if ref := r.Referer(); ref != "" {
values.Set(QueryArgReferer, r.Referer()) values.Set(QueryArgReferer, url.QueryEscape(r.Referer()))
} }
values.Set(QueryArgChallenge, reg.Name) values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
uri.RawQuery = values.Encode() uri.RawQuery = utils.EncodeRawQuery(values)
return uri, nil return uri, nil
} }

View File

@@ -69,9 +69,9 @@ func FillRegistration(state challenge.StateInterface, reg *challenge.Registratio
} }
// remove redirect args // remove redirect args
values := uri.Query() values, _ := utils.ParseRawQuery(uri.RawQuery)
values.Del(challenge.QueryArgRedirect) values.Del(challenge.QueryArgRedirect)
uri.RawQuery = values.Encode() uri.RawQuery = utils.EncodeRawQuery(values)
// Redirect URI must be absolute to work // Redirect URI must be absolute to work
uri.Scheme = utils.GetRequestScheme(r) uri.Scheme = utils.GetRequestScheme(r)

View File

@@ -246,19 +246,20 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
if fromChallenge { if fromChallenge {
r.Header.Del("Referer") r.Header.Del("Referer")
} }
q := r.URL.Query()
q := r.URL.Query()
if ref := q.Get(challenge.QueryArgReferer); ref != "" { if ref := q.Get(challenge.QueryArgReferer); ref != "" {
r.Header.Set("Referer", ref) r.Header.Set("Referer", ref)
} }
rawQ, _ := utils.ParseRawQuery(r.URL.RawQuery)
// delete query parameters that were set by go-away // delete query parameters that were set by go-away
for k := range q { for k := range rawQ {
if strings.HasPrefix(k, challenge.QueryArgPrefix) { if strings.HasPrefix(k, challenge.QueryArgPrefix) {
q.Del(k) rawQ.Del(k)
} }
} }
r.URL.RawQuery = q.Encode() r.URL.RawQuery = utils.EncodeRawQuery(rawQ)
data.ExtraHeaders.Set("X-Away-Rule", ruleName) data.ExtraHeaders.Set("X-Away-Rule", ruleName)
data.ExtraHeaders.Set("X-Away-Action", string(ruleAction)) data.ExtraHeaders.Set("X-Away-Action", string(ruleAction))

View File

@@ -7,11 +7,13 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"maps"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/netip" "net/netip"
"net/url" "net/url"
"slices"
"strings" "strings"
"time" "time"
) )
@@ -178,3 +180,40 @@ var staticCacheBust = RandomCacheBust(16)
func StaticCacheBust() string { func StaticCacheBust() string {
return staticCacheBust return staticCacheBust
} }
func ParseRawQuery(rawQuery string) (m url.Values, err error) {
m = make(url.Values)
for rawQuery != "" {
var key string
key, rawQuery, _ = strings.Cut(rawQuery, "&")
if strings.Contains(key, ";") {
err = fmt.Errorf("invalid semicolon separator in query")
continue
}
if key == "" {
continue
}
key, value, _ := strings.Cut(key, "=")
m[key] = append(m[key], value)
}
return m, err
}
func EncodeRawQuery(v url.Values) string {
if len(v) == 0 {
return ""
}
var buf strings.Builder
for _, k := range slices.Sorted(maps.Keys(v)) {
vs := v[k]
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(k)
buf.WriteByte('=')
buf.WriteString(v)
}
}
return buf.String()
}