diff --git a/lib/challenge.go b/lib/challenge.go index 9eac21e..4d79293 100644 --- a/lib/challenge.go +++ b/lib/challenge.go @@ -3,6 +3,8 @@ package lib import ( "crypto/sha256" "encoding/binary" + "encoding/hex" + "errors" "github.com/go-jose/go-jose/v4/jwt" "net" "net/http" @@ -52,12 +54,44 @@ func getRequestAddress(r *http.Request, clientHeader string) net.IP { return net.ParseIP(ipStr) } -func (state *State) GetChallengeKeyForRequest(challengeName string, until time.Time, r *http.Request) []byte { +type ChallengeKey []byte + +const ChallengeKeySize = sha256.Size + +func (k *ChallengeKey) Set(flags ChallengeKeyFlags) { + (*k)[0] |= uint8(flags) +} +func (k *ChallengeKey) Get(flags ChallengeKeyFlags) ChallengeKeyFlags { + return ChallengeKeyFlags((*k)[0] & uint8(flags)) +} +func (k *ChallengeKey) Unset(flags ChallengeKeyFlags) { + (*k)[0] = (*k)[0] & ^(uint8(flags)) +} + +type ChallengeKeyFlags uint8 + +const ( + ChallengeKeyFlagIsIPv4 = ChallengeKeyFlags(1 << iota) +) + +func ChallengeKeyFromString(s string) (ChallengeKey, error) { + b, err := hex.DecodeString(s) + if err != nil { + return nil, err + } + if len(b) != ChallengeKeySize { + return nil, errors.New("invalid challenge key") + } + return ChallengeKey(b), nil +} + +func (state *State) GetChallengeKeyForRequest(challengeName string, until time.Time, r *http.Request) ChallengeKey { + address := getRequestAddress(r, state.Settings.ClientIpHeader) hasher := sha256.New() hasher.Write([]byte("challenge\x00")) hasher.Write([]byte(challengeName)) hasher.Write([]byte{0}) - hasher.Write(getRequestAddress(r, state.Settings.ClientIpHeader).To16()) + hasher.Write(address.To16()) hasher.Write([]byte{0}) // specific headers @@ -78,5 +112,13 @@ func (state *State) GetChallengeKeyForRequest(challengeName string, until time.T hasher.Write(state.publicKey) hasher.Write([]byte{0}) - return hasher.Sum(nil) + sum := ChallengeKey(hasher.Sum(nil)) + + sum[0] = 0 + + if address.To4() != nil { + // Is IPv4, mark + sum.Set(ChallengeKeyFlagIsIPv4) + } + return ChallengeKey(sum) } diff --git a/lib/state.go b/lib/state.go index 68fbb35..f9b001c 100644 --- a/lib/state.go +++ b/lib/state.go @@ -566,7 +566,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) redirect, err := utils.EnsureNoOpenRedirect(r.FormValue("redirect")) if err != nil { - _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusInternalServerError, err, "") + _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadRequest, err, "") return } @@ -585,29 +585,37 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) if ok, err := c.Verify(key, result, r); err != nil { return err } else if !ok { - state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", redirect) utils.ClearCookie(utils.CookiePrefix+challengeName, w) - + data.Challenges[c.Id] = challenge.VerifyResultFAIL state.SolveChallenge(key, challenge.VerifyResultFAIL) + state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", redirect) - _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName), redirect) - return nil - } + // catch happy eyeballs IPv4 -> IPv6 migration, re-direct to try again + if resultKey, err := ChallengeKeyFromString(result); err == nil && resultKey.Get(ChallengeKeyFlagIsIPv4) > 0 && key.Get(ChallengeKeyFlagIsIPv4) == 0 { - state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", redirect) - - token, err := c.IssueChallengeToken(state.privateKey, key, []byte(result), data.Expires) - if err != nil { - utils.ClearCookie(utils.CookiePrefix+challengeName, w) + } else { + _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName), redirect) + return nil + } } else { - utils.SetCookie(utils.CookiePrefix+challengeName, token, data.Expires, w) - } - data.Challenges[c.Id] = challenge.VerifyResultPASS + state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", redirect) - state.SolveChallenge(key, challenge.VerifyResultPASS) + token, err := c.IssueChallengeToken(state.privateKey, key, []byte(result), data.Expires) + if err != nil { + utils.ClearCookie(utils.CookiePrefix+challengeName, w) + } else { + utils.SetCookie(utils.CookiePrefix+challengeName, token, data.Expires, w) + } + data.Challenges[c.Id] = challenge.VerifyResultPASS + state.SolveChallenge(key, challenge.VerifyResultPASS) + } switch httpCode { case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: + if redirect == "" { + _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadRequest, errors.New("no redirect found"), "") + return nil + } http.Redirect(w, r, redirect, httpCode) default: w.Header().Set("Content-Type", mimeType)