Add more middleware tests

This commit is contained in:
Deluan 2023-03-26 20:22:18 -04:00
parent 59a9c056b4
commit 05d381c26f
2 changed files with 245 additions and 16 deletions

View file

@ -119,10 +119,14 @@ func compressMiddleware() func(http.Handler) http.Handler {
)
}
// clientUniqueIDMiddleware is a middleware that sets a unique client ID as a cookie if it's provided in the request header.
// If the unique client ID is not in the header but present as a cookie, it adds the ID to the request context.
func clientUniqueIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
clientUniqueId := r.Header.Get(consts.UIClientUniqueIDHeader)
// If clientUniqueId is found in the header, set it as a cookie
if clientUniqueId != "" {
c := &http.Cookie{
Name: consts.UIClientUniqueIDHeader,
@ -135,45 +139,69 @@ func clientUniqueIDMiddleware(next http.Handler) http.Handler {
}
http.SetCookie(w, c)
} else {
// If clientUniqueId is not found in the header, check if it's present as a cookie
c, err := r.Cookie(consts.UIClientUniqueIDHeader)
if !errors.Is(err, http.ErrNoCookie) {
clientUniqueId = c.Value
}
}
// If a valid clientUniqueId is found, add it to the request context
if clientUniqueId != "" {
ctx = request.WithClientUniqueId(ctx, clientUniqueId)
r = r.WithContext(ctx)
}
// Call the next middleware or handler in the chain
next.ServeHTTP(w, r)
})
}
// serverAddressMiddleware is a middleware function that modifies the request object
// to reflect the address of the server handling the request, as determined by the
// presence of X-Forwarded-* headers or the scheme and host of the request URL.
func serverAddressMiddleware(h http.Handler) http.Handler {
// Define a new handler function that will be returned by this middleware function.
fn := func(w http.ResponseWriter, r *http.Request) {
// Call the serverAddress function to get the scheme and host of the server
// handling the request. If a host is found, modify the request object to use
// that host and scheme instead of the original ones.
if rScheme, rHost := serverAddress(r); rHost != "" {
r.Host = rHost
r.URL.Scheme = rScheme
}
// Call the next handler in the chain with the modified request and response.
h.ServeHTTP(w, r)
}
// Return the new handler function as an http.Handler object.
return http.HandlerFunc(fn)
}
// Define constants for the X-Forwarded-* header keys.
var (
xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host")
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme")
)
// serverAddress is a helper function that returns the scheme and host of the server
// handling the given request, as determined by the presence of X-Forwarded-* headers
// or the scheme and host of the request URL.
func serverAddress(r *http.Request) (scheme, host string) {
// Save the original request host for later comparison.
origHost := r.Host
// Determine the protocol of the request based on the presence of a TLS connection.
protocol := "http"
if r.TLS != nil {
protocol = "https"
}
// Get the X-Forwarded-Host header and extract the first host name if there are
// multiple hosts listed. If there is no X-Forwarded-Host header, use the original
// request host as the default.
xfh := r.Header.Get(xForwardedHost)
if xfh != "" {
i := strings.Index(xfh, ",")
@ -182,19 +210,29 @@ func serverAddress(r *http.Request) (scheme, host string) {
}
xfh = xfh[:i]
}
host = firstOr(r.Host, xfh)
// Determine the protocol and scheme of the request based on the presence of
// X-Forwarded-* headers or the scheme of the request URL.
scheme = firstOr(
protocol,
r.Header.Get(xForwardedProto),
r.Header.Get(xForwardedScheme),
r.URL.Scheme,
)
host = firstOr(r.Host, xfh)
// If the request host has changed due to the X-Forwarded-Host header, log a trace
// message with the original and new host values, as well as the scheme and URL.
if host != origHost {
log.Trace(r.Context(), "Request host has changed", "origHost", origHost, "host", host, "scheme", scheme, "url", r.URL)
}
// Return the scheme and host of the server handling the request.
return scheme, host
}
// firstOr is a helper function that returns the first non-empty string from a list
// of strings, or a default value if all the strings are empty.
func firstOr(or string, strings ...string) string {
for _, s := range strings {
if s != "" {
@ -204,25 +242,33 @@ func firstOr(or string, strings ...string) string {
return or
}
// URLParamsMiddleware convert Chi URL params (from Context) to query params, as expected by our REST package
// URLParamsMiddleware is a middleware function that decodes the query string of
// the incoming HTTP request, adds the URL parameters from the routing context,
// and re-encodes the modified query string.
func URLParamsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Retrieve the routing context from the request context.
ctx := chi.RouteContext(r.Context())
parts := make([]string, 0)
// Parse the existing query string into a URL values map.
params, _ := url.ParseQuery(r.URL.RawQuery)
// Loop through each URL parameter in the routing context.
for i, key := range ctx.URLParams.Keys {
value := ctx.URLParams.Values[i]
if key == "*" {
// Skip any wildcard URL parameter keys.
if strings.Contains(key, "*") {
continue
}
parts = append(parts, url.QueryEscape(":"+key)+"="+url.QueryEscape(value))
}
q := strings.Join(parts, "&")
if r.URL.RawQuery == "" {
r.URL.RawQuery = q
} else {
r.URL.RawQuery += "&" + q
// Add the URL parameter key-value pair to the URL values map.
params.Add(":"+key, ctx.URLParams.Values[i])
}
// Re-encode the URL values map as a query string and replace the
// existing query string in the request.
r.URL.RawQuery = params.Encode()
// Call the next handler in the chain with the modified request and response.
next.ServeHTTP(w, r)
})
}