diff --git a/server/api/api.go b/server/api/api.go index f68f0ea8f..869dc5237 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -6,7 +6,6 @@ package api import ( "context" "net/http" - "net/url" middleware "github.com/deepmap/oapi-codegen/pkg/chi-middleware" "github.com/getkin/kin-openapi/openapi3" @@ -34,7 +33,10 @@ func New(ds model.DataStore) *Router { RequestErrorHandlerFunc: apiErrorHandler, ResponseErrorHandlerFunc: apiErrorHandler, }) - r.Handler = HandlerFromMux(handler, mux) + r.Handler = HandlerWithOptions(handler, ChiServerOptions{ + BaseRouter: mux, + Middlewares: []MiddlewareFunc{storeRequestInContext}, + }) return r } @@ -76,7 +78,7 @@ func (a *Router) GetTracks(ctx context.Context, request GetTracksRequestObject) if err != nil { return nil, err } - baseUrl, _ := url.JoinPath(spec.Servers[0].URL, "tracks") + baseUrl := baseResourceUrl(ctx, "tracks") links, meta := buildPaginationLinksAndMeta(int32(cnt), request.Params, baseUrl) return GetTracks200JSONResponse{ Data: toAPITracks(mfs), diff --git a/server/api/helpers.go b/server/api/helpers.go index c5ada9b82..5a524fb44 100644 --- a/server/api/helpers.go +++ b/server/api/helpers.go @@ -1,7 +1,9 @@ package api import ( + "context" "encoding/json" + "errors" "net/http" "net/url" "strconv" @@ -9,8 +11,21 @@ import ( "github.com/Masterminds/squirrel" "github.com/navidrome/navidrome/model" + "github.com/navidrome/navidrome/server" ) +type contextKey string + +const requestInContext contextKey = "request" + +// storeRequestInContext is a middleware function that adds the full request object to the context. +func storeRequestInContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), requestInContext, r) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func toAPITrack(mf model.MediaFile) Track { return Track{ Type: "track", @@ -88,10 +103,10 @@ func toQueryOptions(params GetTracksParams) model.QueryOptions { func apiErrorHandler(w http.ResponseWriter, r *http.Request, err error) { var res ErrorObject - switch err { - case model.ErrNotAuthorized: + switch { + case errors.Is(err, model.ErrNotAuthorized): res = ErrorObject{Status: p(strconv.Itoa(http.StatusForbidden)), Title: p(http.StatusText(http.StatusForbidden))} - case model.ErrNotFound: + case errors.Is(err, model.ErrNotFound): res = ErrorObject{Status: p(strconv.Itoa(http.StatusNotFound)), Title: p(http.StatusText(http.StatusNotFound))} default: res = ErrorObject{Status: p(strconv.Itoa(http.StatusInternalServerError)), Title: p(http.StatusText(http.StatusInternalServerError))} @@ -99,7 +114,7 @@ func apiErrorHandler(w http.ResponseWriter, r *http.Request, err error) { w.Header().Set("Content-Type", "application/vnd.api+json") w.WriteHeader(403) - json.NewEncoder(w).Encode(ErrorList{[]ErrorObject{res}}) + _ = json.NewEncoder(w).Encode(ErrorList{[]ErrorObject{res}}) } func validationErrorHandler(w http.ResponseWriter, message string, statusCode int) { @@ -151,15 +166,15 @@ func buildPaginationLinksAndMeta(totalItems int32, params GetTracksParams, resou addFilterParams("filter[endsWith]", params.FilterEndsWith) if params.Sort != nil { - query.Add("sort", string(*params.Sort)) + query.Add("sort", *params.Sort) } if params.Include != nil { - query.Add("include", string(*params.Include)) + query.Add("include", *params.Include) } link := resourceName if len(query) > 0 { - link += "&" + query.Encode() + link += "?" + query.Encode() } return &link } @@ -191,3 +206,9 @@ func buildPaginationLinksAndMeta(totalItems int32, params GetTracksParams, resou return links, meta } + +func baseResourceUrl(ctx context.Context, resourceName string) string { + r := ctx.Value(requestInContext).(*http.Request) + baseUrl, _ := url.JoinPath(spec.Servers[0].URL, resourceName) + return server.AbsoluteURL(r, baseUrl, nil) +} diff --git a/server/api/helpers_test.go b/server/api/helpers_test.go index 9efe1a132..0b4a9e7d2 100644 --- a/server/api/helpers_test.go +++ b/server/api/helpers_test.go @@ -1,6 +1,8 @@ package api import ( + "net/url" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -29,9 +31,9 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() { It("returns correct pagination links and meta", func() { links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName) - Expect(links.First).To(Equal(p("api/resource?page[offset]=0&page[limit]=10"))) - Expect(links.Last).To(Equal(p("api/resource?page[offset]=140&page[limit]=10"))) - Expect(links.Next).To(Equal(p("api/resource?page[offset]=10&page[limit]=10"))) + testLinkEquality(links.First, p("api/resource?page[offset]=0&page[limit]=10")) + testLinkEquality(links.Last, p("api/resource?page[offset]=140&page[limit]=10")) + testLinkEquality(links.Next, p("api/resource?page[offset]=10&page[limit]=10")) Expect(links.Prev).To(BeNil()) Expect(meta.CurrentPage).To(Equal(p(int32(1)))) @@ -51,10 +53,10 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() { It("returns correct pagination links and meta", func() { links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName) - Expect(links.First).To(Equal(p("api/resource?page[offset]=0&page[limit]=20"))) - Expect(links.Last).To(Equal(p("api/resource?page[offset]=140&page[limit]=20"))) - Expect(links.Next).To(Equal(p("api/resource?page[offset]=60&page[limit]=20"))) - Expect(links.Prev).To(Equal(p("api/resource?page[offset]=20&page[limit]=20"))) + testLinkEquality(links.First, p("api/resource?page[offset]=0&page[limit]=20")) + testLinkEquality(links.Last, p("api/resource?page[offset]=140&page[limit]=20")) + testLinkEquality(links.Next, p("api/resource?page[offset]=60&page[limit]=20")) + testLinkEquality(links.Prev, p("api/resource?page[offset]=20&page[limit]=20")) Expect(meta.CurrentPage).To(Equal(p(int32(3)))) Expect(meta.TotalItems).To(Equal(p(int32(150)))) @@ -81,30 +83,56 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() { It("returns correct pagination links with filter params", func() { links, _ := buildPaginationLinksAndMeta(totalItems, params, resourceName) - expectedLinkPrefix := "api/resource?" - expectedParams := []string{ - "page[offset]=0&page[limit]=20", - "filter[equals]=property1:value1&filter[equals]=property2:value2", - "filter[contains]=property3:value3", - "filter[lessThan]=property4:value4", - "filter[lessOrEqual]=property5:value5", - "filter[greaterThan]=property6:value6", - "filter[greaterOrEqual]=property7:value7", - "filter[startsWith]=property8:value8", - "filter[endsWith]=property9:value9", + validateLink := func(link *string, expectedOffset string) { + parsedLink, err := url.Parse(*link) + Expect(err).NotTo(HaveOccurred()) + + queryParams, _ := url.ParseQuery(parsedLink.RawQuery) + Expect(queryParams["page[offset]"]).To(ConsistOf(expectedOffset)) + Expect(queryParams["page[limit]"]).To(ConsistOf("20")) + + for _, param := range *params.FilterEquals { + Expect(queryParams["filter[equals]"]).To(ContainElements(param)) + } + for _, param := range *params.FilterContains { + Expect(queryParams["filter[contains]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterLessThan { + Expect(queryParams["filter[lessThan]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterLessOrEqual { + Expect(queryParams["filter[lessOrEqual]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterGreaterThan { + Expect(queryParams["filter[greaterThan]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterGreaterOrEqual { + Expect(queryParams["filter[greaterOrEqual]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterStartsWith { + Expect(queryParams["filter[startsWith]"]).To(ContainElement(param)) + } + for _, param := range *params.FilterEndsWith { + Expect(queryParams["filter[endsWith]"]).To(ContainElement(param)) + } } - Expect(*links.First).To(HavePrefix(expectedLinkPrefix)) - Expect(*links.Last).To(HavePrefix(expectedLinkPrefix)) - Expect(*links.Next).To(HavePrefix(expectedLinkPrefix)) - Expect(*links.Prev).To(HavePrefix(expectedLinkPrefix)) - - for _, param := range expectedParams { - Expect(*links.First).To(ContainSubstring(param)) - Expect(*links.Last).To(ContainSubstring(param)) - Expect(*links.Next).To(ContainSubstring(param)) - Expect(*links.Prev).To(ContainSubstring(param)) - } + validateLink(links.First, "0") + validateLink(links.Last, "140") + validateLink(links.Next, "60") + validateLink(links.Prev, "20") }) }) }) + +func testLinkEquality(link1, link2 *string) { + parsedLink1, err := url.Parse(*link1) + Expect(err).NotTo(HaveOccurred()) + queryParams1, _ := url.ParseQuery(parsedLink1.RawQuery) + + parsedLink2, err := url.Parse(*link2) + Expect(err).NotTo(HaveOccurred()) + queryParams2, _ := url.ParseQuery(parsedLink2.RawQuery) + + Expect(queryParams1).To(Equal(queryParams2)) +}