mirror of
https://github.com/navidrome/navidrome.git
synced 2025-04-04 13:07:36 +03:00
Add more middleware tests
This commit is contained in:
parent
59a9c056b4
commit
05d381c26f
2 changed files with 245 additions and 16 deletions
|
@ -119,10 +119,14 @@ func compressMiddleware() func(http.Handler) http.Handler {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientUniqueIDMiddleware is a middleware that sets a unique client ID as a cookie if it's provided in the request header.
|
||||||
|
// If the unique client ID is not in the header but present as a cookie, it adds the ID to the request context.
|
||||||
func clientUniqueIDMiddleware(next http.Handler) http.Handler {
|
func clientUniqueIDMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
clientUniqueId := r.Header.Get(consts.UIClientUniqueIDHeader)
|
clientUniqueId := r.Header.Get(consts.UIClientUniqueIDHeader)
|
||||||
|
|
||||||
|
// If clientUniqueId is found in the header, set it as a cookie
|
||||||
if clientUniqueId != "" {
|
if clientUniqueId != "" {
|
||||||
c := &http.Cookie{
|
c := &http.Cookie{
|
||||||
Name: consts.UIClientUniqueIDHeader,
|
Name: consts.UIClientUniqueIDHeader,
|
||||||
|
@ -135,45 +139,69 @@ func clientUniqueIDMiddleware(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
http.SetCookie(w, c)
|
http.SetCookie(w, c)
|
||||||
} else {
|
} else {
|
||||||
|
// If clientUniqueId is not found in the header, check if it's present as a cookie
|
||||||
c, err := r.Cookie(consts.UIClientUniqueIDHeader)
|
c, err := r.Cookie(consts.UIClientUniqueIDHeader)
|
||||||
if !errors.Is(err, http.ErrNoCookie) {
|
if !errors.Is(err, http.ErrNoCookie) {
|
||||||
clientUniqueId = c.Value
|
clientUniqueId = c.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a valid clientUniqueId is found, add it to the request context
|
||||||
if clientUniqueId != "" {
|
if clientUniqueId != "" {
|
||||||
ctx = request.WithClientUniqueId(ctx, clientUniqueId)
|
ctx = request.WithClientUniqueId(ctx, clientUniqueId)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call the next middleware or handler in the chain
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serverAddressMiddleware is a middleware function that modifies the request object
|
||||||
|
// to reflect the address of the server handling the request, as determined by the
|
||||||
|
// presence of X-Forwarded-* headers or the scheme and host of the request URL.
|
||||||
func serverAddressMiddleware(h http.Handler) http.Handler {
|
func serverAddressMiddleware(h http.Handler) http.Handler {
|
||||||
|
// Define a new handler function that will be returned by this middleware function.
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Call the serverAddress function to get the scheme and host of the server
|
||||||
|
// handling the request. If a host is found, modify the request object to use
|
||||||
|
// that host and scheme instead of the original ones.
|
||||||
if rScheme, rHost := serverAddress(r); rHost != "" {
|
if rScheme, rHost := serverAddress(r); rHost != "" {
|
||||||
r.Host = rHost
|
r.Host = rHost
|
||||||
r.URL.Scheme = rScheme
|
r.URL.Scheme = rScheme
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call the next handler in the chain with the modified request and response.
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the new handler function as an http.Handler object.
|
||||||
return http.HandlerFunc(fn)
|
return http.HandlerFunc(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Define constants for the X-Forwarded-* header keys.
|
||||||
var (
|
var (
|
||||||
xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host")
|
xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host")
|
||||||
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
|
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
|
||||||
xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme")
|
xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// serverAddress is a helper function that returns the scheme and host of the server
|
||||||
|
// handling the given request, as determined by the presence of X-Forwarded-* headers
|
||||||
|
// or the scheme and host of the request URL.
|
||||||
func serverAddress(r *http.Request) (scheme, host string) {
|
func serverAddress(r *http.Request) (scheme, host string) {
|
||||||
|
// Save the original request host for later comparison.
|
||||||
origHost := r.Host
|
origHost := r.Host
|
||||||
|
|
||||||
|
// Determine the protocol of the request based on the presence of a TLS connection.
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if r.TLS != nil {
|
if r.TLS != nil {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the X-Forwarded-Host header and extract the first host name if there are
|
||||||
|
// multiple hosts listed. If there is no X-Forwarded-Host header, use the original
|
||||||
|
// request host as the default.
|
||||||
xfh := r.Header.Get(xForwardedHost)
|
xfh := r.Header.Get(xForwardedHost)
|
||||||
if xfh != "" {
|
if xfh != "" {
|
||||||
i := strings.Index(xfh, ",")
|
i := strings.Index(xfh, ",")
|
||||||
|
@ -182,19 +210,29 @@ func serverAddress(r *http.Request) (scheme, host string) {
|
||||||
}
|
}
|
||||||
xfh = xfh[:i]
|
xfh = xfh[:i]
|
||||||
}
|
}
|
||||||
|
host = firstOr(r.Host, xfh)
|
||||||
|
|
||||||
|
// Determine the protocol and scheme of the request based on the presence of
|
||||||
|
// X-Forwarded-* headers or the scheme of the request URL.
|
||||||
scheme = firstOr(
|
scheme = firstOr(
|
||||||
protocol,
|
protocol,
|
||||||
r.Header.Get(xForwardedProto),
|
r.Header.Get(xForwardedProto),
|
||||||
r.Header.Get(xForwardedScheme),
|
r.Header.Get(xForwardedScheme),
|
||||||
r.URL.Scheme,
|
r.URL.Scheme,
|
||||||
)
|
)
|
||||||
host = firstOr(r.Host, xfh)
|
|
||||||
|
// If the request host has changed due to the X-Forwarded-Host header, log a trace
|
||||||
|
// message with the original and new host values, as well as the scheme and URL.
|
||||||
if host != origHost {
|
if host != origHost {
|
||||||
log.Trace(r.Context(), "Request host has changed", "origHost", origHost, "host", host, "scheme", scheme, "url", r.URL)
|
log.Trace(r.Context(), "Request host has changed", "origHost", origHost, "host", host, "scheme", scheme, "url", r.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the scheme and host of the server handling the request.
|
||||||
return scheme, host
|
return scheme, host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// firstOr is a helper function that returns the first non-empty string from a list
|
||||||
|
// of strings, or a default value if all the strings are empty.
|
||||||
func firstOr(or string, strings ...string) string {
|
func firstOr(or string, strings ...string) string {
|
||||||
for _, s := range strings {
|
for _, s := range strings {
|
||||||
if s != "" {
|
if s != "" {
|
||||||
|
@ -204,25 +242,33 @@ func firstOr(or string, strings ...string) string {
|
||||||
return or
|
return or
|
||||||
}
|
}
|
||||||
|
|
||||||
// URLParamsMiddleware convert Chi URL params (from Context) to query params, as expected by our REST package
|
// URLParamsMiddleware is a middleware function that decodes the query string of
|
||||||
|
// the incoming HTTP request, adds the URL parameters from the routing context,
|
||||||
|
// and re-encodes the modified query string.
|
||||||
func URLParamsMiddleware(next http.Handler) http.Handler {
|
func URLParamsMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Retrieve the routing context from the request context.
|
||||||
ctx := chi.RouteContext(r.Context())
|
ctx := chi.RouteContext(r.Context())
|
||||||
parts := make([]string, 0)
|
|
||||||
|
// Parse the existing query string into a URL values map.
|
||||||
|
params, _ := url.ParseQuery(r.URL.RawQuery)
|
||||||
|
|
||||||
|
// Loop through each URL parameter in the routing context.
|
||||||
for i, key := range ctx.URLParams.Keys {
|
for i, key := range ctx.URLParams.Keys {
|
||||||
value := ctx.URLParams.Values[i]
|
// Skip any wildcard URL parameter keys.
|
||||||
if key == "*" {
|
if strings.Contains(key, "*") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
parts = append(parts, url.QueryEscape(":"+key)+"="+url.QueryEscape(value))
|
|
||||||
}
|
// Add the URL parameter key-value pair to the URL values map.
|
||||||
q := strings.Join(parts, "&")
|
params.Add(":"+key, ctx.URLParams.Values[i])
|
||||||
if r.URL.RawQuery == "" {
|
|
||||||
r.URL.RawQuery = q
|
|
||||||
} else {
|
|
||||||
r.URL.RawQuery += "&" + q
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-encode the URL values map as a query string and replace the
|
||||||
|
// existing query string in the request.
|
||||||
|
r.URL.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
// Call the next handler in the chain with the modified request and response.
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,18 +3,27 @@ package server
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/navidrome/navidrome/conf"
|
||||||
|
"github.com/navidrome/navidrome/conf/configtest"
|
||||||
|
"github.com/navidrome/navidrome/consts"
|
||||||
|
"github.com/navidrome/navidrome/model/request"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("middlewares", func() {
|
var _ = Describe("middlewares", func() {
|
||||||
var nextCalled bool
|
BeforeEach(func() {
|
||||||
next := func(w http.ResponseWriter, r *http.Request) {
|
DeferCleanup(configtest.SetupConfig())
|
||||||
nextCalled = true
|
})
|
||||||
}
|
|
||||||
Describe("robotsTXT", func() {
|
Describe("robotsTXT", func() {
|
||||||
|
var nextCalled bool
|
||||||
|
next := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
nextCalled = true
|
||||||
|
}
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
nextCalled = false
|
nextCalled = false
|
||||||
})
|
})
|
||||||
|
@ -144,4 +153,178 @@ var _ = Describe("middlewares", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Describe("clientUniqueIDMiddleware", func() {
|
||||||
|
var (
|
||||||
|
nextHandler http.Handler
|
||||||
|
middleware http.Handler
|
||||||
|
req *http.Request
|
||||||
|
nextReq *http.Request
|
||||||
|
rec *httptest.ResponseRecorder
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
nextHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
nextReq = r
|
||||||
|
})
|
||||||
|
middleware = clientUniqueIDMiddleware(nextHandler)
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when the request header has the unique client ID", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req.Header.Set(consts.UIClientUniqueIDHeader, "123456")
|
||||||
|
conf.Server.BasePath = "/music"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets the unique client ID as a cookie and adds it to the request context", func() {
|
||||||
|
middleware.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Result().Cookies()).To(HaveLen(1))
|
||||||
|
Expect(rec.Result().Cookies()[0].Name).To(Equal(consts.UIClientUniqueIDHeader))
|
||||||
|
Expect(rec.Result().Cookies()[0].Value).To(Equal("123456"))
|
||||||
|
Expect(rec.Result().Cookies()[0].MaxAge).To(Equal(consts.CookieExpiry))
|
||||||
|
Expect(rec.Result().Cookies()[0].HttpOnly).To(BeTrue())
|
||||||
|
Expect(rec.Result().Cookies()[0].Secure).To(BeTrue())
|
||||||
|
Expect(rec.Result().Cookies()[0].SameSite).To(Equal(http.SameSiteStrictMode))
|
||||||
|
Expect(rec.Result().Cookies()[0].Path).To(Equal("/music"))
|
||||||
|
clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context())
|
||||||
|
Expect(clientUniqueId).To(Equal("123456"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when the request header does not have the unique client ID", func() {
|
||||||
|
Context("when the request has the unique client ID in a cookie", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: consts.UIClientUniqueIDHeader,
|
||||||
|
Value: "123456",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds the unique client ID to the request context", func() {
|
||||||
|
middleware.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Result().Cookies()).To(HaveLen(0))
|
||||||
|
|
||||||
|
clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context())
|
||||||
|
Expect(clientUniqueId).To(Equal("123456"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when the request does not have the unique client ID in a cookie", func() {
|
||||||
|
It("does not add the unique client ID to the request context", func() {
|
||||||
|
middleware.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
Expect(rec.Result().Cookies()).To(HaveLen(0))
|
||||||
|
|
||||||
|
clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context())
|
||||||
|
Expect(clientUniqueId).To(BeEmpty())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("URLParamsMiddleware", func() {
|
||||||
|
var (
|
||||||
|
router *chi.Mux
|
||||||
|
middleware http.Handler
|
||||||
|
recorder *httptest.ResponseRecorder
|
||||||
|
testHandler http.HandlerFunc
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
router = chi.NewRouter()
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
testHandler = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when request has no query parameters", func() {
|
||||||
|
It("adds URL parameters to the request", func() {
|
||||||
|
middleware = URLParamsMiddleware(testHandler)
|
||||||
|
router.Mount("/", middleware)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/?user=1", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
Expect(recorder.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(recorder.Body.String()).To(Equal("OK"))
|
||||||
|
Expect(req.URL.RawQuery).To(ContainSubstring("user=1"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when request has query parameters", func() {
|
||||||
|
It("merges URL parameters and query parameters", func() {
|
||||||
|
router.Route("/{key}", func(r chi.Router) {
|
||||||
|
r.Use(URLParamsMiddleware)
|
||||||
|
r.Get("/", testHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/test?key=value", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
Expect(recorder.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(recorder.Body.String()).To(Equal("OK"))
|
||||||
|
Expect(req.URL.RawQuery).To(ContainSubstring("key=value"))
|
||||||
|
Expect(req.URL.RawQuery).To(ContainSubstring("%3Akey=test"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when URL parameter has wildcard key", func() {
|
||||||
|
It("does not include wildcard key in query parameters", func() {
|
||||||
|
router.Route("/{t*}", func(r chi.Router) {
|
||||||
|
r.Use(URLParamsMiddleware)
|
||||||
|
r.Get("/", testHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/test?key=value", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
Expect(recorder.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(recorder.Body.String()).To(Equal("OK"))
|
||||||
|
Expect(req.URL.RawQuery).To(ContainSubstring("key=value"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when URL parameters require encoding", func() {
|
||||||
|
It("encodes URL parameters correctly", func() {
|
||||||
|
router.Route("/{key}", func(r chi.Router) {
|
||||||
|
r.Use(URLParamsMiddleware)
|
||||||
|
r.Get("/", testHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/test with space?key=another value", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
Expect(recorder.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(recorder.Body.String()).To(Equal("OK"))
|
||||||
|
queryValues, _ := url.ParseQuery(req.URL.RawQuery)
|
||||||
|
Expect(queryValues.Get(":key")).To(Equal("test with space"))
|
||||||
|
Expect(queryValues.Get("key")).To(Equal("another value"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when there are multiple URL parameters", func() {
|
||||||
|
It("includes all URL parameters in the query string", func() {
|
||||||
|
router.Route("/{key}/{value}", func(r chi.Router) {
|
||||||
|
r.Use(URLParamsMiddleware)
|
||||||
|
r.Get("/", testHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/test/value?key=other_value", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
Expect(recorder.Code).To(Equal(http.StatusOK))
|
||||||
|
Expect(recorder.Body.String()).To(Equal("OK"))
|
||||||
|
|
||||||
|
queryValues, _ := url.ParseQuery(req.URL.RawQuery)
|
||||||
|
Expect(queryValues.Get(":key")).To(Equal("test"))
|
||||||
|
Expect(queryValues.Get(":value")).To(Equal("value"))
|
||||||
|
Expect(queryValues.Get("key")).To(Equal("other_value"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue