Refactor: Consolidate scrobbling logic in play_tracker

This commit is contained in:
Deluan 2021-06-22 23:56:29 -04:00 committed by Deluan Quintão
parent 76acd7da89
commit 056f0b944f
11 changed files with 243 additions and 120 deletions

View file

@ -50,8 +50,8 @@ func CreateSubsonicAPIRouter() *subsonic.Router {
externalMetadata := core.NewExternalMetadata(dataStore, agentsAgents) externalMetadata := core.NewExternalMetadata(dataStore, agentsAgents)
scanner := GetScanner() scanner := GetScanner()
broker := events.GetBroker() broker := events.GetBroker()
scrobblerBroker := scrobbler.GetBroker(dataStore) playTracker := scrobbler.GetPlayTracker(dataStore, broker)
router := subsonic.New(dataStore, artwork, mediaStreamer, archiver, players, externalMetadata, scanner, broker, scrobblerBroker) router := subsonic.New(dataStore, artwork, mediaStreamer, archiver, players, externalMetadata, scanner, broker, playTracker)
return router return router
} }

View file

@ -5,6 +5,8 @@ import (
"sort" "sort"
"time" "time"
"github.com/navidrome/navidrome/server/events"
"github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/log"
"github.com/ReneKroon/ttlcache/v2" "github.com/ReneKroon/ttlcache/v2"
@ -23,28 +25,34 @@ type NowPlayingInfo struct {
PlayerName string PlayerName string
} }
type Broker interface { type Submission struct {
NowPlaying(ctx context.Context, playerId string, playerName string, trackId string) error TrackID string
GetNowPlaying(ctx context.Context) ([]NowPlayingInfo, error) Timestamp time.Time
Submit(ctx context.Context, trackId string, playTime time.Time) error
} }
type broker struct { type PlayTracker interface {
NowPlaying(ctx context.Context, playerId string, playerName string, trackId string) error
GetNowPlaying(ctx context.Context) ([]NowPlayingInfo, error)
Submit(ctx context.Context, submissions []Submission) error
}
type playTracker struct {
ds model.DataStore ds model.DataStore
broker events.Broker
playMap *ttlcache.Cache playMap *ttlcache.Cache
} }
func GetBroker(ds model.DataStore) Broker { func GetPlayTracker(ds model.DataStore, broker events.Broker) PlayTracker {
instance := singleton.Get(broker{}, func() interface{} { instance := singleton.Get(playTracker{}, func() interface{} {
m := ttlcache.NewCache() m := ttlcache.NewCache()
m.SkipTTLExtensionOnHit(true) m.SkipTTLExtensionOnHit(true)
_ = m.SetTTL(nowPlayingExpire) _ = m.SetTTL(nowPlayingExpire)
return &broker{ds: ds, playMap: m} return &playTracker{ds: ds, playMap: m, broker: broker}
}) })
return instance.(*broker) return instance.(*playTracker)
} }
func (s *broker) NowPlaying(ctx context.Context, playerId string, playerName string, trackId string) error { func (p *playTracker) NowPlaying(ctx context.Context, playerId string, playerName string, trackId string) error {
user, _ := request.UserFrom(ctx) user, _ := request.UserFrom(ctx)
info := NowPlayingInfo{ info := NowPlayingInfo{
TrackID: trackId, TrackID: trackId,
@ -53,13 +61,13 @@ func (s *broker) NowPlaying(ctx context.Context, playerId string, playerName str
PlayerId: playerId, PlayerId: playerId,
PlayerName: playerName, PlayerName: playerName,
} }
_ = s.playMap.Set(playerId, info) _ = p.playMap.Set(playerId, info)
s.dispatchNowPlaying(ctx, user.ID, trackId) p.dispatchNowPlaying(ctx, user.ID, trackId)
return nil return nil
} }
func (s *broker) dispatchNowPlaying(ctx context.Context, userId string, trackId string) { func (p *playTracker) dispatchNowPlaying(ctx context.Context, userId string, trackId string) {
t, err := s.ds.MediaFile(ctx).Get(trackId) t, err := p.ds.MediaFile(ctx).Get(trackId)
if err != nil { if err != nil {
log.Error(ctx, "Error retrieving mediaFile", "id", trackId, err) log.Error(ctx, "Error retrieving mediaFile", "id", trackId, err)
return return
@ -67,7 +75,7 @@ func (s *broker) dispatchNowPlaying(ctx context.Context, userId string, trackId
// TODO Parallelize // TODO Parallelize
for name, constructor := range scrobblers { for name, constructor := range scrobblers {
err := func() error { err := func() error {
s := constructor(s.ds) s := constructor(p.ds)
if !s.IsAuthorized(ctx, userId) { if !s.IsAuthorized(ctx, userId) {
return nil return nil
} }
@ -81,10 +89,10 @@ func (s *broker) dispatchNowPlaying(ctx context.Context, userId string, trackId
} }
} }
func (s *broker) GetNowPlaying(ctx context.Context) ([]NowPlayingInfo, error) { func (p *playTracker) GetNowPlaying(ctx context.Context) ([]NowPlayingInfo, error) {
var res []NowPlayingInfo var res []NowPlayingInfo
for _, playerId := range s.playMap.GetKeys() { for _, playerId := range p.playMap.GetKeys() {
value, err := s.playMap.Get(playerId) value, err := p.playMap.Get(playerId)
if err != nil { if err != nil {
continue continue
} }
@ -97,18 +105,56 @@ func (s *broker) GetNowPlaying(ctx context.Context) ([]NowPlayingInfo, error) {
return res, nil return res, nil
} }
func (s *broker) Submit(ctx context.Context, trackId string, playTime time.Time) error { func (p *playTracker) Submit(ctx context.Context, submissions []Submission) error {
u, _ := request.UserFrom(ctx) username, _ := request.UsernameFrom(ctx)
t, err := s.ds.MediaFile(ctx).Get(trackId) event := &events.RefreshResource{}
if err != nil { success := 0
log.Error(ctx, "Error retrieving mediaFile", "id", trackId, err)
return err for _, s := range submissions {
mf, err := p.ds.MediaFile(ctx).Get(s.TrackID)
if err != nil {
log.Error("Cannot find track for scrobbling", "id", s.TrackID, "user", username, err)
continue
}
err = p.incPlay(ctx, mf, s.Timestamp)
if err != nil {
log.Error("Error updating play counts", "id", mf.ID, "track", mf.Title, "user", username, err)
} else {
success++
event.With("song", mf.ID).With("album", mf.AlbumID).With("artist", mf.AlbumArtistID)
log.Info("Scrobbled", "title", mf.Title, "artist", mf.Artist, "user", username)
_ = p.dispatchScrobble(ctx, mf, s.Timestamp)
}
} }
if success > 0 {
p.broker.SendMessage(ctx, event)
}
return nil
}
func (p *playTracker) incPlay(ctx context.Context, track *model.MediaFile, timestamp time.Time) error {
return p.ds.WithTx(func(tx model.DataStore) error {
err := p.ds.MediaFile(ctx).IncPlayCount(track.ID, timestamp)
if err != nil {
return err
}
err = p.ds.Album(ctx).IncPlayCount(track.AlbumID, timestamp)
if err != nil {
return err
}
err = p.ds.Artist(ctx).IncPlayCount(track.ArtistID, timestamp)
return err
})
}
func (p *playTracker) dispatchScrobble(ctx context.Context, t *model.MediaFile, playTime time.Time) error {
u, _ := request.UserFrom(ctx)
scrobbles := []Scrobble{{MediaFile: *t, TimeStamp: playTime}} scrobbles := []Scrobble{{MediaFile: *t, TimeStamp: playTime}}
// TODO Parallelize // TODO Parallelize
for name, constructor := range scrobblers { for name, constructor := range scrobblers {
err := func() error { err := func() error {
s := constructor(s.ds) s := constructor(p.ds)
if !s.IsAuthorized(ctx, u.ID) { if !s.IsAuthorized(ctx, u.ID) {
return nil return nil
} }

View file

@ -2,8 +2,11 @@ package scrobbler
import ( import (
"context" "context"
"errors"
"time" "time"
"github.com/navidrome/navidrome/server/events"
"github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request" "github.com/navidrome/navidrome/model/request"
"github.com/navidrome/navidrome/tests" "github.com/navidrome/navidrome/tests"
@ -11,17 +14,19 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Broker", func() { var _ = Describe("PlayTracker", func() {
var ctx context.Context var ctx context.Context
var ds model.DataStore var ds model.DataStore
var broker Broker var broker PlayTracker
var track model.MediaFile var track model.MediaFile
var album model.Album
var artist model.Artist
var fake *fakeScrobbler var fake *fakeScrobbler
BeforeEach(func() { BeforeEach(func() {
ctx = context.Background() ctx = context.Background()
ctx = request.WithUser(ctx, model.User{ID: "u-1"}) ctx = request.WithUser(ctx, model.User{ID: "u-1"})
ds = &tests.MockDataStore{} ds = &tests.MockDataStore{}
broker = GetBroker(ds) broker = GetPlayTracker(ds, events.GetBroker())
fake = &fakeScrobbler{Authorized: true} fake = &fakeScrobbler{Authorized: true}
Register("fake", func(ds model.DataStore) Scrobbler { Register("fake", func(ds model.DataStore) Scrobbler {
return fake return fake
@ -31,13 +36,19 @@ var _ = Describe("Broker", func() {
ID: "123", ID: "123",
Title: "Track Title", Title: "Track Title",
Album: "Track Album", Album: "Track Album",
AlbumID: "al-1",
Artist: "Track Artist", Artist: "Track Artist",
ArtistID: "ar-1",
AlbumArtist: "Track AlbumArtist", AlbumArtist: "Track AlbumArtist",
TrackNumber: 1, TrackNumber: 1,
Duration: 180, Duration: 180,
MbzTrackID: "mbz-123", MbzTrackID: "mbz-123",
} }
_ = ds.MediaFile(ctx).Put(&track) _ = ds.MediaFile(ctx).Put(&track)
artist = model.Artist{ID: "ar-1"}
_ = ds.Artist(ctx).Put(&artist)
album = model.Album{ID: "al-1"}
_ = ds.Album(ctx).(*tests.MockAlbumRepo).Put(&album)
}) })
Describe("NowPlaying", func() { Describe("NowPlaying", func() {
@ -50,9 +61,11 @@ var _ = Describe("Broker", func() {
}) })
It("does not send track to agent if user has not authorized", func() { It("does not send track to agent if user has not authorized", func() {
fake.Authorized = false fake.Authorized = false
err := broker.NowPlaying(ctx, "player-1", "player-one", "123") err := broker.NowPlaying(ctx, "player-1", "player-one", "123")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(fake.NowPlayingCalled).ToNot(BeTrue()) Expect(fake.NowPlayingCalled).To(BeFalse())
}) })
}) })
@ -90,7 +103,7 @@ var _ = Describe("Broker", func() {
ctx = request.WithUser(ctx, model.User{ID: "u-1", UserName: "user-1"}) ctx = request.WithUser(ctx, model.User{ID: "u-1", UserName: "user-1"})
ts := time.Now() ts := time.Now()
err := broker.Submit(ctx, "123", ts) err := broker.Submit(ctx, []Submission{{TrackID: "123", Timestamp: ts}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(fake.ScrobbleCalled).To(BeTrue()) Expect(fake.ScrobbleCalled).To(BeTrue())
@ -98,11 +111,38 @@ var _ = Describe("Broker", func() {
Expect(fake.Scrobbles[0].ID).To(Equal("123")) Expect(fake.Scrobbles[0].ID).To(Equal("123"))
}) })
It("increments play counts in the DB", func() {
ctx = request.WithUser(ctx, model.User{ID: "u-1", UserName: "user-1"})
ts := time.Now()
err := broker.Submit(ctx, []Submission{{TrackID: "123", Timestamp: ts}})
Expect(err).ToNot(HaveOccurred())
Expect(track.PlayCount).To(Equal(int64(1)))
Expect(album.PlayCount).To(Equal(int64(1)))
Expect(artist.PlayCount).To(Equal(int64(1)))
})
It("does not send track to agent if user has not authorized", func() { It("does not send track to agent if user has not authorized", func() {
fake.Authorized = false fake.Authorized = false
err := broker.Submit(ctx, "123", time.Now())
err := broker.Submit(ctx, []Submission{{TrackID: "123", Timestamp: time.Now()}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(fake.ScrobbleCalled).ToNot(BeTrue()) Expect(fake.ScrobbleCalled).To(BeFalse())
})
It("increments play counts even if it cannot scrobble", func() {
fake.Error = errors.New("error")
err := broker.Submit(ctx, []Submission{{TrackID: "123", Timestamp: time.Now()}})
Expect(err).ToNot(HaveOccurred())
Expect(fake.ScrobbleCalled).To(BeFalse())
Expect(track.PlayCount).To(Equal(int64(1)))
Expect(album.PlayCount).To(Equal(int64(1)))
Expect(artist.PlayCount).To(Equal(int64(1)))
}) })
}) })

View file

@ -18,6 +18,6 @@ var Set = wire.NewSet(
NewPlayers, NewPlayers,
agents.New, agents.New,
transcoder.New, transcoder.New,
scrobbler.GetBroker, scrobbler.GetPlayTracker,
NewShare, NewShare,
) )

View file

@ -16,10 +16,10 @@ import (
type AlbumListController struct { type AlbumListController struct {
ds model.DataStore ds model.DataStore
scrobbler scrobbler.Broker scrobbler scrobbler.PlayTracker
} }
func NewAlbumListController(ds model.DataStore, scrobbler scrobbler.Broker) *AlbumListController { func NewAlbumListController(ds model.DataStore, scrobbler scrobbler.PlayTracker) *AlbumListController {
c := &AlbumListController{ c := &AlbumListController{
ds: ds, ds: ds,
scrobbler: scrobbler, scrobbler: scrobbler,

View file

@ -34,11 +34,11 @@ type Router struct {
ExternalMetadata core.ExternalMetadata ExternalMetadata core.ExternalMetadata
Scanner scanner.Scanner Scanner scanner.Scanner
Broker events.Broker Broker events.Broker
Scrobbler scrobbler.Broker Scrobbler scrobbler.PlayTracker
} }
func New(ds model.DataStore, artwork core.Artwork, streamer core.MediaStreamer, archiver core.Archiver, players core.Players, func New(ds model.DataStore, artwork core.Artwork, streamer core.MediaStreamer, archiver core.Archiver, players core.Players,
externalMetadata core.ExternalMetadata, scanner scanner.Scanner, broker events.Broker, scrobbler scrobbler.Broker) *Router { externalMetadata core.ExternalMetadata, scanner scanner.Scanner, broker events.Broker, scrobbler scrobbler.PlayTracker) *Router {
r := &Router{ r := &Router{
DataStore: ds, DataStore: ds,
Artwork: artwork, Artwork: artwork,

View file

@ -18,11 +18,11 @@ import (
type MediaAnnotationController struct { type MediaAnnotationController struct {
ds model.DataStore ds model.DataStore
scrobbler scrobbler.Broker scrobbler scrobbler.PlayTracker
broker events.Broker broker events.Broker
} }
func NewMediaAnnotationController(ds model.DataStore, scrobbler scrobbler.Broker, broker events.Broker) *MediaAnnotationController { func NewMediaAnnotationController(ds model.DataStore, scrobbler scrobbler.PlayTracker, broker events.Broker) *MediaAnnotationController {
return &MediaAnnotationController{ds: ds, scrobbler: scrobbler, broker: broker} return &MediaAnnotationController{ds: ds, scrobbler: scrobbler, broker: broker}
} }
@ -126,10 +126,25 @@ func (c *MediaAnnotationController) Scrobble(w http.ResponseWriter, r *http.Requ
} }
submission := utils.ParamBool(r, "submission", true) submission := utils.ParamBool(r, "submission", true)
ctx := r.Context() ctx := r.Context()
event := &events.RefreshResource{}
submissions := 0
log.Debug(r, "Scrobbling tracks", "ids", ids, "times", times, "submission", submission) if submission {
err := c.scrobblerSubmit(ctx, ids, times)
if err != nil {
log.Error(ctx, "Error registering scrobbles", "ids", ids, "times", times, err)
}
} else {
err := c.scrobblerNowPlaying(ctx, ids[0])
if err != nil {
log.Error(ctx, "Error setting NowPlaying", "id", ids[0], err)
}
}
return newResponse(), nil
}
func (c *MediaAnnotationController) scrobblerSubmit(ctx context.Context, ids []string, times []time.Time) error {
var submissions []scrobbler.Submission
log.Debug(ctx, "Scrobbling tracks", "ids", ids, "times", times)
for i, id := range ids { for i, id := range ids {
var t time.Time var t time.Time
if len(times) > 0 { if len(times) > 0 {
@ -137,57 +152,10 @@ func (c *MediaAnnotationController) Scrobble(w http.ResponseWriter, r *http.Requ
} else { } else {
t = time.Now() t = time.Now()
} }
if submission { submissions = append(submissions, scrobbler.Submission{TrackID: id, Timestamp: t})
mf, err := c.scrobblerRegister(ctx, id, t)
if err != nil {
log.Error(r, "Error scrobbling track", "id", id, err)
continue
}
submissions++
event.With("song", mf.ID).With("album", mf.AlbumID).With("artist", mf.AlbumArtistID)
}
if !submission || len(times) == 0 {
err := c.scrobblerNowPlaying(ctx, id)
if err != nil {
log.Error(r, "Error setting current song", "id", id, err)
continue
}
}
} }
if submissions > 0 {
c.broker.SendMessage(ctx, event)
}
return newResponse(), nil
}
func (c *MediaAnnotationController) scrobblerRegister(ctx context.Context, trackId string, playTime time.Time) (*model.MediaFile, error) { return c.scrobbler.Submit(ctx, submissions)
var mf *model.MediaFile
var err error
err = c.ds.WithTx(func(tx model.DataStore) error {
mf, err = c.ds.MediaFile(ctx).Get(trackId)
if err != nil {
return err
}
err = c.ds.MediaFile(ctx).IncPlayCount(trackId, playTime)
if err != nil {
return err
}
err = c.ds.Album(ctx).IncPlayCount(mf.AlbumID, playTime)
if err != nil {
return err
}
err = c.ds.Artist(ctx).IncPlayCount(mf.ArtistID, playTime)
return err
})
username, _ := request.UsernameFrom(ctx)
if err != nil {
log.Error("Error while scrobbling", "trackId", trackId, "user", username, err)
} else {
log.Info("Scrobbled", "title", mf.Title, "artist", mf.Artist, "user", username)
}
_ = c.scrobbler.Submit(ctx, trackId, playTime)
return mf, err
} }
func (c *MediaAnnotationController) scrobblerNowPlaying(ctx context.Context, trackId string) error { func (c *MediaAnnotationController) scrobblerNowPlaying(ctx context.Context, trackId string) error {

View file

@ -25,16 +25,16 @@ func initBrowsingController(router *Router) *BrowsingController {
func initAlbumListController(router *Router) *AlbumListController { func initAlbumListController(router *Router) *AlbumListController {
dataStore := router.DataStore dataStore := router.DataStore
broker := router.Scrobbler playTracker := router.Scrobbler
albumListController := NewAlbumListController(dataStore, broker) albumListController := NewAlbumListController(dataStore, playTracker)
return albumListController return albumListController
} }
func initMediaAnnotationController(router *Router) *MediaAnnotationController { func initMediaAnnotationController(router *Router) *MediaAnnotationController {
dataStore := router.DataStore dataStore := router.DataStore
broker := router.Scrobbler playTracker := router.Scrobbler
eventsBroker := router.Broker broker := router.Broker
mediaAnnotationController := NewMediaAnnotationController(dataStore, broker, eventsBroker) mediaAnnotationController := NewMediaAnnotationController(dataStore, playTracker, broker)
return mediaAnnotationController return mediaAnnotationController
} }

View file

@ -2,17 +2,22 @@ package tests
import ( import (
"errors" "errors"
"time"
"github.com/google/uuid"
"github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model"
) )
func CreateMockAlbumRepo() *MockAlbumRepo { func CreateMockAlbumRepo() *MockAlbumRepo {
return &MockAlbumRepo{} return &MockAlbumRepo{
data: make(map[string]*model.Album),
}
} }
type MockAlbumRepo struct { type MockAlbumRepo struct {
model.AlbumRepository model.AlbumRepository
data map[string]model.Album data map[string]*model.Album
all model.Albums all model.Albums
err bool err bool
Options model.QueryOptions Options model.QueryOptions
@ -23,10 +28,10 @@ func (m *MockAlbumRepo) SetError(err bool) {
} }
func (m *MockAlbumRepo) SetData(albums model.Albums) { func (m *MockAlbumRepo) SetData(albums model.Albums) {
m.data = make(map[string]model.Album) m.data = make(map[string]*model.Album)
m.all = albums m.all = albums
for _, a := range m.all { for i, a := range m.all {
m.data[a.ID] = a m.data[a.ID] = &m.all[i]
} }
} }
@ -43,11 +48,22 @@ func (m *MockAlbumRepo) Get(id string) (*model.Album, error) {
return nil, errors.New("Error!") return nil, errors.New("Error!")
} }
if d, ok := m.data[id]; ok { if d, ok := m.data[id]; ok {
return &d, nil return d, nil
} }
return nil, model.ErrNotFound return nil, model.ErrNotFound
} }
func (m *MockAlbumRepo) Put(al *model.Album) error {
if m.err {
return errors.New("error")
}
if al.ID == "" {
al.ID = uuid.NewString()
}
m.data[al.ID] = al
return nil
}
func (m *MockAlbumRepo) GetAll(qo ...model.QueryOptions) (model.Albums, error) { func (m *MockAlbumRepo) GetAll(qo ...model.QueryOptions) (model.Albums, error) {
if len(qo) > 0 { if len(qo) > 0 {
m.Options = qo[0] m.Options = qo[0]
@ -58,6 +74,18 @@ func (m *MockAlbumRepo) GetAll(qo ...model.QueryOptions) (model.Albums, error) {
return m.all, nil return m.all, nil
} }
func (m *MockAlbumRepo) IncPlayCount(id string, timestamp time.Time) error {
if m.err {
return errors.New("error")
}
if d, ok := m.data[id]; ok {
d.PlayCount++
d.PlayDate = timestamp
return nil
}
return model.ErrNotFound
}
func (m *MockAlbumRepo) FindByArtist(artistId string) (model.Albums, error) { func (m *MockAlbumRepo) FindByArtist(artistId string) (model.Albums, error) {
if m.err { if m.err {
return nil, errors.New("Error!") return nil, errors.New("Error!")
@ -66,7 +94,7 @@ func (m *MockAlbumRepo) FindByArtist(artistId string) (model.Albums, error) {
i := 0 i := 0
for _, a := range m.data { for _, a := range m.data {
if a.AlbumArtistID == artistId { if a.AlbumArtistID == artistId {
res[i] = a res[i] = *a
i++ i++
} }
} }

View file

@ -2,17 +2,22 @@ package tests
import ( import (
"errors" "errors"
"time"
"github.com/google/uuid"
"github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model"
) )
func CreateMockArtistRepo() *MockArtistRepo { func CreateMockArtistRepo() *MockArtistRepo {
return &MockArtistRepo{} return &MockArtistRepo{
data: make(map[string]*model.Artist),
}
} }
type MockArtistRepo struct { type MockArtistRepo struct {
model.ArtistRepository model.ArtistRepository
data map[string]model.Artist data map[string]*model.Artist
err bool err bool
} }
@ -21,9 +26,9 @@ func (m *MockArtistRepo) SetError(err bool) {
} }
func (m *MockArtistRepo) SetData(artists model.Artists) { func (m *MockArtistRepo) SetData(artists model.Artists) {
m.data = make(map[string]model.Artist) m.data = make(map[string]*model.Artist)
for _, a := range artists { for i, a := range artists {
m.data[a.ID] = a m.data[a.ID] = &artists[i]
} }
} }
@ -40,9 +45,32 @@ func (m *MockArtistRepo) Get(id string) (*model.Artist, error) {
return nil, errors.New("Error!") return nil, errors.New("Error!")
} }
if d, ok := m.data[id]; ok { if d, ok := m.data[id]; ok {
return &d, nil return d, nil
} }
return nil, model.ErrNotFound return nil, model.ErrNotFound
} }
func (m *MockArtistRepo) Put(ar *model.Artist) error {
if m.err {
return errors.New("error")
}
if ar.ID == "" {
ar.ID = uuid.NewString()
}
m.data[ar.ID] = ar
return nil
}
func (m *MockArtistRepo) IncPlayCount(id string, timestamp time.Time) error {
if m.err {
return errors.New("error")
}
if d, ok := m.data[id]; ok {
d.PlayCount++
d.PlayDate = timestamp
return nil
}
return model.ErrNotFound
}
var _ model.ArtistRepository = (*MockArtistRepo)(nil) var _ model.ArtistRepository = (*MockArtistRepo)(nil)

View file

@ -2,6 +2,7 @@ package tests
import ( import (
"errors" "errors"
"time"
"github.com/google/uuid" "github.com/google/uuid"
@ -10,13 +11,13 @@ import (
func CreateMockMediaFileRepo() *MockMediaFileRepo { func CreateMockMediaFileRepo() *MockMediaFileRepo {
return &MockMediaFileRepo{ return &MockMediaFileRepo{
data: make(map[string]model.MediaFile), data: make(map[string]*model.MediaFile),
} }
} }
type MockMediaFileRepo struct { type MockMediaFileRepo struct {
model.MediaFileRepository model.MediaFileRepository
data map[string]model.MediaFile data map[string]*model.MediaFile
err bool err bool
} }
@ -25,9 +26,9 @@ func (m *MockMediaFileRepo) SetError(err bool) {
} }
func (m *MockMediaFileRepo) SetData(mfs model.MediaFiles) { func (m *MockMediaFileRepo) SetData(mfs model.MediaFiles) {
m.data = make(map[string]model.MediaFile) m.data = make(map[string]*model.MediaFile)
for _, mf := range mfs { for i, mf := range mfs {
m.data[mf.ID] = mf m.data[mf.ID] = &mfs[i]
} }
} }
@ -44,22 +45,34 @@ func (m *MockMediaFileRepo) Get(id string) (*model.MediaFile, error) {
return nil, errors.New("Error!") return nil, errors.New("Error!")
} }
if d, ok := m.data[id]; ok { if d, ok := m.data[id]; ok {
return &d, nil return d, nil
} }
return nil, model.ErrNotFound return nil, model.ErrNotFound
} }
func (m *MockMediaFileRepo) Put(mf *model.MediaFile) error { func (m *MockMediaFileRepo) Put(mf *model.MediaFile) error {
if m.err { if m.err {
return errors.New("error!") return errors.New("error")
} }
if mf.ID == "" { if mf.ID == "" {
mf.ID = uuid.NewString() mf.ID = uuid.NewString()
} }
m.data[mf.ID] = *mf m.data[mf.ID] = mf
return nil return nil
} }
func (m *MockMediaFileRepo) IncPlayCount(id string, timestamp time.Time) error {
if m.err {
return errors.New("error")
}
if d, ok := m.data[id]; ok {
d.PlayCount++
d.PlayDate = timestamp
return nil
}
return model.ErrNotFound
}
func (m *MockMediaFileRepo) FindByAlbum(artistId string) (model.MediaFiles, error) { func (m *MockMediaFileRepo) FindByAlbum(artistId string) (model.MediaFiles, error) {
if m.err { if m.err {
return nil, errors.New("Error!") return nil, errors.New("Error!")
@ -68,7 +81,7 @@ func (m *MockMediaFileRepo) FindByAlbum(artistId string) (model.MediaFiles, erro
i := 0 i := 0
for _, a := range m.data { for _, a := range m.data {
if a.AlbumID == artistId { if a.AlbumID == artistId {
res[i] = a res[i] = *a
i++ i++
} }
} }