diff --git a/Dockerfile b/Dockerfile index 3eb2675..e01dc39 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,6 +41,7 @@ ENV GOAWAY_CHALLENGE_TEMPLATE="anubis" ENV GOAWAY_CHALLENGE_TEMPLATE_THEME="" ENV GOAWAY_SLOG_LEVEL="WARN" ENV GOAWAY_CLIENT_IP_HEADER="" +ENV GOAWAY_BACKEND_IP_HEADER="" ENV GOAWAY_JWT_PRIVATE_KEY_SEED="" ENV GOAWAY_BACKEND="" ENV GOAWAY_DNSBL="dnsbl.dronebl.org" @@ -53,7 +54,7 @@ EXPOSE 8080/udp ENV JWT_PRIVATE_KEY_SEED="${GOAWAY_JWT_PRIVATE_KEY_SEED}" ENTRYPOINT /bin/go-away --bind ${GOAWAY_BIND} --bind-network ${GOAWAY_BIND_NETWORK} --socket-mode ${GOAWAY_SOCKET_MODE} \ - --policy ${GOAWAY_POLICY} --client-ip-header ${GOAWAY_CLIENT_IP_HEADER} \ + --policy ${GOAWAY_POLICY} --client-ip-header ${GOAWAY_CLIENT_IP_HEADER} --backend-ip-header ${GOAWAY_BACKEND_IP_HEADER} \ --cache ${GOAWAY_CACHE} \ --dnsbl ${GOAWAY_DNSBL} \ --challenge-template ${GOAWAY_CHALLENGE_TEMPLATE} --challenge-template-theme ${GOAWAY_CHALLENGE_TEMPLATE_THEME} \ diff --git a/cmd/go-away/main.go b/cmd/go-away/main.go index dac6566..f8f0fad 100644 --- a/cmd/go-away/main.go +++ b/cmd/go-away/main.go @@ -153,6 +153,7 @@ func main() { acmeAutocert := flag.String("acme-autocert", "", "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)") clientIpHeader := flag.String("client-ip-header", "", "Client HTTP header to fetch their IP address from (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)") + backendIpHeader := flag.String("backend-ip-header", "", "Backend HTTP header to set the client IP address from, if empty defaults to leaving Client header alone (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)") dnsbl := flag.String("dnsbl", "dnsbl.dronebl.org", "blocklist for DNSBL (default DroneBL)") @@ -345,6 +346,7 @@ func main() { ChallengeTemplateTheme: *challengeTemplateTheme, PrivateKeySeed: seed, ClientIpHeader: *clientIpHeader, + BackendIpHeader: *backendIpHeader, } if *dnsbl != "" { diff --git a/lib/challenge.go b/lib/challenge.go index 4d79293..4dbad49 100644 --- a/lib/challenge.go +++ b/lib/challenge.go @@ -86,7 +86,8 @@ func ChallengeKeyFromString(s string) (ChallengeKey, error) { } func (state *State) GetChallengeKeyForRequest(challengeName string, until time.Time, r *http.Request) ChallengeKey { - address := getRequestAddress(r, state.Settings.ClientIpHeader) + data := RequestDataFromContext(r.Context()) + address := data.RemoteAddress hasher := sha256.New() hasher.Write([]byte("challenge\x00")) hasher.Write([]byte(challengeName)) diff --git a/lib/http.go b/lib/http.go index c8051b6..229a6c3 100644 --- a/lib/http.go +++ b/lib/http.go @@ -18,6 +18,7 @@ import ( "io" "log/slog" "maps" + "net" "net/http" "net/http/pprof" "path" @@ -130,10 +131,11 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration } } -func GetLoggerForRequest(r *http.Request, clientHeader string) *slog.Logger { +func GetLoggerForRequest(r *http.Request) *slog.Logger { + data := RequestDataFromContext(r.Context()) return slog.With( - "request_id", r.Header.Get("X-Away-Id"), - "remote_address", getRequestAddress(r, clientHeader), + "request_id", hex.EncodeToString(data.Id[:]), + "remote_address", data.RemoteAddress.String(), "user_agent", r.UserAgent(), "host", r.Host, "path", r.URL.Path, @@ -142,7 +144,7 @@ func GetLoggerForRequest(r *http.Request, clientHeader string) *slog.Logger { } func (state *State) logger(r *http.Request) *slog.Logger { - return GetLoggerForRequest(r, state.Settings.ClientIpHeader) + return GetLoggerForRequest(r) } func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { @@ -412,15 +414,17 @@ func (state *State) setupRoutes() error { } func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var data RequestData // generate random id, todo: is this fast? _, _ = rand.Read(data.Id[:]) + data.RemoteAddress = getRequestAddress(r, state.Settings.ClientIpHeader) data.Challenges = make(map[challenge.Id]challenge.VerifyResult, len(state.Challenges)) data.Expires = time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity) data.ProgramEnv = map[string]any{ "host": r.Host, "method": r.Method, - "remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader), + "remoteAddress": data.RemoteAddress, "userAgent": r.UserAgent(), "path": r.URL.Path, "query": func() map[string]string { @@ -465,7 +469,11 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) { } r.Header.Set("X-Away-Id", hex.EncodeToString(data.Id[:])) - w.Header().Set("X-Away-Id", hex.EncodeToString(data.Id[:])) + if state.Settings.BackendIpHeader != "" { + r.Header.Del(state.Settings.ClientIpHeader) + r.Header.Set(state.Settings.BackendIpHeader, data.RemoteAddress.String()) + } + w.Header().Add("Via", fmt.Sprintf("%s %s", r.Proto, "go-away")) // send these to client so we consistently get the headers //w.Header().Set("Accept-CH", "Sec-CH-UA, Sec-CH-UA-Platform") @@ -481,10 +489,11 @@ func RequestDataFromContext(ctx context.Context) *RequestData { } type RequestData struct { - Id [16]byte - ProgramEnv map[string]any - Expires time.Time - Challenges map[challenge.Id]challenge.VerifyResult + Id [16]byte + ProgramEnv map[string]any + Expires time.Time + Challenges map[challenge.Id]challenge.VerifyResult + RemoteAddress net.IP } func (d *RequestData) HasValidChallenge(id challenge.Id) bool { diff --git a/lib/state.go b/lib/state.go index 57a42cd..d8476f0 100644 --- a/lib/state.go +++ b/lib/state.go @@ -112,6 +112,7 @@ type StateSettings struct { ChallengeTemplate string ChallengeTemplateTheme string ClientIpHeader string + BackendIpHeader string DNSBL *utils.DNSBL }