http/query: preserve raw query state when modifying url query
This commit is contained in:
@@ -94,13 +94,13 @@ func VerifyUrl(r *http.Request, reg *Registration, token string) (*url.URL, erro
|
||||
uri.Path = reg.Path + VerifyChallengeUrlSuffix
|
||||
|
||||
data := RequestDataFromContext(r.Context())
|
||||
values := uri.Query()
|
||||
values.Set(QueryArgRequestId, data.Id.String())
|
||||
values.Set(QueryArgRedirect, redirectUrl.String())
|
||||
values.Set(QueryArgToken, token)
|
||||
values.Set(QueryArgChallenge, reg.Name)
|
||||
values.Set(QueryArgBust, strconv.FormatInt(time.Now().UTC().UnixMilli(), 10))
|
||||
uri.RawQuery = values.Encode()
|
||||
values, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||
values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
|
||||
values.Set(QueryArgRedirect, url.QueryEscape(redirectUrl.String()))
|
||||
values.Set(QueryArgToken, url.QueryEscape(token))
|
||||
values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
|
||||
values.Set(QueryArgBust, url.QueryEscape(strconv.FormatInt(time.Now().UTC().UnixMilli(), 10)))
|
||||
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||
|
||||
return uri, nil
|
||||
}
|
||||
@@ -112,13 +112,13 @@ func RedirectUrl(r *http.Request, reg *Registration) (*url.URL, error) {
|
||||
}
|
||||
|
||||
data := RequestDataFromContext(r.Context())
|
||||
values := uri.Query()
|
||||
values.Set(QueryArgRequestId, data.Id.String())
|
||||
values, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||
values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
|
||||
if ref := r.Referer(); ref != "" {
|
||||
values.Set(QueryArgReferer, r.Referer())
|
||||
values.Set(QueryArgReferer, url.QueryEscape(r.Referer()))
|
||||
}
|
||||
values.Set(QueryArgChallenge, reg.Name)
|
||||
uri.RawQuery = values.Encode()
|
||||
values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
|
||||
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||
|
||||
return uri, nil
|
||||
}
|
||||
|
@@ -69,9 +69,9 @@ func FillRegistration(state challenge.StateInterface, reg *challenge.Registratio
|
||||
}
|
||||
|
||||
// remove redirect args
|
||||
values := uri.Query()
|
||||
values, _ := utils.ParseRawQuery(uri.RawQuery)
|
||||
values.Del(challenge.QueryArgRedirect)
|
||||
uri.RawQuery = values.Encode()
|
||||
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||
|
||||
// Redirect URI must be absolute to work
|
||||
uri.Scheme = utils.GetRequestScheme(r)
|
||||
|
@@ -246,19 +246,20 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if fromChallenge {
|
||||
r.Header.Del("Referer")
|
||||
}
|
||||
q := r.URL.Query()
|
||||
|
||||
q := r.URL.Query()
|
||||
if ref := q.Get(challenge.QueryArgReferer); ref != "" {
|
||||
r.Header.Set("Referer", ref)
|
||||
}
|
||||
|
||||
rawQ, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||
// delete query parameters that were set by go-away
|
||||
for k := range q {
|
||||
for k := range rawQ {
|
||||
if strings.HasPrefix(k, challenge.QueryArgPrefix) {
|
||||
q.Del(k)
|
||||
rawQ.Del(k)
|
||||
}
|
||||
}
|
||||
r.URL.RawQuery = q.Encode()
|
||||
r.URL.RawQuery = utils.EncodeRawQuery(rawQ)
|
||||
|
||||
data.ExtraHeaders.Set("X-Away-Rule", ruleName)
|
||||
data.ExtraHeaders.Set("X-Away-Action", string(ruleAction))
|
||||
|
@@ -7,11 +7,13 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -178,3 +180,40 @@ var staticCacheBust = RandomCacheBust(16)
|
||||
func StaticCacheBust() string {
|
||||
return staticCacheBust
|
||||
}
|
||||
|
||||
func ParseRawQuery(rawQuery string) (m url.Values, err error) {
|
||||
m = make(url.Values)
|
||||
for rawQuery != "" {
|
||||
var key string
|
||||
key, rawQuery, _ = strings.Cut(rawQuery, "&")
|
||||
if strings.Contains(key, ";") {
|
||||
err = fmt.Errorf("invalid semicolon separator in query")
|
||||
continue
|
||||
}
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
key, value, _ := strings.Cut(key, "=")
|
||||
m[key] = append(m[key], value)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
|
||||
func EncodeRawQuery(v url.Values) string {
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
var buf strings.Builder
|
||||
for _, k := range slices.Sorted(maps.Keys(v)) {
|
||||
vs := v[k]
|
||||
for _, v := range vs {
|
||||
if buf.Len() > 0 {
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
buf.WriteString(k)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(v)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
Reference in New Issue
Block a user