http/query: preserve raw query state when modifying url query
This commit is contained in:
@@ -94,13 +94,13 @@ func VerifyUrl(r *http.Request, reg *Registration, token string) (*url.URL, erro
|
|||||||
uri.Path = reg.Path + VerifyChallengeUrlSuffix
|
uri.Path = reg.Path + VerifyChallengeUrlSuffix
|
||||||
|
|
||||||
data := RequestDataFromContext(r.Context())
|
data := RequestDataFromContext(r.Context())
|
||||||
values := uri.Query()
|
values, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||||
values.Set(QueryArgRequestId, data.Id.String())
|
values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
|
||||||
values.Set(QueryArgRedirect, redirectUrl.String())
|
values.Set(QueryArgRedirect, url.QueryEscape(redirectUrl.String()))
|
||||||
values.Set(QueryArgToken, token)
|
values.Set(QueryArgToken, url.QueryEscape(token))
|
||||||
values.Set(QueryArgChallenge, reg.Name)
|
values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
|
||||||
values.Set(QueryArgBust, strconv.FormatInt(time.Now().UTC().UnixMilli(), 10))
|
values.Set(QueryArgBust, url.QueryEscape(strconv.FormatInt(time.Now().UTC().UnixMilli(), 10)))
|
||||||
uri.RawQuery = values.Encode()
|
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||||
|
|
||||||
return uri, nil
|
return uri, nil
|
||||||
}
|
}
|
||||||
@@ -112,13 +112,13 @@ func RedirectUrl(r *http.Request, reg *Registration) (*url.URL, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := RequestDataFromContext(r.Context())
|
data := RequestDataFromContext(r.Context())
|
||||||
values := uri.Query()
|
values, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||||
values.Set(QueryArgRequestId, data.Id.String())
|
values.Set(QueryArgRequestId, url.QueryEscape(data.Id.String()))
|
||||||
if ref := r.Referer(); ref != "" {
|
if ref := r.Referer(); ref != "" {
|
||||||
values.Set(QueryArgReferer, r.Referer())
|
values.Set(QueryArgReferer, url.QueryEscape(r.Referer()))
|
||||||
}
|
}
|
||||||
values.Set(QueryArgChallenge, reg.Name)
|
values.Set(QueryArgChallenge, url.QueryEscape(reg.Name))
|
||||||
uri.RawQuery = values.Encode()
|
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||||
|
|
||||||
return uri, nil
|
return uri, nil
|
||||||
}
|
}
|
||||||
|
@@ -69,9 +69,9 @@ func FillRegistration(state challenge.StateInterface, reg *challenge.Registratio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove redirect args
|
// remove redirect args
|
||||||
values := uri.Query()
|
values, _ := utils.ParseRawQuery(uri.RawQuery)
|
||||||
values.Del(challenge.QueryArgRedirect)
|
values.Del(challenge.QueryArgRedirect)
|
||||||
uri.RawQuery = values.Encode()
|
uri.RawQuery = utils.EncodeRawQuery(values)
|
||||||
|
|
||||||
// Redirect URI must be absolute to work
|
// Redirect URI must be absolute to work
|
||||||
uri.Scheme = utils.GetRequestScheme(r)
|
uri.Scheme = utils.GetRequestScheme(r)
|
||||||
|
@@ -246,19 +246,20 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
if fromChallenge {
|
if fromChallenge {
|
||||||
r.Header.Del("Referer")
|
r.Header.Del("Referer")
|
||||||
}
|
}
|
||||||
q := r.URL.Query()
|
|
||||||
|
|
||||||
|
q := r.URL.Query()
|
||||||
if ref := q.Get(challenge.QueryArgReferer); ref != "" {
|
if ref := q.Get(challenge.QueryArgReferer); ref != "" {
|
||||||
r.Header.Set("Referer", ref)
|
r.Header.Set("Referer", ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawQ, _ := utils.ParseRawQuery(r.URL.RawQuery)
|
||||||
// delete query parameters that were set by go-away
|
// delete query parameters that were set by go-away
|
||||||
for k := range q {
|
for k := range rawQ {
|
||||||
if strings.HasPrefix(k, challenge.QueryArgPrefix) {
|
if strings.HasPrefix(k, challenge.QueryArgPrefix) {
|
||||||
q.Del(k)
|
rawQ.Del(k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.URL.RawQuery = q.Encode()
|
r.URL.RawQuery = utils.EncodeRawQuery(rawQ)
|
||||||
|
|
||||||
data.ExtraHeaders.Set("X-Away-Rule", ruleName)
|
data.ExtraHeaders.Set("X-Away-Rule", ruleName)
|
||||||
data.ExtraHeaders.Set("X-Away-Action", string(ruleAction))
|
data.ExtraHeaders.Set("X-Away-Action", string(ruleAction))
|
||||||
|
@@ -7,11 +7,13 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -178,3 +180,40 @@ var staticCacheBust = RandomCacheBust(16)
|
|||||||
func StaticCacheBust() string {
|
func StaticCacheBust() string {
|
||||||
return staticCacheBust
|
return staticCacheBust
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParseRawQuery(rawQuery string) (m url.Values, err error) {
|
||||||
|
m = make(url.Values)
|
||||||
|
for rawQuery != "" {
|
||||||
|
var key string
|
||||||
|
key, rawQuery, _ = strings.Cut(rawQuery, "&")
|
||||||
|
if strings.Contains(key, ";") {
|
||||||
|
err = fmt.Errorf("invalid semicolon separator in query")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, value, _ := strings.Cut(key, "=")
|
||||||
|
m[key] = append(m[key], value)
|
||||||
|
}
|
||||||
|
return m, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeRawQuery(v url.Values) string {
|
||||||
|
if len(v) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var buf strings.Builder
|
||||||
|
for _, k := range slices.Sorted(maps.Keys(v)) {
|
||||||
|
vs := v[k]
|
||||||
|
for _, v := range vs {
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
buf.WriteByte('&')
|
||||||
|
}
|
||||||
|
buf.WriteString(k)
|
||||||
|
buf.WriteByte('=')
|
||||||
|
buf.WriteString(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user