diff --git a/engine/common.go b/engine/common.go index 356da3ffe..5d1eecaf4 100644 --- a/engine/common.go +++ b/engine/common.go @@ -7,6 +7,7 @@ import ( "github.com/deluan/navidrome/consts" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" ) type Entry struct { @@ -159,10 +160,9 @@ func FromArtists(ars model.Artists) Entries { } func userName(ctx context.Context) string { - user := ctx.Value("user") - if user == nil { + if user, ok := request.UserFrom(ctx); !ok { return "UNKNOWN" + } else { + return user.UserName } - usr := user.(model.User) - return usr.UserName } diff --git a/engine/media_streamer.go b/engine/media_streamer.go index c37c5f31a..357310be0 100644 --- a/engine/media_streamer.go +++ b/engine/media_streamer.go @@ -13,6 +13,7 @@ import ( "github.com/deluan/navidrome/engine/transcoder" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/djherbis/fscache" ) @@ -161,7 +162,7 @@ func selectTranscodingOptions(ctx context.Context, ds model.DataStore, mf *model bitRate = mf.BitRate return } - trc, hasDefault := ctx.Value("transcoding").(model.Transcoding) + trc, hasDefault := request.TranscodingFrom(ctx) var cFormat string var cBitRate int if reqFormat != "" { @@ -170,7 +171,7 @@ func selectTranscodingOptions(ctx context.Context, ds model.DataStore, mf *model if hasDefault { cFormat = trc.TargetFormat cBitRate = trc.DefaultBitRate - if p, ok := ctx.Value("player").(model.Player); ok { + if p, ok := request.PlayerFrom(ctx); ok { cBitRate = p.MaxBitRate } } diff --git a/engine/media_streamer_test.go b/engine/media_streamer_test.go index 0f0dbbb57..53f8d7e1a 100644 --- a/engine/media_streamer_test.go +++ b/engine/media_streamer_test.go @@ -7,6 +7,7 @@ import ( "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/persistence" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -101,7 +102,7 @@ var _ = Describe("MediaStreamer", func() { Context("player has format configured", func() { BeforeEach(func() { t := model.Transcoding{ID: "oga1", TargetFormat: "oga", DefaultBitRate: 96} - ctx = context.WithValue(ctx, "transcoding", t) + ctx = request.WithTranscoding(ctx, t) }) It("returns raw if raw is requested", func() { mf.Suffix = "flac" @@ -142,8 +143,8 @@ var _ = Describe("MediaStreamer", func() { BeforeEach(func() { t := model.Transcoding{ID: "oga1", TargetFormat: "oga", DefaultBitRate: 96} p := model.Player{ID: "player1", TranscodingId: t.ID, MaxBitRate: 80} - ctx = context.WithValue(ctx, "transcoding", t) - ctx = context.WithValue(ctx, "player", p) + ctx = request.WithTranscoding(ctx, t) + ctx = request.WithPlayer(ctx, p) }) It("returns raw if raw is requested", func() { mf.Suffix = "flac" diff --git a/engine/players.go b/engine/players.go index 9e5bacb79..132c602e8 100644 --- a/engine/players.go +++ b/engine/players.go @@ -7,6 +7,7 @@ import ( "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/google/uuid" ) @@ -27,7 +28,7 @@ func (p *players) Register(ctx context.Context, id, client, typ, ip string) (*mo var plr *model.Player var trc *model.Transcoding var err error - userName := ctx.Value("username").(string) + userName, _ := request.UsernameFrom(ctx) if id != "" { plr, err = p.ds.Player(ctx).Get(id) if err == nil && plr.Client != client { diff --git a/engine/players_test.go b/engine/players_test.go index 720274035..7d0a9a52a 100644 --- a/engine/players_test.go +++ b/engine/players_test.go @@ -6,6 +6,7 @@ import ( "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/persistence" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -14,8 +15,9 @@ import ( var _ = Describe("Players", func() { var players Players var repo *mockPlayerRepository - ctx := context.WithValue(log.NewContext(context.TODO()), "user", model.User{ID: "userid", UserName: "johndoe"}) - ctx = context.WithValue(ctx, "username", "johndoe") + ctx := log.NewContext(context.TODO()) + ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "johndoe"}) + ctx = request.WithUsername(ctx, "johndoe") var beforeRegister time.Time BeforeEach(func() { diff --git a/engine/playlists.go b/engine/playlists.go index d37b4f730..ef0860ed7 100644 --- a/engine/playlists.go +++ b/engine/playlists.go @@ -5,6 +5,7 @@ import ( "time" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/utils" ) @@ -52,7 +53,7 @@ func (p *playlists) Create(ctx context.Context, playlistId, name string, ids []s } func (p *playlists) getUser(ctx context.Context) string { - user, ok := ctx.Value("user").(model.User) + user, ok := request.UserFrom(ctx) if ok { return user.UserName } diff --git a/log/log.go b/log/log.go index 7532ef51b..b992a8df7 100644 --- a/log/log.go +++ b/log/log.go @@ -24,6 +24,10 @@ const ( LevelTrace = Level(logrus.TraceLevel) ) +type contextKey string + +const loggerCtxKey = contextKey("logger") + var ( currentLevel Level defaultLogger = logrus.New() @@ -66,7 +70,7 @@ func NewContext(ctx context.Context, keyValuePairs ...interface{}) context.Conte } logger := addFields(createNewLogger(), keyValuePairs) - ctx = context.WithValue(ctx, "logger", logger) + ctx = context.WithValue(ctx, loggerCtxKey, logger) return ctx } @@ -176,7 +180,7 @@ func extractLogger(ctx interface{}) (*logrus.Entry, error) { case *logrus.Entry: return ctx, nil case context.Context: - logger := ctx.Value("logger") + logger := ctx.Value(loggerCtxKey) if logger != nil { return logger.(*logrus.Entry), nil } diff --git a/log/log_test.go b/log/log_test.go index eca2a42dd..146baee3b 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -136,7 +136,7 @@ var _ = Describe("Logger", func() { It("returns the logger from context if it has one", func() { logger := logrus.NewEntry(logrus.New()) ctx := context.Background() - ctx = context.WithValue(ctx, "logger", logger) + ctx = context.WithValue(ctx, loggerCtxKey, logger) Expect(extractLogger(ctx)).To(Equal(logger)) }) @@ -144,7 +144,7 @@ var _ = Describe("Logger", func() { It("returns the logger from request's context if it has one", func() { logger := logrus.NewEntry(logrus.New()) ctx := context.Background() - ctx = context.WithValue(ctx, "logger", logger) + ctx = context.WithValue(ctx, loggerCtxKey, logger) req := httptest.NewRequest("get", "/", nil).WithContext(ctx) Expect(extractLogger(req)).To(Equal(logger)) diff --git a/model/request/request.go b/model/request/request.go new file mode 100644 index 000000000..16b3a9333 --- /dev/null +++ b/model/request/request.go @@ -0,0 +1,72 @@ +package request + +import ( + "context" + + "github.com/deluan/navidrome/model" +) + +type contextKey string + +const ( + User = contextKey("user") + Username = contextKey("username") + Client = contextKey("client") + Version = contextKey("version") + Player = contextKey("player") + Transcoding = contextKey("transcoding") +) + +func WithUser(ctx context.Context, u model.User) context.Context { + return context.WithValue(ctx, User, u) +} + +func WithUsername(ctx context.Context, username string) context.Context { + return context.WithValue(ctx, Username, username) +} + +func WithClient(ctx context.Context, client string) context.Context { + return context.WithValue(ctx, Client, client) +} + +func WithVersion(ctx context.Context, version string) context.Context { + return context.WithValue(ctx, Version, version) +} + +func WithPlayer(ctx context.Context, player model.Player) context.Context { + return context.WithValue(ctx, Player, player) +} + +func WithTranscoding(ctx context.Context, t model.Transcoding) context.Context { + return context.WithValue(ctx, Transcoding, t) +} + +func UserFrom(ctx context.Context) (model.User, bool) { + v, ok := ctx.Value(User).(model.User) + return v, ok +} + +func UsernameFrom(ctx context.Context) (string, bool) { + v, ok := ctx.Value(Username).(string) + return v, ok +} + +func ClientFrom(ctx context.Context) (string, bool) { + v, ok := ctx.Value(Client).(string) + return v, ok +} + +func VersionFrom(ctx context.Context) (string, bool) { + v, ok := ctx.Value(Version).(string) + return v, ok +} + +func PlayerFrom(ctx context.Context) (model.Player, bool) { + v, ok := ctx.Value(Player).(model.Player) + return v, ok +} + +func TranscodingFrom(ctx context.Context) (model.Transcoding, bool) { + v, ok := ctx.Value(Transcoding).(model.Transcoding) + return v, ok +} diff --git a/persistence/album_repository_test.go b/persistence/album_repository_test.go index 1472c2804..a5f729089 100644 --- a/persistence/album_repository_test.go +++ b/persistence/album_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,7 +15,7 @@ var _ = Describe("AlbumRepository", func() { var repo model.AlbumRepository BeforeEach(func() { - ctx := context.WithValue(log.NewContext(context.TODO()), "user", model.User{ID: "userid"}) + ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe"}) repo = NewAlbumRepository(ctx, orm.NewOrm()) }) diff --git a/persistence/artist_repository_test.go b/persistence/artist_repository_test.go index 79726e96c..a6e986322 100644 --- a/persistence/artist_repository_test.go +++ b/persistence/artist_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,7 +15,8 @@ var _ = Describe("ArtistRepository", func() { var repo model.ArtistRepository BeforeEach(func() { - ctx := context.WithValue(log.NewContext(context.TODO()), "user", model.User{ID: "userid"}) + ctx := log.NewContext(context.TODO()) + ctx = request.WithUser(ctx, model.User{ID: "userid"}) repo = NewArtistRepository(ctx, orm.NewOrm()) }) diff --git a/persistence/mediafile_repository_test.go b/persistence/mediafile_repository_test.go index 5231e4e93..3b929d729 100644 --- a/persistence/mediafile_repository_test.go +++ b/persistence/mediafile_repository_test.go @@ -7,6 +7,7 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/google/uuid" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -16,7 +17,8 @@ var _ = Describe("MediaRepository", func() { var mr model.MediaFileRepository BeforeEach(func() { - ctx := context.WithValue(log.NewContext(context.TODO()), "user", model.User{ID: "userid"}) + ctx := log.NewContext(context.TODO()) + ctx = request.WithUser(ctx, model.User{ID: "userid"}) mr = NewMediaFileRepository(ctx, orm.NewOrm()) }) diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go index a85315e4a..1263a518a 100644 --- a/persistence/persistence_suite_test.go +++ b/persistence/persistence_suite_test.go @@ -10,6 +10,7 @@ import ( "github.com/deluan/navidrome/db" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/tests" _ "github.com/mattn/go-sqlite3" . "github.com/onsi/ginkgo" @@ -84,7 +85,8 @@ var _ = Describe("Initialize test DB", func() { // TODO Load this data setup from file(s) BeforeSuite(func() { o := orm.NewOrm() - ctx := context.WithValue(log.NewContext(context.TODO()), "user", model.User{ID: "userid"}) + ctx := log.NewContext(context.TODO()) + ctx = request.WithUser(ctx, model.User{ID: "userid"}) mr := NewMediaFileRepository(ctx, o) for i := range testSongs { s := testSongs[i] diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index 1d8eaf81b..d4f25b15d 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -11,6 +11,7 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/google/uuid" ) @@ -24,21 +25,19 @@ type sqlRepository struct { const invalidUserId = "-1" func userId(ctx context.Context) string { - user := ctx.Value("user") - if user == nil { + if user, ok := request.UserFrom(ctx); !ok { return invalidUserId + } else { + return user.ID } - usr := user.(model.User) - return usr.ID } func loggedUser(ctx context.Context) *model.User { - user := ctx.Value("user") - if user == nil { + if user, ok := request.UserFrom(ctx); !ok { return &model.User{} + } else { + return &user } - u := user.(model.User) - return &u } func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder { diff --git a/server/app/auth.go b/server/app/auth.go index e800ad58c..16e9f1f72 100644 --- a/server/app/auth.go +++ b/server/app/auth.go @@ -12,6 +12,7 @@ import ( "github.com/deluan/navidrome/engine/auth" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/rest" "github.com/dgrijalva/jwt-go" "github.com/go-chi/jwtauth" @@ -146,7 +147,7 @@ func validateLogin(userRepo model.UserRepository, userName, password string) (*m func contextWithUser(ctx context.Context, ds model.DataStore, claims jwt.MapClaims) context.Context { userName := claims["sub"].(string) user, _ := ds.User(ctx).FindByUsername(userName) - return context.WithValue(ctx, "user", *user) + return request.WithUser(ctx, *user) } func getToken(ds model.DataStore, ctx context.Context) (*jwt.Token, error) { diff --git a/server/subsonic/helpers.go b/server/subsonic/helpers.go index 4b0489a0f..06233d574 100644 --- a/server/subsonic/helpers.go +++ b/server/subsonic/helpers.go @@ -10,6 +10,7 @@ import ( "github.com/deluan/navidrome/consts" "github.com/deluan/navidrome/engine" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/server/subsonic/responses" "github.com/deluan/navidrome/utils" ) @@ -154,10 +155,10 @@ func ToGenres(genres model.Genres) *responses.Genres { } func getTranscoding(ctx context.Context) (format string, bitRate int) { - if trc, ok := ctx.Value("transcoding").(model.Transcoding); ok { + if trc, ok := request.TranscodingFrom(ctx); ok { format = trc.TargetFormat } - if plr, ok := ctx.Value("player").(model.Player); ok { + if plr, ok := request.PlayerFrom(ctx); ok { bitRate = plr.MaxBitRate } return diff --git a/server/subsonic/middlewares.go b/server/subsonic/middlewares.go index aa102bf56..cb9ae2562 100644 --- a/server/subsonic/middlewares.go +++ b/server/subsonic/middlewares.go @@ -1,7 +1,6 @@ package subsonic import ( - "context" "fmt" "net" "net/http" @@ -11,6 +10,7 @@ import ( "github.com/deluan/navidrome/engine" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" "github.com/deluan/navidrome/server/subsonic/responses" "github.com/deluan/navidrome/utils" ) @@ -50,14 +50,14 @@ func checkRequiredParameters(next http.Handler) http.Handler { } } - user := utils.ParamString(r, "u") + username := utils.ParamString(r, "u") client := utils.ParamString(r, "c") version := utils.ParamString(r, "v") ctx := r.Context() - ctx = context.WithValue(ctx, "username", user) - ctx = context.WithValue(ctx, "client", client) - ctx = context.WithValue(ctx, "version", version) - log.Debug(ctx, "API: New request "+r.URL.Path, "username", user, "client", client, "version", version) + ctx = request.WithUsername(ctx, username) + ctx = request.WithClient(ctx, client) + ctx = request.WithVersion(ctx, version) + log.Debug(ctx, "API: New request "+r.URL.Path, "username", username, "client", client, "version", version) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -87,7 +87,7 @@ func authenticate(users engine.Users) func(next http.Handler) http.Handler { } ctx := r.Context() - ctx = context.WithValue(ctx, "user", *usr) + ctx = request.WithUser(ctx, *usr) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -99,17 +99,17 @@ func getPlayer(players engine.Players) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userName := ctx.Value("username").(string) - client := ctx.Value("client").(string) + userName, _ := request.UsernameFrom(ctx) + client, _ := request.ClientFrom(ctx) playerId := playerIDFromCookie(r, userName) ip, _, _ := net.SplitHostPort(r.RemoteAddr) player, trc, err := players.Register(ctx, playerId, client, r.Header.Get("user-agent"), ip) if err != nil { log.Error("Could not register player", "userName", userName, "client", client) } else { - ctx = context.WithValue(ctx, "player", *player) + ctx = request.WithPlayer(ctx, *player) if trc != nil { - ctx = context.WithValue(ctx, "transcoding", *trc) + ctx = request.WithTranscoding(ctx, *trc) } r = r.WithContext(ctx) diff --git a/server/subsonic/middlewares_test.go b/server/subsonic/middlewares_test.go index fe6eaa15a..dbe640c67 100644 --- a/server/subsonic/middlewares_test.go +++ b/server/subsonic/middlewares_test.go @@ -10,6 +10,7 @@ import ( "github.com/deluan/navidrome/engine" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -73,9 +74,13 @@ var _ = Describe("Middlewares", func() { cp := checkRequiredParameters(next) cp.ServeHTTP(w, r) - Expect(next.req.Context().Value("username")).To(Equal("user")) - Expect(next.req.Context().Value("version")).To(Equal("1.15")) - Expect(next.req.Context().Value("client")).To(Equal("test")) + username, _ := request.UsernameFrom(next.req.Context()) + Expect(username).To(Equal("user")) + version, _ := request.VersionFrom(next.req.Context()) + Expect(version).To(Equal("1.15")) + client, _ := request.ClientFrom(next.req.Context()) + Expect(client).To(Equal("test")) + Expect(next.called).To(BeTrue()) }) @@ -124,7 +129,7 @@ var _ = Describe("Middlewares", func() { Expect(mockedUsers.salt).To(Equal("salt")) Expect(mockedUsers.jwt).To(Equal("jwt")) Expect(next.called).To(BeTrue()) - user := next.req.Context().Value("user").(model.User) + user, _ := request.UserFrom(next.req.Context()) Expect(user.UserName).To(Equal("valid")) }) @@ -144,8 +149,8 @@ var _ = Describe("Middlewares", func() { BeforeEach(func() { mockedPlayers = &mockPlayers{} r = newGetRequest() - ctx := context.WithValue(r.Context(), "username", "someone") - ctx = context.WithValue(ctx, "client", "client") + ctx := request.WithUsername(r.Context(), "someone") + ctx = request.WithClient(ctx, "client") r = r.WithContext(ctx) }) @@ -158,7 +163,7 @@ var _ = Describe("Middlewares", func() { }) It("does not add the cookie if there was an error", func() { - ctx := context.WithValue(r.Context(), "client", "error") + ctx := request.WithClient(r.Context(), "error") r = r.WithContext(ctx) gp := getPlayer(mockedPlayers)(next) @@ -183,9 +188,10 @@ var _ = Describe("Middlewares", func() { It("stores the player in the context", func() { Expect(next.called).To(BeTrue()) - player := next.req.Context().Value("player").(model.Player) + player, _ := request.PlayerFrom(next.req.Context()) Expect(player.ID).To(Equal("123")) - Expect(next.req.Context().Value("transcoding")).To(BeNil()) + _, ok := request.TranscodingFrom(next.req.Context()) + Expect(ok).To(BeFalse()) }) It("returns the playerId in the cookie", func() { @@ -208,9 +214,9 @@ var _ = Describe("Middlewares", func() { }) It("stores the player in the context", func() { - player := next.req.Context().Value("player").(model.Player) + player, _ := request.PlayerFrom(next.req.Context()) Expect(player.ID).To(Equal("123")) - transcoding := next.req.Context().Value("transcoding").(model.Transcoding) + transcoding, _ := request.TranscodingFrom(next.req.Context()) Expect(transcoding.ID).To(Equal("12")) }) }) diff --git a/server/subsonic/responses/responses_suite_test.go b/server/subsonic/responses/responses_suite_test.go index 742d4b13c..dcfb44b19 100644 --- a/server/subsonic/responses/responses_suite_test.go +++ b/server/subsonic/responses/responses_suite_test.go @@ -1,7 +1,6 @@ package responses import ( - "fmt" "testing" "github.com/bradleyjkemp/cupaloy" @@ -33,9 +32,9 @@ func (matcher snapshotMatcher) Match(actual interface{}) (success bool, err erro } func (matcher snapshotMatcher) FailureMessage(actual interface{}) (message string) { - return fmt.Sprintf("Expected to match saved snapshot\n") + return "Expected to match saved snapshot\n" } func (matcher snapshotMatcher) NegatedFailureMessage(actual interface{}) (message string) { - return fmt.Sprintf("Expected to not match saved snapshot\n") + return "Expected to not match saved snapshot\n" }