diff --git a/persistence/album_repository.go b/persistence/album_repository.go index 840ad16cc..3ef3528ae 100644 --- a/persistence/album_repository.go +++ b/persistence/album_repository.go @@ -162,7 +162,7 @@ func (r *albumRepository) Get(id string) (*model.Album, error) { return nil, model.ErrNotFound } res := dba.toModels() - err := r.loadAlbumGenres(&res) + err := loadAllGenres(r, res) return &res[0], err } @@ -171,7 +171,7 @@ func (r *albumRepository) Put(m *model.Album) error { if err != nil { return err } - return r.updateGenres(m.ID, r.tableName, m.Genres) + return r.updateGenres(m.ID, m.Genres) } func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, error) { @@ -179,7 +179,7 @@ func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, e if err != nil { return nil, err } - err = r.loadAlbumGenres(&res) + err = loadAllGenres(r, res) return res, err } @@ -211,7 +211,7 @@ func (r *albumRepository) Search(q string, offset int, size int) (model.Albums, return nil, err } res := dba.toModels() - err = r.loadAlbumGenres(&res) + err = loadAllGenres(r, res) return res, err } diff --git a/persistence/artist_repository.go b/persistence/artist_repository.go index 25288c3a1..74b44d2f5 100644 --- a/persistence/artist_repository.go +++ b/persistence/artist_repository.go @@ -100,9 +100,9 @@ func (r *artistRepository) Put(a *model.Artist, colsToUpdate ...string) error { return err } if a.ID == consts.VariousArtistsID { - return r.updateGenres(a.ID, r.tableName, nil) + return r.updateGenres(a.ID, nil) } - return r.updateGenres(a.ID, r.tableName, a.Genres) + return r.updateGenres(a.ID, a.Genres) } func (r *artistRepository) Get(id string) (*model.Artist, error) { @@ -115,7 +115,7 @@ func (r *artistRepository) Get(id string) (*model.Artist, error) { return nil, model.ErrNotFound } res := r.toModels(dba) - err := r.loadArtistGenres(&res) + err := loadAllGenres(r, res) return &res[0], err } @@ -127,7 +127,7 @@ func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists, return nil, err } res := r.toModels(dba) - err = r.loadArtistGenres(&res) + err = loadAllGenres(r, res) return res, err } diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index 3ff036569..139ea65b1 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -65,7 +65,7 @@ func (r *mediaFileRepository) Put(m *model.MediaFile) error { if err != nil { return err } - return r.updateGenres(m.ID, r.tableName, m.Genres) + return r.updateGenres(m.ID, m.Genres) } func (r *mediaFileRepository) selectMediaFile(options ...model.QueryOptions) SelectBuilder { @@ -94,7 +94,7 @@ func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) { if len(res) == 0 { return nil, model.ErrNotFound } - err := r.loadMediaFileGenres(&res) + err := loadAllGenres(r, res) return &res[0], err } @@ -105,7 +105,7 @@ func (r *mediaFileRepository) GetAll(options ...model.QueryOptions) (model.Media if err != nil { return nil, err } - err = r.loadMediaFileGenres(&res) + err = loadAllGenres(r, res) return res, err } @@ -200,7 +200,7 @@ func (r *mediaFileRepository) Search(q string, offset int, size int) (model.Medi if err != nil { return nil, err } - err = r.loadMediaFileGenres(&results) + err = loadAllGenres(r, results) return results, err } diff --git a/persistence/playlist_track_repository.go b/persistence/playlist_track_repository.go index 24a73f63e..54077dbba 100644 --- a/persistence/playlist_track_repository.go +++ b/persistence/playlist_track_repository.go @@ -80,7 +80,7 @@ func (r *playlistTrackRepository) GetAll(options ...model.QueryOptions) (model.P return nil, err } mfs := tracks.MediaFiles() - err = r.loadMediaFileGenres(&mfs) + err = loadAllGenres(r, mfs) if err != nil { log.Error(r.ctx, "Error loading genres for playlist", "playlist", r.playlist.Name, "id", r.playlist.ID, err) return nil, err diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index 43bb35c56..3bb050e89 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -42,6 +42,10 @@ func loggedUser(ctx context.Context) *model.User { } } +func (r sqlRepository) getTableName() string { + return r.tableName +} + func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder { sq := Select().From(r.tableName) sq = r.applyOptions(sq, options...) diff --git a/persistence/sql_bookmarks.go b/persistence/sql_bookmarks.go index 58c062c50..33bf95b44 100644 --- a/persistence/sql_bookmarks.go +++ b/persistence/sql_bookmarks.go @@ -104,7 +104,7 @@ func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) { log.Error(r.ctx, "Error getting mediafiles with bookmarks", "user", user.UserName, err) return nil, err } - err = r.loadMediaFileGenres(&mfs) + err = loadAllGenres(r, mfs) if err != nil { log.Error(r.ctx, "Error loading genres for bookmarked songs", "user", user.UserName, err) return nil, err diff --git a/persistence/sql_genres.go b/persistence/sql_genres.go index be245721a..4332c60e4 100644 --- a/persistence/sql_genres.go +++ b/persistence/sql_genres.go @@ -11,7 +11,8 @@ func (r sqlRepository) withGenres(sql SelectBuilder) SelectBuilder { LeftJoin("genre on ag.genre_id = genre.id") } -func (r *sqlRepository) updateGenres(id string, tableName string, genres model.Genres) error { +func (r *sqlRepository) updateGenres(id string, genres model.Genres) error { + tableName := r.getTableName() del := Delete(tableName + "_genres").Where(Eq{tableName + "_id": id}) _, err := r.executeSQL(del) if err != nil { @@ -36,89 +37,70 @@ func (r *sqlRepository) updateGenres(id string, tableName string, genres model.G return err } -func (r *sqlRepository) loadMediaFileGenres(mfs *model.MediaFiles) error { - var ids []string - m := map[string]*model.MediaFile{} - for i := range *mfs { - mf := &(*mfs)[i] - ids = append(ids, mf.ID) - m[mf.ID] = mf - } +type baseRepository interface { + queryAll(SelectBuilder, any, ...model.QueryOptions) error + getTableName() string +} +type modelWithGenres interface { + model.Album | model.Artist | model.MediaFile +} + +func getID[T modelWithGenres](item T) string { + switch v := any(item).(type) { + case model.Album: + return v.ID + case model.Artist: + return v.ID + case model.MediaFile: + return v.ID + } + return "" +} + +func appendGenre[T modelWithGenres](item *T, genre model.Genre) { + switch v := any(item).(type) { + case *model.Album: + v.Genres = append(v.Genres, genre) + case *model.Artist: + v.Genres = append(v.Genres, genre) + case *model.MediaFile: + v.Genres = append(v.Genres, genre) + } +} + +func loadGenres[T modelWithGenres](r baseRepository, ids []string, items map[string]*T) error { + tableName := r.getTableName() return slice.RangeByChunks(ids, 900, func(ids []string) error { - sql := Select("g.*", "mg.media_file_id").From("genre g").Join("media_file_genres mg on mg.genre_id = g.id"). - Where(Eq{"mg.media_file_id": ids}).OrderBy("mg.media_file_id", "mg.rowid") + sql := Select("genre.*", tableName+"_id as item_id").From("genre"). + Join(tableName+"_genres ig on genre.id = ig.genre_id"). + OrderBy(tableName+"_id", "ig.rowid").Where(Eq{tableName + "_id": ids}) + var genres []struct { model.Genre - MediaFileId string + ItemID string } - err := r.queryAll(sql, &genres) if err != nil { return err } for _, g := range genres { - mf := m[g.MediaFileId] - mf.Genres = append(mf.Genres, g.Genre) + appendGenre(items[g.ItemID], g.Genre) } return nil }) } -func (r *sqlRepository) loadAlbumGenres(mfs *model.Albums) error { +func loadAllGenres[T modelWithGenres](r baseRepository, items []T) error { + // Map references to items by ID and collect all IDs + m := map[string]*T{} var ids []string - m := map[string]*model.Album{} - for i := range *mfs { - mf := &(*mfs)[i] - ids = append(ids, mf.ID) - m[mf.ID] = mf + for i := range items { + item := &(items)[i] + id := getID(*item) + ids = append(ids, id) + m[id] = item } - return slice.RangeByChunks(ids, 900, func(ids []string) error { - sql := Select("g.*", "ag.album_id").From("genre g").Join("album_genres ag on ag.genre_id = g.id"). - Where(Eq{"ag.album_id": ids}).OrderBy("ag.album_id", "ag.rowid") - var genres []struct { - model.Genre - AlbumId string - } - - err := r.queryAll(sql, &genres) - if err != nil { - return err - } - for _, g := range genres { - mf := m[g.AlbumId] - mf.Genres = append(mf.Genres, g.Genre) - } - return nil - }) -} - -func (r *sqlRepository) loadArtistGenres(mfs *model.Artists) error { - var ids []string - m := map[string]*model.Artist{} - for i := range *mfs { - mf := &(*mfs)[i] - ids = append(ids, mf.ID) - m[mf.ID] = mf - } - - return slice.RangeByChunks(ids, 900, func(ids []string) error { - sql := Select("g.*", "ag.artist_id").From("genre g").Join("artist_genres ag on ag.genre_id = g.id"). - Where(Eq{"ag.artist_id": ids}).OrderBy("ag.artist_id", "ag.rowid") - var genres []struct { - model.Genre - ArtistId string - } - - err := r.queryAll(sql, &genres) - if err != nil { - return err - } - for _, g := range genres { - mf := m[g.ArtistId] - mf.Genres = append(mf.Genres, g.Genre) - } - return nil - }) + return loadGenres(r, ids, m) }