refactor: add Context to the persistence layer

This commit is contained in:
Deluan 2020-01-27 09:41:33 -05:00
parent 3c54b776d6
commit 0d2911daf9
18 changed files with 147 additions and 140 deletions

View file

@ -32,11 +32,11 @@ type browser struct {
} }
func (b *browser) MediaFolders(ctx context.Context) (model.MediaFolders, error) { func (b *browser) MediaFolders(ctx context.Context) (model.MediaFolders, error) {
return b.ds.MediaFolder().GetAll() return b.ds.MediaFolder(ctx).GetAll()
} }
func (b *browser) Indexes(ctx context.Context, ifModifiedSince time.Time) (model.ArtistIndexes, time.Time, error) { func (b *browser) Indexes(ctx context.Context, ifModifiedSince time.Time) (model.ArtistIndexes, time.Time, error) {
l, err := b.ds.Property().DefaultGet(model.PropLastScan, "-1") l, err := b.ds.Property(ctx).DefaultGet(model.PropLastScan, "-1")
ms, _ := strconv.ParseInt(l, 10, 64) ms, _ := strconv.ParseInt(l, 10, 64)
lastModified := utils.ToTime(ms) lastModified := utils.ToTime(ms)
@ -45,7 +45,7 @@ func (b *browser) Indexes(ctx context.Context, ifModifiedSince time.Time) (model
} }
if lastModified.After(ifModifiedSince) { if lastModified.After(ifModifiedSince) {
indexes, err := b.ds.Artist().GetIndex() indexes, err := b.ds.Artist(ctx).GetIndex()
return indexes, lastModified, err return indexes, lastModified, err
} }
@ -72,7 +72,7 @@ type DirectoryInfo struct {
} }
func (b *browser) Artist(ctx context.Context, id string) (*DirectoryInfo, error) { func (b *browser) Artist(ctx context.Context, id string) (*DirectoryInfo, error) {
a, albums, err := b.retrieveArtist(id) a, albums, err := b.retrieveArtist(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -81,12 +81,12 @@ func (b *browser) Artist(ctx context.Context, id string) (*DirectoryInfo, error)
for _, al := range albums { for _, al := range albums {
albumIds = append(albumIds, al.ID) albumIds = append(albumIds, al.ID)
} }
annMap, err := b.ds.Annotation().GetMap(getUserID(ctx), model.AlbumItemType, albumIds) annMap, err := b.ds.Annotation(ctx).GetMap(getUserID(ctx), model.AlbumItemType, albumIds)
return b.buildArtistDir(a, albums, annMap), nil return b.buildArtistDir(a, albums, annMap), nil
} }
func (b *browser) Album(ctx context.Context, id string) (*DirectoryInfo, error) { func (b *browser) Album(ctx context.Context, id string) (*DirectoryInfo, error) {
al, tracks, err := b.retrieveAlbum(id) al, tracks, err := b.retrieveAlbum(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,11 +97,11 @@ func (b *browser) Album(ctx context.Context, id string) (*DirectoryInfo, error)
} }
userID := getUserID(ctx) userID := getUserID(ctx)
trackAnnMap, err := b.ds.Annotation().GetMap(userID, model.MediaItemType, mfIds) trackAnnMap, err := b.ds.Annotation(ctx).GetMap(userID, model.MediaItemType, mfIds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ann, err := b.ds.Annotation().Get(userID, model.AlbumItemType, al.ID) ann, err := b.ds.Annotation(ctx).Get(userID, model.AlbumItemType, al.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,13 +121,13 @@ func (b *browser) Directory(ctx context.Context, id string) (*DirectoryInfo, err
} }
func (b *browser) GetSong(ctx context.Context, id string) (*Entry, error) { func (b *browser) GetSong(ctx context.Context, id string) (*Entry, error) {
mf, err := b.ds.MediaFile().Get(id) mf, err := b.ds.MediaFile(ctx).Get(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userId := getUserID(ctx) userId := getUserID(ctx)
ann, err := b.ds.Annotation().Get(userId, model.MediaItemType, id) ann, err := b.ds.Annotation(ctx).Get(userId, model.MediaItemType, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,7 +137,7 @@ func (b *browser) GetSong(ctx context.Context, id string) (*Entry, error) {
} }
func (b *browser) GetGenres(ctx context.Context) (model.Genres, error) { func (b *browser) GetGenres(ctx context.Context) (model.Genres, error) {
genres, err := b.ds.Genre().GetAll() genres, err := b.ds.Genre(ctx).GetAll()
for i, g := range genres { for i, g := range genres {
if strings.TrimSpace(g.Name) == "" { if strings.TrimSpace(g.Name) == "" {
genres[i].Name = "<Empty>" genres[i].Name = "<Empty>"
@ -195,7 +195,7 @@ func (b *browser) buildAlbumDir(al *model.Album, albumAnn *model.Annotation, tra
} }
func (b *browser) isArtist(ctx context.Context, id string) bool { func (b *browser) isArtist(ctx context.Context, id string) bool {
found, err := b.ds.Artist().Exists(id) found, err := b.ds.Artist(ctx).Exists(id)
if err != nil { if err != nil {
log.Debug(ctx, "Error searching for Artist", "id", id, err) log.Debug(ctx, "Error searching for Artist", "id", id, err)
return false return false
@ -204,7 +204,7 @@ func (b *browser) isArtist(ctx context.Context, id string) bool {
} }
func (b *browser) isAlbum(ctx context.Context, id string) bool { func (b *browser) isAlbum(ctx context.Context, id string) bool {
found, err := b.ds.Album().Exists(id) found, err := b.ds.Album(ctx).Exists(id)
if err != nil { if err != nil {
log.Debug(ctx, "Error searching for Album", "id", id, err) log.Debug(ctx, "Error searching for Album", "id", id, err)
return false return false
@ -212,27 +212,27 @@ func (b *browser) isAlbum(ctx context.Context, id string) bool {
return found return found
} }
func (b *browser) retrieveArtist(id string) (a *model.Artist, as model.Albums, err error) { func (b *browser) retrieveArtist(ctx context.Context, id string) (a *model.Artist, as model.Albums, err error) {
a, err = b.ds.Artist().Get(id) a, err = b.ds.Artist(ctx).Get(id)
if err != nil { if err != nil {
err = fmt.Errorf("Error reading Artist %s from DB: %v", id, err) err = fmt.Errorf("Error reading Artist %s from DB: %v", id, err)
return return
} }
if as, err = b.ds.Album().FindByArtist(id); err != nil { if as, err = b.ds.Album(ctx).FindByArtist(id); err != nil {
err = fmt.Errorf("Error reading %s's albums from DB: %v", a.Name, err) err = fmt.Errorf("Error reading %s's albums from DB: %v", a.Name, err)
} }
return return
} }
func (b *browser) retrieveAlbum(id string) (al *model.Album, mfs model.MediaFiles, err error) { func (b *browser) retrieveAlbum(ctx context.Context, id string) (al *model.Album, mfs model.MediaFiles, err error) {
al, err = b.ds.Album().Get(id) al, err = b.ds.Album(ctx).Get(id)
if err != nil { if err != nil {
err = fmt.Errorf("Error reading Album %s from DB: %v", id, err) err = fmt.Errorf("Error reading Album %s from DB: %v", id, err)
return return
} }
if mfs, err = b.ds.MediaFile().FindByAlbum(id); err != nil { if mfs, err = b.ds.MediaFile(ctx).FindByAlbum(id); err != nil {
err = fmt.Errorf("Error reading %s's tracks from DB: %v", al.Name, err) err = fmt.Errorf("Error reading %s's tracks from DB: %v", al.Name, err)
} }
return return

View file

@ -31,17 +31,17 @@ func NewCover(ds model.DataStore) Cover {
return &cover{ds} return &cover{ds}
} }
func (c *cover) getCoverPath(id string) (string, error) { func (c *cover) getCoverPath(ctx context.Context, id string) (string, error) {
switch { switch {
case strings.HasPrefix(id, "al-"): case strings.HasPrefix(id, "al-"):
id = id[3:] id = id[3:]
al, err := c.ds.Album().Get(id) al, err := c.ds.Album(ctx).Get(id)
if err != nil { if err != nil {
return "", err return "", err
} }
return al.CoverArtPath, nil return al.CoverArtPath, nil
default: default:
mf, err := c.ds.MediaFile().Get(id) mf, err := c.ds.MediaFile(ctx).Get(id)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -53,7 +53,7 @@ func (c *cover) getCoverPath(id string) (string, error) {
} }
func (c *cover) Get(ctx context.Context, id string, size int, out io.Writer) error { func (c *cover) Get(ctx context.Context, id string, size int, out io.Writer) error {
path, err := c.getCoverPath(id) path, err := c.getCoverPath(ctx, id)
if err != nil && err != model.ErrNotFound { if err != nil && err != model.ErrNotFound {
return err return err
} }

View file

@ -17,8 +17,8 @@ func TestCover(t *testing.T) {
Init(t, false) Init(t, false)
ds := &persistence.MockDataStore{} ds := &persistence.MockDataStore{}
mockMediaFileRepo := ds.MediaFile().(*persistence.MockMediaFile) mockMediaFileRepo := ds.MediaFile(nil).(*persistence.MockMediaFile)
mockAlbumRepo := ds.Album().(*persistence.MockAlbum) mockAlbumRepo := ds.Album(nil).(*persistence.MockAlbum)
cover := engine.NewCover(ds) cover := engine.NewCover(ds)
out := new(bytes.Buffer) out := new(bytes.Buffer)

View file

@ -31,7 +31,7 @@ type listGenerator struct {
} }
func (g *listGenerator) query(ctx context.Context, qo model.QueryOptions) (Entries, error) { func (g *listGenerator) query(ctx context.Context, qo model.QueryOptions) (Entries, error) {
albums, err := g.ds.Album().GetAll(qo) albums, err := g.ds.Album(ctx).GetAll(qo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -39,7 +39,7 @@ func (g *listGenerator) query(ctx context.Context, qo model.QueryOptions) (Entri
for i, al := range albums { for i, al := range albums {
albumIds[i] = al.ID albumIds[i] = al.ID
} }
annMap, err := g.ds.Annotation().GetMap(getUserID(ctx), model.AlbumItemType, albumIds) annMap, err := g.ds.Annotation(ctx).GetMap(getUserID(ctx), model.AlbumItemType, albumIds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -47,7 +47,7 @@ func (g *listGenerator) query(ctx context.Context, qo model.QueryOptions) (Entri
} }
func (g *listGenerator) queryByAnnotation(ctx context.Context, qo model.QueryOptions) (Entries, error) { func (g *listGenerator) queryByAnnotation(ctx context.Context, qo model.QueryOptions) (Entries, error) {
annotations, err := g.ds.Annotation().GetAll(getUserID(ctx), model.AlbumItemType, qo) annotations, err := g.ds.Annotation(ctx).GetAll(getUserID(ctx), model.AlbumItemType, qo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,7 +56,7 @@ func (g *listGenerator) queryByAnnotation(ctx context.Context, qo model.QueryOpt
albumIds[i] = ann.ItemID albumIds[i] = ann.ItemID
} }
albumMap, err := g.ds.Album().GetMap(albumIds) albumMap, err := g.ds.Album(ctx).GetMap(albumIds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,7 +103,7 @@ func (g *listGenerator) GetByArtist(ctx context.Context, offset int, size int) (
} }
func (g *listGenerator) GetRandom(ctx context.Context, offset int, size int) (Entries, error) { func (g *listGenerator) GetRandom(ctx context.Context, offset int, size int) (Entries, error) {
albums, err := g.ds.Album().GetRandom(model.QueryOptions{Max: size, Offset: offset}) albums, err := g.ds.Album(ctx).GetRandom(model.QueryOptions{Max: size, Offset: offset})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,7 +120,7 @@ func (g *listGenerator) getAnnotationsForAlbums(ctx context.Context, albums mode
for i, al := range albums { for i, al := range albums {
albumIds[i] = al.ID albumIds[i] = al.ID
} }
return g.ds.Annotation().GetMap(getUserID(ctx), model.AlbumItemType, albumIds) return g.ds.Annotation(ctx).GetMap(getUserID(ctx), model.AlbumItemType, albumIds)
} }
func (g *listGenerator) GetRandomSongs(ctx context.Context, size int, genre string) (Entries, error) { func (g *listGenerator) GetRandomSongs(ctx context.Context, size int, genre string) (Entries, error) {
@ -128,14 +128,14 @@ func (g *listGenerator) GetRandomSongs(ctx context.Context, size int, genre stri
if genre != "" { if genre != "" {
options.Filters = map[string]interface{}{"genre": genre} options.Filters = map[string]interface{}{"genre": genre}
} }
mediaFiles, err := g.ds.MediaFile().GetRandom(options) mediaFiles, err := g.ds.MediaFile(ctx).GetRandom(options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := make(Entries, len(mediaFiles)) r := make(Entries, len(mediaFiles))
for i, mf := range mediaFiles { for i, mf := range mediaFiles {
ann, err := g.ds.Annotation().Get(getUserID(ctx), model.MediaItemType, mf.ID) ann, err := g.ds.Annotation(ctx).Get(getUserID(ctx), model.MediaItemType, mf.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -146,7 +146,7 @@ func (g *listGenerator) GetRandomSongs(ctx context.Context, size int, genre stri
func (g *listGenerator) GetStarred(ctx context.Context, offset int, size int) (Entries, error) { func (g *listGenerator) GetStarred(ctx context.Context, offset int, size int) (Entries, error) {
qo := model.QueryOptions{Offset: offset, Max: size, Sort: "starred_at", Order: "desc"} qo := model.QueryOptions{Offset: offset, Max: size, Sort: "starred_at", Order: "desc"}
albums, err := g.ds.Album().GetStarred(getUserID(ctx), qo) albums, err := g.ds.Album(ctx).GetStarred(getUserID(ctx), qo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -161,17 +161,17 @@ func (g *listGenerator) GetStarred(ctx context.Context, offset int, size int) (E
func (g *listGenerator) GetAllStarred(ctx context.Context) (artists Entries, albums Entries, mediaFiles Entries, err error) { func (g *listGenerator) GetAllStarred(ctx context.Context) (artists Entries, albums Entries, mediaFiles Entries, err error) {
options := model.QueryOptions{Sort: "starred_at", Order: "desc"} options := model.QueryOptions{Sort: "starred_at", Order: "desc"}
ars, err := g.ds.Artist().GetStarred(getUserID(ctx), options) ars, err := g.ds.Artist(ctx).GetStarred(getUserID(ctx), options)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
als, err := g.ds.Album().GetStarred(getUserID(ctx), options) als, err := g.ds.Album(ctx).GetStarred(getUserID(ctx), options)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
mfs, err := g.ds.MediaFile().GetStarred(getUserID(ctx), options) mfs, err := g.ds.MediaFile(ctx).GetStarred(getUserID(ctx), options)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -180,7 +180,7 @@ func (g *listGenerator) GetAllStarred(ctx context.Context) (artists Entries, alb
for _, mf := range mfs { for _, mf := range mfs {
mfIds = append(mfIds, mf.ID) mfIds = append(mfIds, mf.ID)
} }
trackAnnMap, err := g.ds.Annotation().GetMap(getUserID(ctx), model.MediaItemType, mfIds) trackAnnMap, err := g.ds.Annotation(ctx).GetMap(getUserID(ctx), model.MediaItemType, mfIds)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -194,7 +194,7 @@ func (g *listGenerator) GetAllStarred(ctx context.Context) (artists Entries, alb
for _, ar := range ars { for _, ar := range ars {
artistIds = append(artistIds, ar.ID) artistIds = append(artistIds, ar.ID)
} }
artistAnnMap, err := g.ds.Annotation().GetMap(getUserID(ctx), model.MediaItemType, artistIds) artistAnnMap, err := g.ds.Annotation(ctx).GetMap(getUserID(ctx), model.MediaItemType, artistIds)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -213,11 +213,11 @@ func (g *listGenerator) GetNowPlaying(ctx context.Context) (Entries, error) {
} }
entries := make(Entries, len(npInfo)) entries := make(Entries, len(npInfo))
for i, np := range npInfo { for i, np := range npInfo {
mf, err := g.ds.MediaFile().Get(np.TrackID) mf, err := g.ds.MediaFile(ctx).Get(np.TrackID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ann, err := g.ds.Annotation().Get(getUserID(ctx), model.MediaItemType, mf.ID) ann, err := g.ds.Annotation(ctx).Get(getUserID(ctx), model.MediaItemType, mf.ID)
entries[i] = FromMediaFile(mf, ann) entries[i] = FromMediaFile(mf, ann)
entries[i].UserName = np.Username entries[i].UserName = np.Username
entries[i].MinutesAgo = int(time.Now().Sub(np.Start).Minutes()) entries[i].MinutesAgo = int(time.Now().Sub(np.Start).Minutes())

View file

@ -30,7 +30,7 @@ func (p *playlists) Create(ctx context.Context, playlistId, name string, ids []s
var err error var err error
// If playlistID is present, override tracks // If playlistID is present, override tracks
if playlistId != "" { if playlistId != "" {
pls, err = p.ds.Playlist().Get(playlistId) pls, err = p.ds.Playlist(ctx).Get(playlistId)
if err != nil { if err != nil {
return err return err
} }
@ -48,7 +48,7 @@ func (p *playlists) Create(ctx context.Context, playlistId, name string, ids []s
pls.Tracks = append(pls.Tracks, model.MediaFile{ID: id}) pls.Tracks = append(pls.Tracks, model.MediaFile{ID: id})
} }
return p.ds.Playlist().Put(pls) return p.ds.Playlist(ctx).Put(pls)
} }
func (p *playlists) getUser(ctx context.Context) string { func (p *playlists) getUser(ctx context.Context) string {
@ -61,7 +61,7 @@ func (p *playlists) getUser(ctx context.Context) string {
} }
func (p *playlists) Delete(ctx context.Context, playlistId string) error { func (p *playlists) Delete(ctx context.Context, playlistId string) error {
pls, err := p.ds.Playlist().Get(playlistId) pls, err := p.ds.Playlist(ctx).Get(playlistId)
if err != nil { if err != nil {
return err return err
} }
@ -70,11 +70,11 @@ func (p *playlists) Delete(ctx context.Context, playlistId string) error {
if owner != pls.Owner { if owner != pls.Owner {
return model.ErrNotAuthorized return model.ErrNotAuthorized
} }
return p.ds.Playlist().Delete(playlistId) return p.ds.Playlist(nil).Delete(playlistId)
} }
func (p *playlists) Update(ctx context.Context, playlistId string, name *string, idsToAdd []string, idxToRemove []int) error { func (p *playlists) Update(ctx context.Context, playlistId string, name *string, idsToAdd []string, idxToRemove []int) error {
pls, err := p.ds.Playlist().Get(playlistId) pls, err := p.ds.Playlist(ctx).Get(playlistId)
owner := p.getUser(ctx) owner := p.getUser(ctx)
if owner != pls.Owner { if owner != pls.Owner {
@ -100,11 +100,11 @@ func (p *playlists) Update(ctx context.Context, playlistId string, name *string,
} }
pls.Tracks = newTracks pls.Tracks = newTracks
return p.ds.Playlist().Put(pls) return p.ds.Playlist(ctx).Put(pls)
} }
func (p *playlists) GetAll(ctx context.Context) (model.Playlists, error) { func (p *playlists) GetAll(ctx context.Context) (model.Playlists, error) {
return p.ds.Playlist().GetAll(model.QueryOptions{}) return p.ds.Playlist(ctx).GetAll(model.QueryOptions{})
} }
type PlaylistInfo struct { type PlaylistInfo struct {
@ -119,7 +119,7 @@ type PlaylistInfo struct {
} }
func (p *playlists) Get(ctx context.Context, id string) (*PlaylistInfo, error) { func (p *playlists) Get(ctx context.Context, id string) (*PlaylistInfo, error) {
pl, err := p.ds.Playlist().GetWithTracks(id) pl, err := p.ds.Playlist(ctx).GetWithTracks(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,7 +141,7 @@ func (p *playlists) Get(ctx context.Context, id string) (*PlaylistInfo, error) {
mfIds = append(mfIds, mf.ID) mfIds = append(mfIds, mf.ID)
} }
annMap, err := p.ds.Annotation().GetMap(getUserID(ctx), model.MediaItemType, mfIds) annMap, err := p.ds.Annotation(ctx).GetMap(getUserID(ctx), model.MediaItemType, mfIds)
for i, mf := range pl.Tracks { for i, mf := range pl.Tracks {
ann := annMap[mf.ID] ann := annMap[mf.ID]

View file

@ -21,14 +21,14 @@ type ratings struct {
} }
func (r ratings) SetRating(ctx context.Context, id string, rating int) error { func (r ratings) SetRating(ctx context.Context, id string, rating int) error {
exist, err := r.ds.Album().Exists(id) exist, err := r.ds.Album(ctx).Exists(id)
if err != nil { if err != nil {
return err return err
} }
if exist { if exist {
return r.ds.Annotation().SetRating(rating, getUserID(ctx), model.AlbumItemType, id) return r.ds.Annotation(ctx).SetRating(rating, getUserID(ctx), model.AlbumItemType, id)
} }
return r.ds.Annotation().SetRating(rating, getUserID(ctx), model.MediaItemType, id) return r.ds.Annotation(ctx).SetRating(rating, getUserID(ctx), model.MediaItemType, id)
} }
func (r ratings) SetStar(ctx context.Context, star bool, ids ...string) error { func (r ratings) SetStar(ctx context.Context, star bool, ids ...string) error {
@ -40,29 +40,29 @@ func (r ratings) SetStar(ctx context.Context, star bool, ids ...string) error {
return r.ds.WithTx(func(tx model.DataStore) error { return r.ds.WithTx(func(tx model.DataStore) error {
for _, id := range ids { for _, id := range ids {
exist, err := r.ds.Album().Exists(id) exist, err := r.ds.Album(ctx).Exists(id)
if err != nil { if err != nil {
return err return err
} }
if exist { if exist {
err = tx.Annotation().SetStar(star, userId, model.AlbumItemType, ids...) err = tx.Annotation(ctx).SetStar(star, userId, model.AlbumItemType, ids...)
if err != nil { if err != nil {
return err return err
} }
continue continue
} }
exist, err = r.ds.Artist().Exists(id) exist, err = r.ds.Artist(ctx).Exists(id)
if err != nil { if err != nil {
return err return err
} }
if exist { if exist {
err = tx.Annotation().SetStar(star, userId, model.ArtistItemType, ids...) err = tx.Annotation(ctx).SetStar(star, userId, model.ArtistItemType, ids...)
if err != nil { if err != nil {
return err return err
} }
continue continue
} }
err = tx.Annotation().SetStar(star, userId, model.MediaItemType, ids...) err = tx.Annotation(ctx).SetStar(star, userId, model.MediaItemType, ids...)
if err != nil { if err != nil {
return err return err
} }

View file

@ -29,15 +29,15 @@ func (s *scrobbler) Register(ctx context.Context, playerId int, trackId string,
var mf *model.MediaFile var mf *model.MediaFile
var err error var err error
err = s.ds.WithTx(func(tx model.DataStore) error { err = s.ds.WithTx(func(tx model.DataStore) error {
mf, err = s.ds.MediaFile().Get(trackId) mf, err = s.ds.MediaFile(ctx).Get(trackId)
if err != nil { if err != nil {
return err return err
} }
err = s.ds.Annotation().IncPlayCount(userId, model.MediaItemType, trackId, playTime) err = s.ds.Annotation(ctx).IncPlayCount(userId, model.MediaItemType, trackId, playTime)
if err != nil { if err != nil {
return err return err
} }
err = s.ds.Annotation().IncPlayCount(userId, model.AlbumItemType, mf.AlbumID, playTime) err = s.ds.Annotation(ctx).IncPlayCount(userId, model.AlbumItemType, mf.AlbumID, playTime)
return err return err
}) })
return mf, err return mf, err
@ -45,7 +45,7 @@ func (s *scrobbler) Register(ctx context.Context, playerId int, trackId string,
// TODO Validate if NowPlaying still works after all refactorings // TODO Validate if NowPlaying still works after all refactorings
func (s *scrobbler) NowPlaying(ctx context.Context, playerId int, playerName, trackId, username string) (*model.MediaFile, error) { func (s *scrobbler) NowPlaying(ctx context.Context, playerId int, playerName, trackId, username string) (*model.MediaFile, error) {
mf, err := s.ds.MediaFile().Get(trackId) mf, err := s.ds.MediaFile(ctx).Get(trackId)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -25,7 +25,7 @@ func NewSearch(ds model.DataStore) Search {
func (s *search) SearchArtist(ctx context.Context, q string, offset int, size int) (Entries, error) { func (s *search) SearchArtist(ctx context.Context, q string, offset int, size int) (Entries, error) {
q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*"))) q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*")))
artists, err := s.ds.Artist().Search(q, offset, size) artists, err := s.ds.Artist(ctx).Search(q, offset, size)
if len(artists) == 0 || err != nil { if len(artists) == 0 || err != nil {
return nil, nil return nil, nil
} }
@ -34,7 +34,7 @@ func (s *search) SearchArtist(ctx context.Context, q string, offset int, size in
for i, al := range artists { for i, al := range artists {
artistIds[i] = al.ID artistIds[i] = al.ID
} }
annMap, err := s.ds.Annotation().GetMap(getUserID(ctx), model.ArtistItemType, artistIds) annMap, err := s.ds.Annotation(ctx).GetMap(getUserID(ctx), model.ArtistItemType, artistIds)
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
@ -44,7 +44,7 @@ func (s *search) SearchArtist(ctx context.Context, q string, offset int, size in
func (s *search) SearchAlbum(ctx context.Context, q string, offset int, size int) (Entries, error) { func (s *search) SearchAlbum(ctx context.Context, q string, offset int, size int) (Entries, error) {
q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*"))) q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*")))
albums, err := s.ds.Album().Search(q, offset, size) albums, err := s.ds.Album(ctx).Search(q, offset, size)
if len(albums) == 0 || err != nil { if len(albums) == 0 || err != nil {
return nil, nil return nil, nil
} }
@ -53,7 +53,7 @@ func (s *search) SearchAlbum(ctx context.Context, q string, offset int, size int
for i, al := range albums { for i, al := range albums {
albumIds[i] = al.ID albumIds[i] = al.ID
} }
annMap, err := s.ds.Annotation().GetMap(getUserID(ctx), model.AlbumItemType, albumIds) annMap, err := s.ds.Annotation(ctx).GetMap(getUserID(ctx), model.AlbumItemType, albumIds)
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
@ -63,7 +63,7 @@ func (s *search) SearchAlbum(ctx context.Context, q string, offset int, size int
func (s *search) SearchSong(ctx context.Context, q string, offset int, size int) (Entries, error) { func (s *search) SearchSong(ctx context.Context, q string, offset int, size int) (Entries, error) {
q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*"))) q = sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*")))
mediaFiles, err := s.ds.MediaFile().Search(q, offset, size) mediaFiles, err := s.ds.MediaFile(ctx).Search(q, offset, size)
if len(mediaFiles) == 0 || err != nil { if len(mediaFiles) == 0 || err != nil {
return nil, nil return nil, nil
} }
@ -72,7 +72,7 @@ func (s *search) SearchSong(ctx context.Context, q string, offset int, size int)
for i, mf := range mediaFiles { for i, mf := range mediaFiles {
trackIds[i] = mf.ID trackIds[i] = mf.ID
} }
annMap, err := s.ds.Annotation().GetMap(getUserID(ctx), model.MediaItemType, trackIds) annMap, err := s.ds.Annotation(ctx).GetMap(getUserID(ctx), model.MediaItemType, trackIds)
if err != nil { if err != nil {
return nil, nil return nil, nil
} }

View file

@ -24,7 +24,7 @@ type users struct {
} }
func (u *users) Authenticate(ctx context.Context, username, pass, token, salt string) (*model.User, error) { func (u *users) Authenticate(ctx context.Context, username, pass, token, salt string) (*model.User, error) {
user, err := u.ds.User().FindByUsername(username) user, err := u.ds.User(ctx).FindByUsername(username)
if err == model.ErrNotFound { if err == model.ErrNotFound {
return nil, model.ErrInvalidAuth return nil, model.ErrInvalidAuth
} }
@ -50,7 +50,7 @@ func (u *users) Authenticate(ctx context.Context, username, pass, token, salt st
return nil, model.ErrInvalidAuth return nil, model.ErrInvalidAuth
} }
go func() { go func() {
err := u.ds.User().UpdateLastAccessAt(user.ID) err := u.ds.User(ctx).UpdateLastAccessAt(user.ID)
if err != nil { if err != nil {
log.Error(ctx, "Could not update user's lastAccessAt", "user", user.UserName) log.Error(ctx, "Could not update user's lastAccessAt", "user", user.UserName)
} }

View file

@ -1,6 +1,8 @@
package model package model
import ( import (
"context"
"github.com/deluan/rest" "github.com/deluan/rest"
) )
@ -22,17 +24,17 @@ type ResourceRepository interface {
} }
type DataStore interface { type DataStore interface {
Album() AlbumRepository Album(ctx context.Context) AlbumRepository
Artist() ArtistRepository Artist(ctx context.Context) ArtistRepository
MediaFile() MediaFileRepository MediaFile(ctx context.Context) MediaFileRepository
MediaFolder() MediaFolderRepository MediaFolder(ctx context.Context) MediaFolderRepository
Genre() GenreRepository Genre(ctx context.Context) GenreRepository
Playlist() PlaylistRepository Playlist(ctx context.Context) PlaylistRepository
Property() PropertyRepository Property(ctx context.Context) PropertyRepository
User() UserRepository User(ctx context.Context) UserRepository
Annotation() AnnotationRepository Annotation(ctx context.Context) AnnotationRepository
Resource(model interface{}) ResourceRepository Resource(ctx context.Context, model interface{}) ResourceRepository
WithTx(func(tx DataStore) error) error WithTx(func(tx DataStore) error) error
} }

View file

@ -1,6 +1,10 @@
package persistence package persistence
import "github.com/deluan/navidrome/model" import (
"context"
"github.com/deluan/navidrome/model"
)
type MockDataStore struct { type MockDataStore struct {
MockedGenre model.GenreRepository MockedGenre model.GenreRepository
@ -10,54 +14,54 @@ type MockDataStore struct {
MockedUser model.UserRepository MockedUser model.UserRepository
} }
func (db *MockDataStore) Album() model.AlbumRepository { func (db *MockDataStore) Album(context.Context) model.AlbumRepository {
if db.MockedAlbum == nil { if db.MockedAlbum == nil {
db.MockedAlbum = CreateMockAlbumRepo() db.MockedAlbum = CreateMockAlbumRepo()
} }
return db.MockedAlbum return db.MockedAlbum
} }
func (db *MockDataStore) Artist() model.ArtistRepository { func (db *MockDataStore) Artist(context.Context) model.ArtistRepository {
if db.MockedArtist == nil { if db.MockedArtist == nil {
db.MockedArtist = CreateMockArtistRepo() db.MockedArtist = CreateMockArtistRepo()
} }
return db.MockedArtist return db.MockedArtist
} }
func (db *MockDataStore) MediaFile() model.MediaFileRepository { func (db *MockDataStore) MediaFile(context.Context) model.MediaFileRepository {
if db.MockedMediaFile == nil { if db.MockedMediaFile == nil {
db.MockedMediaFile = CreateMockMediaFileRepo() db.MockedMediaFile = CreateMockMediaFileRepo()
} }
return db.MockedMediaFile return db.MockedMediaFile
} }
func (db *MockDataStore) MediaFolder() model.MediaFolderRepository { func (db *MockDataStore) MediaFolder(context.Context) model.MediaFolderRepository {
return struct{ model.MediaFolderRepository }{} return struct{ model.MediaFolderRepository }{}
} }
func (db *MockDataStore) Genre() model.GenreRepository { func (db *MockDataStore) Genre(context.Context) model.GenreRepository {
if db.MockedGenre != nil { if db.MockedGenre != nil {
return db.MockedGenre return db.MockedGenre
} }
return struct{ model.GenreRepository }{} return struct{ model.GenreRepository }{}
} }
func (db *MockDataStore) Playlist() model.PlaylistRepository { func (db *MockDataStore) Playlist(context.Context) model.PlaylistRepository {
return struct{ model.PlaylistRepository }{} return struct{ model.PlaylistRepository }{}
} }
func (db *MockDataStore) Property() model.PropertyRepository { func (db *MockDataStore) Property(context.Context) model.PropertyRepository {
return struct{ model.PropertyRepository }{} return struct{ model.PropertyRepository }{}
} }
func (db *MockDataStore) User() model.UserRepository { func (db *MockDataStore) User(context.Context) model.UserRepository {
if db.MockedUser == nil { if db.MockedUser == nil {
db.MockedUser = &mockedUserRepo{} db.MockedUser = &mockedUserRepo{}
} }
return db.MockedUser return db.MockedUser
} }
func (db *MockDataStore) Annotation() model.AnnotationRepository { func (db *MockDataStore) Annotation(context.Context) model.AnnotationRepository {
return struct{ model.AnnotationRepository }{} return struct{ model.AnnotationRepository }{}
} }
@ -65,7 +69,7 @@ func (db *MockDataStore) WithTx(block func(db model.DataStore) error) error {
return block(db) return block(db)
} }
func (db *MockDataStore) Resource(m interface{}) model.ResourceRepository { func (db *MockDataStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository {
return struct{ model.ResourceRepository }{} return struct{ model.ResourceRepository }{}
} }

View file

@ -1,6 +1,7 @@
package persistence package persistence
import ( import (
"context"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -41,43 +42,43 @@ func New() model.DataStore {
return &SQLStore{} return &SQLStore{}
} }
func (db *SQLStore) Album() model.AlbumRepository { func (db *SQLStore) Album(context.Context) model.AlbumRepository {
return NewAlbumRepository(db.getOrmer()) return NewAlbumRepository(db.getOrmer())
} }
func (db *SQLStore) Artist() model.ArtistRepository { func (db *SQLStore) Artist(context.Context) model.ArtistRepository {
return NewArtistRepository(db.getOrmer()) return NewArtistRepository(db.getOrmer())
} }
func (db *SQLStore) MediaFile() model.MediaFileRepository { func (db *SQLStore) MediaFile(context.Context) model.MediaFileRepository {
return NewMediaFileRepository(db.getOrmer()) return NewMediaFileRepository(db.getOrmer())
} }
func (db *SQLStore) MediaFolder() model.MediaFolderRepository { func (db *SQLStore) MediaFolder(context.Context) model.MediaFolderRepository {
return NewMediaFolderRepository(db.getOrmer()) return NewMediaFolderRepository(db.getOrmer())
} }
func (db *SQLStore) Genre() model.GenreRepository { func (db *SQLStore) Genre(context.Context) model.GenreRepository {
return NewGenreRepository(db.getOrmer()) return NewGenreRepository(db.getOrmer())
} }
func (db *SQLStore) Playlist() model.PlaylistRepository { func (db *SQLStore) Playlist(context.Context) model.PlaylistRepository {
return NewPlaylistRepository(db.getOrmer()) return NewPlaylistRepository(db.getOrmer())
} }
func (db *SQLStore) Property() model.PropertyRepository { func (db *SQLStore) Property(context.Context) model.PropertyRepository {
return NewPropertyRepository(db.getOrmer()) return NewPropertyRepository(db.getOrmer())
} }
func (db *SQLStore) User() model.UserRepository { func (db *SQLStore) User(context.Context) model.UserRepository {
return NewUserRepository(db.getOrmer()) return NewUserRepository(db.getOrmer())
} }
func (db *SQLStore) Annotation() model.AnnotationRepository { func (db *SQLStore) Annotation(context.Context) model.AnnotationRepository {
return NewAnnotationRepository(db.getOrmer()) return NewAnnotationRepository(db.getOrmer())
} }
func (db *SQLStore) Resource(model interface{}) model.ResourceRepository { func (db *SQLStore) Resource(ctx context.Context, model interface{}) model.ResourceRepository {
return NewResource(db.getOrmer(), model, getMappedModel(model)) return NewResource(db.getOrmer(), model, getMappedModel(model))
} }

View file

@ -63,21 +63,21 @@ var _ = Describe("Initialize test DB", func() {
BeforeSuite(func() { BeforeSuite(func() {
conf.Server.DbPath = ":memory:" conf.Server.DbPath = ":memory:"
ds := New() ds := New()
artistRepo := ds.Artist() artistRepo := ds.Artist(nil)
for _, a := range testArtists { for _, a := range testArtists {
err := artistRepo.Put(&a) err := artistRepo.Put(&a)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
albumRepository := ds.Album() albumRepository := ds.Album(nil)
for _, a := range testAlbums { for _, a := range testAlbums {
err := albumRepository.Put(&a) err := albumRepository.Put(&a)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
mediaFileRepository := ds.MediaFile() mediaFileRepository := ds.MediaFile(nil)
for _, s := range testSongs { for _, s := range testSongs {
err := mediaFileRepository.Put(&s) err := mediaFileRepository.Put(&s)
if err != nil { if err != nil {

View file

@ -60,7 +60,7 @@ func (s *Scanner) RescanAll(fullRescan bool) error {
func (s *Scanner) Status() []StatusInfo { return nil } func (s *Scanner) Status() []StatusInfo { return nil }
func (s *Scanner) getLastModifiedSince(folder string) time.Time { func (s *Scanner) getLastModifiedSince(folder string) time.Time {
ms, err := s.ds.Property().Get(model.PropLastScan + "-" + folder) ms, err := s.ds.Property(nil).Get(model.PropLastScan + "-" + folder)
if err != nil { if err != nil {
return time.Time{} return time.Time{}
} }
@ -73,11 +73,11 @@ func (s *Scanner) getLastModifiedSince(folder string) time.Time {
func (s *Scanner) updateLastModifiedSince(folder string, t time.Time) { func (s *Scanner) updateLastModifiedSince(folder string, t time.Time) {
millis := t.UnixNano() / int64(time.Millisecond) millis := t.UnixNano() / int64(time.Millisecond)
s.ds.Property().Put(model.PropLastScan+"-"+folder, fmt.Sprint(millis)) s.ds.Property(nil).Put(model.PropLastScan+"-"+folder, fmt.Sprint(millis))
} }
func (s *Scanner) loadFolders() { func (s *Scanner) loadFolders() {
fs, _ := s.ds.MediaFolder().GetAll() fs, _ := s.ds.MediaFolder(nil).GetAll()
for _, f := range fs { for _, f := range fs {
log.Info("Configuring Media Folder", "name", f.Name, "path", f.Path) log.Info("Configuring Media Folder", "name", f.Name, "path", f.Path)
s.folders[f.Path] = NewTagScanner(f.Path, s.ds) s.folders[f.Path] = NewTagScanner(f.Path, s.ds)

View file

@ -58,16 +58,16 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time) erro
updatedAlbums := map[string]bool{} updatedAlbums := map[string]bool{}
for _, c := range changed { for _, c := range changed {
err := s.processChangedDir(c, updatedArtists, updatedAlbums) err := s.processChangedDir(ctx, c, updatedArtists, updatedAlbums)
if err != nil { if err != nil {
return err return err
} }
if len(updatedAlbums)+len(updatedArtists) > 100 { if len(updatedAlbums)+len(updatedArtists) > 100 {
err = s.refreshAlbums(updatedAlbums) err = s.refreshAlbums(ctx, updatedAlbums)
if err != nil { if err != nil {
return err return err
} }
err = s.refreshArtists(updatedArtists) err = s.refreshArtists(ctx, updatedArtists)
if err != nil { if err != nil {
return err return err
} }
@ -76,16 +76,16 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time) erro
} }
} }
for _, c := range deleted { for _, c := range deleted {
err := s.processDeletedDir(c, updatedArtists, updatedAlbums) err := s.processDeletedDir(ctx, c, updatedArtists, updatedAlbums)
if err != nil { if err != nil {
return err return err
} }
if len(updatedAlbums)+len(updatedArtists) > 100 { if len(updatedAlbums)+len(updatedArtists) > 100 {
err = s.refreshAlbums(updatedAlbums) err = s.refreshAlbums(ctx, updatedAlbums)
if err != nil { if err != nil {
return err return err
} }
err = s.refreshArtists(updatedArtists) err = s.refreshArtists(ctx, updatedArtists)
if err != nil { if err != nil {
return err return err
} }
@ -94,22 +94,22 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time) erro
} }
} }
err = s.refreshAlbums(updatedAlbums) err = s.refreshAlbums(ctx, updatedAlbums)
if err != nil { if err != nil {
return err return err
} }
err = s.refreshArtists(updatedArtists) err = s.refreshArtists(ctx, updatedArtists)
if err != nil { if err != nil {
return err return err
} }
err = s.ds.Album().PurgeEmpty() err = s.ds.Album(ctx).PurgeEmpty()
if err != nil { if err != nil {
return err return err
} }
err = s.ds.Artist().PurgeEmpty() err = s.ds.Artist(ctx).PurgeEmpty()
if err != nil { if err != nil {
return err return err
} }
@ -117,30 +117,30 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time) erro
return nil return nil
} }
func (s *TagScanner) refreshAlbums(updatedAlbums map[string]bool) error { func (s *TagScanner) refreshAlbums(ctx context.Context, updatedAlbums map[string]bool) error {
var ids []string var ids []string
for id := range updatedAlbums { for id := range updatedAlbums {
ids = append(ids, id) ids = append(ids, id)
} }
return s.ds.Album().Refresh(ids...) return s.ds.Album(ctx).Refresh(ids...)
} }
func (s *TagScanner) refreshArtists(updatedArtists map[string]bool) error { func (s *TagScanner) refreshArtists(ctx context.Context, updatedArtists map[string]bool) error {
var ids []string var ids []string
for id := range updatedArtists { for id := range updatedArtists {
ids = append(ids, id) ids = append(ids, id)
} }
return s.ds.Artist().Refresh(ids...) return s.ds.Artist(ctx).Refresh(ids...)
} }
func (s *TagScanner) processChangedDir(dir string, updatedArtists map[string]bool, updatedAlbums map[string]bool) error { func (s *TagScanner) processChangedDir(ctx context.Context, dir string, updatedArtists map[string]bool, updatedAlbums map[string]bool) error {
dir = path.Join(s.rootFolder, dir) dir = path.Join(s.rootFolder, dir)
start := time.Now() start := time.Now()
// Load folder's current tracks from DB into a map // Load folder's current tracks from DB into a map
currentTracks := map[string]model.MediaFile{} currentTracks := map[string]model.MediaFile{}
ct, err := s.ds.MediaFile().FindByPath(dir) ct, err := s.ds.MediaFile(ctx).FindByPath(dir)
if err != nil { if err != nil {
return err return err
} }
@ -168,7 +168,7 @@ func (s *TagScanner) processChangedDir(dir string, updatedArtists map[string]boo
for _, n := range newTracks { for _, n := range newTracks {
c, ok := currentTracks[n.ID] c, ok := currentTracks[n.ID]
if !ok || (ok && n.UpdatedAt.After(c.UpdatedAt)) { if !ok || (ok && n.UpdatedAt.After(c.UpdatedAt)) {
err := s.ds.MediaFile().Put(&n) err := s.ds.MediaFile(ctx).Put(&n)
updatedArtists[n.ArtistID] = true updatedArtists[n.ArtistID] = true
updatedAlbums[n.AlbumID] = true updatedAlbums[n.AlbumID] = true
numUpdatedTracks++ numUpdatedTracks++
@ -182,7 +182,7 @@ func (s *TagScanner) processChangedDir(dir string, updatedArtists map[string]boo
// Remaining tracks from DB that are not in the folder are deleted // Remaining tracks from DB that are not in the folder are deleted
for id := range currentTracks { for id := range currentTracks {
numPurgedTracks++ numPurgedTracks++
if err := s.ds.MediaFile().Delete(id); err != nil { if err := s.ds.MediaFile(ctx).Delete(id); err != nil {
return err return err
} }
} }
@ -191,10 +191,10 @@ func (s *TagScanner) processChangedDir(dir string, updatedArtists map[string]boo
return nil return nil
} }
func (s *TagScanner) processDeletedDir(dir string, updatedArtists map[string]bool, updatedAlbums map[string]bool) error { func (s *TagScanner) processDeletedDir(ctx context.Context, dir string, updatedArtists map[string]bool, updatedAlbums map[string]bool) error {
dir = path.Join(s.rootFolder, dir) dir = path.Join(s.rootFolder, dir)
ct, err := s.ds.MediaFile().FindByPath(dir) ct, err := s.ds.MediaFile(ctx).FindByPath(dir)
if err != nil { if err != nil {
return err return err
} }
@ -203,7 +203,7 @@ func (s *TagScanner) processDeletedDir(dir string, updatedArtists map[string]boo
updatedAlbums[t.AlbumID] = true updatedAlbums[t.AlbumID] = true
} }
return s.ds.MediaFile().DeleteByPath(dir) return s.ds.MediaFile(ctx).DeleteByPath(dir)
} }
func (s *TagScanner) loadTracks(dirPath string) (model.MediaFiles, error) { func (s *TagScanner) loadTracks(dirPath string) (model.MediaFiles, error) {

View file

@ -64,7 +64,7 @@ func (app *Router) routes() http.Handler {
func (app *Router) R(r chi.Router, pathPrefix string, model interface{}) { func (app *Router) R(r chi.Router, pathPrefix string, model interface{}) {
constructor := func(ctx context.Context) rest.Repository { constructor := func(ctx context.Context) rest.Repository {
return app.ds.Resource(model) return app.ds.Resource(ctx, model)
} }
r.Route(pathPrefix, func(r chi.Router) { r.Route(pathPrefix, func(r chi.Router) {
r.Get("/", rest.GetAll(constructor)) r.Get("/", rest.GetAll(constructor))

View file

@ -41,7 +41,7 @@ func Login(ds model.DataStore) func(w http.ResponseWriter, r *http.Request) {
} }
func handleLogin(ds model.DataStore, username string, password string, w http.ResponseWriter, r *http.Request) { func handleLogin(ds model.DataStore, username string, password string, w http.ResponseWriter, r *http.Request) {
user, err := validateLogin(ds.User(), username, password) user, err := validateLogin(ds.User(r.Context()), username, password)
if err != nil { if err != nil {
rest.RespondWithError(w, http.StatusInternalServerError, "Unknown error authentication user. Please try again") rest.RespondWithError(w, http.StatusInternalServerError, "Unknown error authentication user. Please try again")
return return
@ -89,7 +89,7 @@ func CreateAdmin(ds model.DataStore) func(w http.ResponseWriter, r *http.Request
rest.RespondWithError(w, http.StatusUnprocessableEntity, err.Error()) rest.RespondWithError(w, http.StatusUnprocessableEntity, err.Error())
return return
} }
c, err := ds.User().CountAll() c, err := ds.User(r.Context()).CountAll()
if err != nil { if err != nil {
rest.RespondWithError(w, http.StatusInternalServerError, err.Error()) rest.RespondWithError(w, http.StatusInternalServerError, err.Error())
return return
@ -98,7 +98,7 @@ func CreateAdmin(ds model.DataStore) func(w http.ResponseWriter, r *http.Request
rest.RespondWithError(w, http.StatusForbidden, "Cannot create another first admin") rest.RespondWithError(w, http.StatusForbidden, "Cannot create another first admin")
return return
} }
err = createDefaultUser(ds, username, password) err = createDefaultUser(r.Context(), ds, username, password)
if err != nil { if err != nil {
rest.RespondWithError(w, http.StatusInternalServerError, err.Error()) rest.RespondWithError(w, http.StatusInternalServerError, err.Error())
return return
@ -107,7 +107,7 @@ func CreateAdmin(ds model.DataStore) func(w http.ResponseWriter, r *http.Request
} }
} }
func createDefaultUser(ds model.DataStore, username, password string) error { func createDefaultUser(ctx context.Context, ds model.DataStore, username, password string) error {
id, _ := uuid.NewRandom() id, _ := uuid.NewRandom()
log.Warn("Creating initial user", "user", consts.InitialUserName) log.Warn("Creating initial user", "user", consts.InitialUserName)
initialUser := model.User{ initialUser := model.User{
@ -118,7 +118,7 @@ func createDefaultUser(ds model.DataStore, username, password string) error {
Password: password, Password: password,
IsAdmin: true, IsAdmin: true,
} }
err := ds.User().Put(&initialUser) err := ds.User(ctx).Put(&initialUser)
if err != nil { if err != nil {
log.Error("Could not create initial user", "user", initialUser, err) log.Error("Could not create initial user", "user", initialUser, err)
} }
@ -127,7 +127,7 @@ func createDefaultUser(ds model.DataStore, username, password string) error {
func initTokenAuth(ds model.DataStore) { func initTokenAuth(ds model.DataStore) {
once.Do(func() { once.Do(func() {
secret, err := ds.Property().DefaultGet(consts.JWTSecretKey, "not so secret") secret, err := ds.Property(nil).DefaultGet(consts.JWTSecretKey, "not so secret")
if err != nil { if err != nil {
log.Error("No JWT secret found in DB. Setting a temp one, but please report this error", err) log.Error("No JWT secret found in DB. Setting a temp one, but please report this error", err)
} }
@ -190,7 +190,7 @@ func getToken(ds model.DataStore, ctx context.Context) (*jwt.Token, error) {
return token, nil return token, nil
} }
c, err := ds.User().CountAll() c, err := ds.User(ctx).CountAll()
firstTime := c == 0 && err == nil firstTime := c == 0 && err == nil
if firstTime { if firstTime {
return nil, ErrFirstTime return nil, ErrFirstTime

View file

@ -11,7 +11,7 @@ import (
func initialSetup(ds model.DataStore) { func initialSetup(ds model.DataStore) {
_ = ds.WithTx(func(tx model.DataStore) error { _ = ds.WithTx(func(tx model.DataStore) error {
_, err := ds.Property().Get(consts.InitialSetupFlagKey) _, err := ds.Property(nil).Get(consts.InitialSetupFlagKey)
if err == nil { if err == nil {
return nil return nil
} }
@ -20,19 +20,19 @@ func initialSetup(ds model.DataStore) {
return err return err
} }
err = ds.Property().Put(consts.InitialSetupFlagKey, time.Now().String()) err = ds.Property(nil).Put(consts.InitialSetupFlagKey, time.Now().String())
return err return err
}) })
} }
func createJWTSecret(ds model.DataStore) error { func createJWTSecret(ds model.DataStore) error {
_, err := ds.Property().Get(consts.JWTSecretKey) _, err := ds.Property(nil).Get(consts.JWTSecretKey)
if err == nil { if err == nil {
return nil return nil
} }
jwtSecret, _ := uuid.NewRandom() jwtSecret, _ := uuid.NewRandom()
log.Warn("Creating JWT secret, used for encrypting UI sessions") log.Warn("Creating JWT secret, used for encrypting UI sessions")
err = ds.Property().Put(consts.JWTSecretKey, jwtSecret.String()) err = ds.Property(nil).Put(consts.JWTSecretKey, jwtSecret.String())
if err != nil { if err != nil {
log.Error("Could not save JWT secret in DB", err) log.Error("Could not save JWT secret in DB", err)
} }