Build collection Links

This commit is contained in:
Deluan 2023-03-25 23:00:08 -04:00 committed by Deluan
parent dcb5725642
commit ed87e703ff
3 changed files with 90 additions and 39 deletions

View file

@ -6,7 +6,6 @@ package api
import ( import (
"context" "context"
"net/http" "net/http"
"net/url"
middleware "github.com/deepmap/oapi-codegen/pkg/chi-middleware" middleware "github.com/deepmap/oapi-codegen/pkg/chi-middleware"
"github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3"
@ -34,7 +33,10 @@ func New(ds model.DataStore) *Router {
RequestErrorHandlerFunc: apiErrorHandler, RequestErrorHandlerFunc: apiErrorHandler,
ResponseErrorHandlerFunc: apiErrorHandler, ResponseErrorHandlerFunc: apiErrorHandler,
}) })
r.Handler = HandlerFromMux(handler, mux) r.Handler = HandlerWithOptions(handler, ChiServerOptions{
BaseRouter: mux,
Middlewares: []MiddlewareFunc{storeRequestInContext},
})
return r return r
} }
@ -76,7 +78,7 @@ func (a *Router) GetTracks(ctx context.Context, request GetTracksRequestObject)
if err != nil { if err != nil {
return nil, err return nil, err
} }
baseUrl, _ := url.JoinPath(spec.Servers[0].URL, "tracks") baseUrl := baseResourceUrl(ctx, "tracks")
links, meta := buildPaginationLinksAndMeta(int32(cnt), request.Params, baseUrl) links, meta := buildPaginationLinksAndMeta(int32(cnt), request.Params, baseUrl)
return GetTracks200JSONResponse{ return GetTracks200JSONResponse{
Data: toAPITracks(mfs), Data: toAPITracks(mfs),

View file

@ -1,7 +1,9 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -9,8 +11,21 @@ import (
"github.com/Masterminds/squirrel" "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/server"
) )
type contextKey string
const requestInContext contextKey = "request"
// storeRequestInContext is a middleware function that adds the full request object to the context.
func storeRequestInContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), requestInContext, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func toAPITrack(mf model.MediaFile) Track { func toAPITrack(mf model.MediaFile) Track {
return Track{ return Track{
Type: "track", Type: "track",
@ -88,10 +103,10 @@ func toQueryOptions(params GetTracksParams) model.QueryOptions {
func apiErrorHandler(w http.ResponseWriter, r *http.Request, err error) { func apiErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
var res ErrorObject var res ErrorObject
switch err { switch {
case model.ErrNotAuthorized: case errors.Is(err, model.ErrNotAuthorized):
res = ErrorObject{Status: p(strconv.Itoa(http.StatusForbidden)), Title: p(http.StatusText(http.StatusForbidden))} res = ErrorObject{Status: p(strconv.Itoa(http.StatusForbidden)), Title: p(http.StatusText(http.StatusForbidden))}
case model.ErrNotFound: case errors.Is(err, model.ErrNotFound):
res = ErrorObject{Status: p(strconv.Itoa(http.StatusNotFound)), Title: p(http.StatusText(http.StatusNotFound))} res = ErrorObject{Status: p(strconv.Itoa(http.StatusNotFound)), Title: p(http.StatusText(http.StatusNotFound))}
default: default:
res = ErrorObject{Status: p(strconv.Itoa(http.StatusInternalServerError)), Title: p(http.StatusText(http.StatusInternalServerError))} res = ErrorObject{Status: p(strconv.Itoa(http.StatusInternalServerError)), Title: p(http.StatusText(http.StatusInternalServerError))}
@ -99,7 +114,7 @@ func apiErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
w.Header().Set("Content-Type", "application/vnd.api+json") w.Header().Set("Content-Type", "application/vnd.api+json")
w.WriteHeader(403) w.WriteHeader(403)
json.NewEncoder(w).Encode(ErrorList{[]ErrorObject{res}}) _ = json.NewEncoder(w).Encode(ErrorList{[]ErrorObject{res}})
} }
func validationErrorHandler(w http.ResponseWriter, message string, statusCode int) { func validationErrorHandler(w http.ResponseWriter, message string, statusCode int) {
@ -151,15 +166,15 @@ func buildPaginationLinksAndMeta(totalItems int32, params GetTracksParams, resou
addFilterParams("filter[endsWith]", params.FilterEndsWith) addFilterParams("filter[endsWith]", params.FilterEndsWith)
if params.Sort != nil { if params.Sort != nil {
query.Add("sort", string(*params.Sort)) query.Add("sort", *params.Sort)
} }
if params.Include != nil { if params.Include != nil {
query.Add("include", string(*params.Include)) query.Add("include", *params.Include)
} }
link := resourceName link := resourceName
if len(query) > 0 { if len(query) > 0 {
link += "&" + query.Encode() link += "?" + query.Encode()
} }
return &link return &link
} }
@ -191,3 +206,9 @@ func buildPaginationLinksAndMeta(totalItems int32, params GetTracksParams, resou
return links, meta return links, meta
} }
func baseResourceUrl(ctx context.Context, resourceName string) string {
r := ctx.Value(requestInContext).(*http.Request)
baseUrl, _ := url.JoinPath(spec.Servers[0].URL, resourceName)
return server.AbsoluteURL(r, baseUrl, nil)
}

View file

@ -1,6 +1,8 @@
package api package api
import ( import (
"net/url"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -29,9 +31,9 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() {
It("returns correct pagination links and meta", func() { It("returns correct pagination links and meta", func() {
links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName) links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName)
Expect(links.First).To(Equal(p("api/resource?page[offset]=0&page[limit]=10"))) testLinkEquality(links.First, p("api/resource?page[offset]=0&page[limit]=10"))
Expect(links.Last).To(Equal(p("api/resource?page[offset]=140&page[limit]=10"))) testLinkEquality(links.Last, p("api/resource?page[offset]=140&page[limit]=10"))
Expect(links.Next).To(Equal(p("api/resource?page[offset]=10&page[limit]=10"))) testLinkEquality(links.Next, p("api/resource?page[offset]=10&page[limit]=10"))
Expect(links.Prev).To(BeNil()) Expect(links.Prev).To(BeNil())
Expect(meta.CurrentPage).To(Equal(p(int32(1)))) Expect(meta.CurrentPage).To(Equal(p(int32(1))))
@ -51,10 +53,10 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() {
It("returns correct pagination links and meta", func() { It("returns correct pagination links and meta", func() {
links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName) links, meta := buildPaginationLinksAndMeta(totalItems, params, resourceName)
Expect(links.First).To(Equal(p("api/resource?page[offset]=0&page[limit]=20"))) testLinkEquality(links.First, p("api/resource?page[offset]=0&page[limit]=20"))
Expect(links.Last).To(Equal(p("api/resource?page[offset]=140&page[limit]=20"))) testLinkEquality(links.Last, p("api/resource?page[offset]=140&page[limit]=20"))
Expect(links.Next).To(Equal(p("api/resource?page[offset]=60&page[limit]=20"))) testLinkEquality(links.Next, p("api/resource?page[offset]=60&page[limit]=20"))
Expect(links.Prev).To(Equal(p("api/resource?page[offset]=20&page[limit]=20"))) testLinkEquality(links.Prev, p("api/resource?page[offset]=20&page[limit]=20"))
Expect(meta.CurrentPage).To(Equal(p(int32(3)))) Expect(meta.CurrentPage).To(Equal(p(int32(3))))
Expect(meta.TotalItems).To(Equal(p(int32(150)))) Expect(meta.TotalItems).To(Equal(p(int32(150))))
@ -81,30 +83,56 @@ var _ = Describe("BuildPaginationLinksAndMeta", func() {
It("returns correct pagination links with filter params", func() { It("returns correct pagination links with filter params", func() {
links, _ := buildPaginationLinksAndMeta(totalItems, params, resourceName) links, _ := buildPaginationLinksAndMeta(totalItems, params, resourceName)
expectedLinkPrefix := "api/resource?" validateLink := func(link *string, expectedOffset string) {
expectedParams := []string{ parsedLink, err := url.Parse(*link)
"page[offset]=0&page[limit]=20", Expect(err).NotTo(HaveOccurred())
"filter[equals]=property1:value1&filter[equals]=property2:value2",
"filter[contains]=property3:value3", queryParams, _ := url.ParseQuery(parsedLink.RawQuery)
"filter[lessThan]=property4:value4", Expect(queryParams["page[offset]"]).To(ConsistOf(expectedOffset))
"filter[lessOrEqual]=property5:value5", Expect(queryParams["page[limit]"]).To(ConsistOf("20"))
"filter[greaterThan]=property6:value6",
"filter[greaterOrEqual]=property7:value7", for _, param := range *params.FilterEquals {
"filter[startsWith]=property8:value8", Expect(queryParams["filter[equals]"]).To(ContainElements(param))
"filter[endsWith]=property9:value9", }
for _, param := range *params.FilterContains {
Expect(queryParams["filter[contains]"]).To(ContainElement(param))
}
for _, param := range *params.FilterLessThan {
Expect(queryParams["filter[lessThan]"]).To(ContainElement(param))
}
for _, param := range *params.FilterLessOrEqual {
Expect(queryParams["filter[lessOrEqual]"]).To(ContainElement(param))
}
for _, param := range *params.FilterGreaterThan {
Expect(queryParams["filter[greaterThan]"]).To(ContainElement(param))
}
for _, param := range *params.FilterGreaterOrEqual {
Expect(queryParams["filter[greaterOrEqual]"]).To(ContainElement(param))
}
for _, param := range *params.FilterStartsWith {
Expect(queryParams["filter[startsWith]"]).To(ContainElement(param))
}
for _, param := range *params.FilterEndsWith {
Expect(queryParams["filter[endsWith]"]).To(ContainElement(param))
}
} }
Expect(*links.First).To(HavePrefix(expectedLinkPrefix)) validateLink(links.First, "0")
Expect(*links.Last).To(HavePrefix(expectedLinkPrefix)) validateLink(links.Last, "140")
Expect(*links.Next).To(HavePrefix(expectedLinkPrefix)) validateLink(links.Next, "60")
Expect(*links.Prev).To(HavePrefix(expectedLinkPrefix)) validateLink(links.Prev, "20")
for _, param := range expectedParams {
Expect(*links.First).To(ContainSubstring(param))
Expect(*links.Last).To(ContainSubstring(param))
Expect(*links.Next).To(ContainSubstring(param))
Expect(*links.Prev).To(ContainSubstring(param))
}
}) })
}) })
}) })
func testLinkEquality(link1, link2 *string) {
parsedLink1, err := url.Parse(*link1)
Expect(err).NotTo(HaveOccurred())
queryParams1, _ := url.ParseQuery(parsedLink1.RawQuery)
parsedLink2, err := url.Parse(*link2)
Expect(err).NotTo(HaveOccurred())
queryParams2, _ := url.ParseQuery(parsedLink2.RawQuery)
Expect(queryParams1).To(Equal(queryParams2))
}