From 3107170afd9f557a10f7031f23cb3c9e975a71f9 Mon Sep 17 00:00:00 2001 From: Deluan Date: Mon, 9 Sep 2024 19:45:02 -0400 Subject: [PATCH] Improve SQL sanitization --- model/annotation.go | 10 +- persistence/album_repository.go | 14 ++- persistence/artist_repository.go | 10 +- persistence/genre_repository.go | 10 +- persistence/library_repository.go | 2 +- persistence/mediafile_repository.go | 32 +++--- persistence/player_repository.go | 10 +- persistence/playlist_repository.go | 10 +- persistence/playlist_track_repository.go | 6 +- persistence/radio_repository.go | 10 +- persistence/share_repository.go | 7 +- persistence/sql_base_repository.go | 32 +++++- persistence/sql_base_repository_test.go | 65 ++++++++---- persistence/sql_restful.go | 124 +++++++++++++++++++---- persistence/sql_restful_test.go | 16 +-- persistence/transcoding_repository.go | 7 +- persistence/user_repository.go | 21 ++-- ui/src/album/AlbumSongs.js | 7 +- ui/src/album/AlbumTableView.js | 3 +- ui/src/artist/ArtistList.js | 3 +- ui/src/common/ContextMenus.js | 4 +- ui/src/common/PlayButton.js | 2 +- ui/src/song/SongList.js | 13 +-- 23 files changed, 259 insertions(+), 159 deletions(-) diff --git a/model/annotation.go b/model/annotation.go index f96e926c0..b365e23ba 100644 --- a/model/annotation.go +++ b/model/annotation.go @@ -3,11 +3,11 @@ package model import "time" type Annotations struct { - PlayCount int64 `structs:"-" json:"playCount"` - PlayDate *time.Time `structs:"-" json:"playDate" ` - Rating int `structs:"-" json:"rating" ` - Starred bool `structs:"-" json:"starred" ` - StarredAt *time.Time `structs:"-" json:"starredAt"` + PlayCount int64 `structs:"play_count" json:"playCount"` + PlayDate *time.Time `structs:"play_date" json:"playDate" ` + Rating int `structs:"rating" json:"rating" ` + Starred bool `structs:"starred" json:"starred" ` + StarredAt *time.Time `structs:"starred_at" json:"starredAt"` } type AnnotatedRepository interface { diff --git a/persistence/album_repository.go b/persistence/album_repository.go index c820fc13a..a0a92f7ee 100644 --- a/persistence/album_repository.go +++ b/persistence/album_repository.go @@ -16,7 +16,6 @@ import ( type albumRepository struct { sqlRepository - sqlRestful } type dbAlbum struct { @@ -58,8 +57,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito r := &albumRepository{} r.ctx = ctx r.db = db - r.tableName = "album" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Album{}, map[string]filterFunc{ "id": idFilter(r.tableName), "name": fullTextFilter, "compilation": booleanFilter, @@ -68,12 +66,12 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito "recently_played": recentlyPlayedFilter, "starred": booleanFilter, "has_rating": hasRatingFilter, - } + }) if conf.Server.PreferSortTags { r.sortMappings = map[string]string{ "name": "COALESCE(NULLIF(sort_album_name,''),order_album_name)", "artist": "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", - "albumArtist": "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", + "album_artist": "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", "max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", "random": r.seededRandomSort(), "recently_added": recentlyAddedSort(), @@ -82,7 +80,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito r.sortMappings = map[string]string{ "name": "order_album_name asc, order_album_artist_name asc", "artist": "compilation asc, order_album_artist_name asc, order_album_name asc", - "albumArtist": "compilation asc, order_album_artist_name asc, order_album_name asc", + "album_artist": "compilation asc, order_album_artist_name asc, order_album_name asc", "max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, order_album_name asc", "random": r.seededRandomSort(), "recently_added": recentlyAddedSort(), @@ -213,7 +211,7 @@ func (r *albumRepository) Search(q string, offset int, size int) (model.Albums, } func (r *albumRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *albumRepository) Read(id string) (interface{}, error) { @@ -221,7 +219,7 @@ func (r *albumRepository) Read(id string) (interface{}, error) { } func (r *albumRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *albumRepository) EntityName() string { diff --git a/persistence/artist_repository.go b/persistence/artist_repository.go index e1c7134c6..617fcb6df 100644 --- a/persistence/artist_repository.go +++ b/persistence/artist_repository.go @@ -20,7 +20,6 @@ import ( type artistRepository struct { sqlRepository - sqlRestful indexGroups utils.IndexGroups } @@ -60,12 +59,11 @@ func NewArtistRepository(ctx context.Context, db dbx.Builder) model.ArtistReposi r.ctx = ctx r.db = db r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups) - r.tableName = "artist" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Artist{}, map[string]filterFunc{ "id": idFilter(r.tableName), "name": fullTextFilter, "starred": booleanFilter, - } + }) if conf.Server.PreferSortTags { r.sortMappings = map[string]string{ "name": "COALESCE(NULLIF(sort_artist_name,''),order_artist_name)", @@ -200,7 +198,7 @@ func (r *artistRepository) Search(q string, offset int, size int) (model.Artists } func (r *artistRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *artistRepository) Read(id string) (interface{}, error) { @@ -208,7 +206,7 @@ func (r *artistRepository) Read(id string) (interface{}, error) { } func (r *artistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *artistRepository) EntityName() string { diff --git a/persistence/genre_repository.go b/persistence/genre_repository.go index 357a8b4d9..77f27b77b 100644 --- a/persistence/genre_repository.go +++ b/persistence/genre_repository.go @@ -14,17 +14,15 @@ import ( type genreRepository struct { sqlRepository - sqlRestful } func NewGenreRepository(ctx context.Context, db dbx.Builder) model.GenreRepository { r := &genreRepository{} r.ctx = ctx r.db = db - r.tableName = "genre" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Genre{}, map[string]filterFunc{ "name": containsFilter("name"), - } + }) return r } @@ -60,7 +58,7 @@ func (r *genreRepository) Put(m *model.Genre) error { } func (r *genreRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.count(Select(), r.parseRestOptions(options...)) + return r.count(Select(), r.parseRestOptions(r.ctx, options...)) } func (r *genreRepository) Read(id string) (interface{}, error) { @@ -71,7 +69,7 @@ func (r *genreRepository) Read(id string) (interface{}, error) { } func (r *genreRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - sel := r.newSelect(r.parseRestOptions(options...)).Columns("*") + sel := r.newSelect(r.parseRestOptions(r.ctx, options...)).Columns("*") res := model.Genres{} err := r.queryAll(sel, &res) return res, err diff --git a/persistence/library_repository.go b/persistence/library_repository.go index e5009eef9..4603c613a 100644 --- a/persistence/library_repository.go +++ b/persistence/library_repository.go @@ -18,7 +18,7 @@ func NewLibraryRepository(ctx context.Context, db dbx.Builder) model.LibraryRepo r := &libraryRepository{} r.ctx = ctx r.db = db - r.tableName = "library" + r.registerModel(&model.Library{}, nil) return r } diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index 6c476a4fe..584b381f8 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -18,34 +18,34 @@ import ( type mediaFileRepository struct { sqlRepository - sqlRestful } func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepository { r := &mediaFileRepository{} r.ctx = ctx r.db = db - r.tableName = "media_file" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.MediaFile{}, map[string]filterFunc{ "id": idFilter(r.tableName), "title": fullTextFilter, "starred": booleanFilter, - } + }) if conf.Server.PreferSortTags { r.sortMappings = map[string]string{ - "title": "COALESCE(NULLIF(sort_title,''),title)", - "artist": "COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc", - "album": "COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc, COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_title,''),title) asc", - "random": r.seededRandomSort(), - "createdAt": "media_file.created_at", + "title": "COALESCE(NULLIF(sort_title,''),title)", + "artist": "COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc", + "album": "COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc, COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_title,''),title) asc", + "random": r.seededRandomSort(), + "created_at": "media_file.created_at", + "track_number": "album, release_date, disc_number, track_number", } } else { r.sortMappings = map[string]string{ - "title": "order_title", - "artist": "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc", - "album": "order_album_name asc, release_date asc, disc_number asc, track_number asc, order_artist_name asc, title asc", - "random": r.seededRandomSort(), - "createdAt": "media_file.created_at", + "title": "order_title", + "artist": "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc", + "album": "order_album_name asc, release_date asc, disc_number asc, track_number asc, order_artist_name asc, title asc", + "random": r.seededRandomSort(), + "created_at": "media_file.created_at", + "track_number": "album, release_date, disc_number, track_number", } } return r @@ -209,7 +209,7 @@ func (r *mediaFileRepository) Search(q string, offset int, size int) (model.Medi } func (r *mediaFileRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *mediaFileRepository) Read(id string) (interface{}, error) { @@ -217,7 +217,7 @@ func (r *mediaFileRepository) Read(id string) (interface{}, error) { } func (r *mediaFileRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *mediaFileRepository) EntityName() string { diff --git a/persistence/player_repository.go b/persistence/player_repository.go index d51613605..381ec2ae8 100644 --- a/persistence/player_repository.go +++ b/persistence/player_repository.go @@ -12,17 +12,15 @@ import ( type playerRepository struct { sqlRepository - sqlRestful } func NewPlayerRepository(ctx context.Context, db dbx.Builder) model.PlayerRepository { r := &playerRepository{} r.ctx = ctx r.db = db - r.tableName = "player" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Player{}, map[string]filterFunc{ "name": containsFilter("player.name"), - } + }) return r } @@ -74,7 +72,7 @@ func (r *playerRepository) addRestriction(sql ...Sqlizer) Sqlizer { } func (r *playerRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.count(r.newRestSelect(), r.parseRestOptions(options...)) + return r.count(r.newRestSelect(), r.parseRestOptions(r.ctx, options...)) } func (r *playerRepository) Read(id string) (interface{}, error) { @@ -85,7 +83,7 @@ func (r *playerRepository) Read(id string) (interface{}, error) { } func (r *playerRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - sel := r.newRestSelect(r.parseRestOptions(options...)) + sel := r.newRestSelect(r.parseRestOptions(r.ctx, options...)) res := model.Players{} err := r.queryAll(sel, &res) return res, err diff --git a/persistence/playlist_repository.go b/persistence/playlist_repository.go index 1cbc9fc65..efc072da5 100644 --- a/persistence/playlist_repository.go +++ b/persistence/playlist_repository.go @@ -19,7 +19,6 @@ import ( type playlistRepository struct { sqlRepository - sqlRestful } type dbPlaylist struct { @@ -51,11 +50,10 @@ func NewPlaylistRepository(ctx context.Context, db dbx.Builder) model.PlaylistRe r := &playlistRepository{} r.ctx = ctx r.db = db - r.tableName = "playlist" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Playlist{}, map[string]filterFunc{ "q": playlistFilter, "smart": smartPlaylistFilter, - } + }) return r } @@ -372,7 +370,7 @@ func (r *playlistRepository) loadTracks(sel SelectBuilder, id string) (model.Pla } func (r *playlistRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *playlistRepository) Read(id string) (interface{}, error) { @@ -380,7 +378,7 @@ func (r *playlistRepository) Read(id string) (interface{}, error) { } func (r *playlistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *playlistRepository) EntityName() string { diff --git a/persistence/playlist_track_repository.go b/persistence/playlist_track_repository.go index 6625f9009..f2089edcf 100644 --- a/persistence/playlist_track_repository.go +++ b/persistence/playlist_track_repository.go @@ -13,7 +13,6 @@ import ( type playlistTrackRepository struct { sqlRepository - sqlRestful playlistId string playlist *model.Playlist playlistRepo *playlistRepository @@ -26,6 +25,7 @@ func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool p.ctx = r.ctx p.db = r.db p.tableName = "playlist_tracks" + p.registerModel(&model.PlaylistTrack{}, nil) p.sortMappings = map[string]string{ "id": "playlist_tracks.id", "artist": "order_artist_name asc", @@ -51,7 +51,7 @@ func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool } func (r *playlistTrackRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.count(Select().Where(Eq{"playlist_id": r.playlistId}), r.parseRestOptions(options...)) + return r.count(Select().Where(Eq{"playlist_id": r.playlistId}), r.parseRestOptions(r.ctx, options...)) } func (r *playlistTrackRepository) Read(id string) (interface{}, error) { @@ -112,7 +112,7 @@ func (r *playlistTrackRepository) GetAlbumIDs(options ...model.QueryOptions) ([] } func (r *playlistTrackRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *playlistTrackRepository) EntityName() string { diff --git a/persistence/radio_repository.go b/persistence/radio_repository.go index 636fd891d..781c50241 100644 --- a/persistence/radio_repository.go +++ b/persistence/radio_repository.go @@ -15,17 +15,15 @@ import ( type radioRepository struct { sqlRepository - sqlRestful } func NewRadioRepository(ctx context.Context, db dbx.Builder) model.RadioRepository { r := &radioRepository{} r.ctx = ctx r.db = db - r.tableName = "radio" - r.filterMappings = map[string]filterFunc{ + r.registerModel(&model.Radio{}, map[string]filterFunc{ "name": containsFilter("name"), - } + }) r.sortMappings = map[string]string{ "name": "(name collate nocase), name", } @@ -96,7 +94,7 @@ func (r *radioRepository) Put(radio *model.Radio) error { } func (r *radioRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *radioRepository) EntityName() string { @@ -112,7 +110,7 @@ func (r *radioRepository) Read(id string) (interface{}, error) { } func (r *radioRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *radioRepository) Save(entity interface{}) (string, error) { diff --git a/persistence/share_repository.go b/persistence/share_repository.go index 2547bcfa5..ca7b971c1 100644 --- a/persistence/share_repository.go +++ b/persistence/share_repository.go @@ -17,14 +17,13 @@ import ( type shareRepository struct { sqlRepository - sqlRestful } func NewShareRepository(ctx context.Context, db dbx.Builder) model.ShareRepository { r := &shareRepository{} r.ctx = ctx r.db = db - r.tableName = "share" + r.registerModel(&model.Share{}, map[string]filterFunc{}) return r } @@ -166,7 +165,7 @@ func (r *shareRepository) CountAll(options ...model.QueryOptions) (int64, error) } func (r *shareRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *shareRepository) EntityName() string { @@ -185,7 +184,7 @@ func (r *shareRepository) Read(id string) (interface{}, error) { } func (r *shareRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - sq := r.selectShare(r.parseRestOptions(options...)) + sq := r.selectShare(r.parseRestOptions(r.ctx, options...)) res := model.Shares{} err := r.queryAll(sq, &res) return res, err diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index b87d354c8..e50f72cb1 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -19,11 +19,25 @@ import ( "github.com/pocketbase/dbx" ) +// sqlRepository is the base repository for all SQL repositories. It provides common functions to interact with the DB. +// When creating a new repository using this base, you must: +// +// - Embed this struct. +// - Set ctx and db fields. ctx should be the context passed to the constructor method, usually obtained from the request +// - Call registerModel with the model instance and any possible filters. +// - If the model has a different table name than the default (lowercase of the model name), it should be set manually +// using the tableName field. +// - Sort mappings should be set in the sortMappings field. If the sort field is not in the map, it will be used as is. +// +// All fields in filters and sortMappings must be in snake_case. Only sorts and filters based on real field names or +// defined in the mappings will be allowed. type sqlRepository struct { - ctx context.Context - tableName string - db dbx.Builder - sortMappings map[string]string + ctx context.Context + tableName string + db dbx.Builder + sortMappings map[string]string + filterMappings map[string]filterFunc + isFieldWhiteListed fieldWhiteListedFunc } const invalidUserId = "-1" @@ -44,6 +58,16 @@ func loggedUser(ctx context.Context) *model.User { } } +func (r *sqlRepository) registerModel(instance any, filters map[string]filterFunc) { + if r.tableName == "" { + r.tableName = strings.TrimPrefix(reflect.TypeOf(instance).String(), "*model.") + r.tableName = toSnakeCase(r.tableName) + } + r.tableName = strings.ToLower(r.tableName) + r.isFieldWhiteListed = registerModelWhiteList(instance) + r.filterMappings = filters +} + func (r sqlRepository) getTableName() string { return r.tableName } diff --git a/persistence/sql_base_repository_test.go b/persistence/sql_base_repository_test.go index 8b1900a62..35d802d98 100644 --- a/persistence/sql_base_repository_test.go +++ b/persistence/sql_base_repository_test.go @@ -73,40 +73,61 @@ var _ = Describe("sqlRepository", func() { }) }) - Describe("sortMapping", func() { + Describe("sanitizeSort", func() { BeforeEach(func() { + r.registerModel(&struct { + Field string `structs:"field"` + }{}, nil) r.sortMappings = map[string]string{ - "sort1": "mappedSort1", - "sortTwo": "mappedSort2", - "sort_three": "mappedSort3", + "sort1": "mappedSort1", } }) - It("returns the mapped value when sort key exists", func() { - Expect(r.sortMapping("sort1")).To(Equal("mappedSort1")) - }) + When("sanitizing sort", func() { + It("returns empty if the sort key is not found in the model nor in the mappings", func() { + sort, _ := r.sanitizeSort("unknown", "") + Expect(sort).To(BeEmpty()) + }) - Context("when sort key does not exist", func() { - It("returns the original sort key, snake cased", func() { - Expect(r.sortMapping("NotFoundSort")).To(Equal("not_found_sort")) + It("returns the mapped value when sort key exists", func() { + sort, _ := r.sanitizeSort("sort1", "") + Expect(sort).To(Equal("mappedSort1")) + }) + + It("is case insensitive", func() { + sort, _ := r.sanitizeSort("Sort1", "") + Expect(sort).To(Equal("mappedSort1")) + }) + + It("returns the field if it is a valid field", func() { + sort, _ := r.sanitizeSort("field", "") + Expect(sort).To(Equal("field")) + }) + + It("is case insensitive for fields", func() { + sort, _ := r.sanitizeSort("FIELD", "") + Expect(sort).To(Equal("field")) }) }) + When("sanitizing order", func() { + It("returns 'asc' if order is empty", func() { + _, order := r.sanitizeSort("", "") + Expect(order).To(Equal("")) + }) - Context("when sort key is camel cased", func() { - It("returns the mapped value when camel case sort key exists", func() { - Expect(r.sortMapping("sortTwo")).To(Equal("mappedSort2")) + It("returns 'asc' if order is 'asc'", func() { + _, order := r.sanitizeSort("", "ASC") + Expect(order).To(Equal("asc")) }) - It("returns the mapped value when passing a snake case key", func() { - Expect(r.sortMapping("sort_two")).To(Equal("mappedSort2")) - }) - }) - Context("when sort key is snake cased", func() { - It("returns the mapped value when snake case sort key exists", func() { - Expect(r.sortMapping("sort_three")).To(Equal("mappedSort3")) + It("returns 'desc' if order is 'desc'", func() { + _, order := r.sanitizeSort("", "desc") + Expect(order).To(Equal("desc")) }) - It("returns the mapped value when passing a camel case key", func() { - Expect(r.sortMapping("sortThree")).To(Equal("mappedSort3")) + + It("returns 'asc' if order is unknown", func() { + _, order := r.sanitizeSort("", "something") + Expect(order).To(Equal("asc")) }) }) }) diff --git a/persistence/sql_restful.go b/persistence/sql_restful.go index b802a750d..193ff6563 100644 --- a/persistence/sql_restful.go +++ b/persistence/sql_restful.go @@ -1,61 +1,94 @@ package persistence import ( + "context" "fmt" + "reflect" "strings" + "sync" . "github.com/Masterminds/squirrel" "github.com/deluan/rest" + "github.com/fatih/structs" + "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" ) -type filterFunc = func(field string, value interface{}) Sqlizer +type filterFunc = func(field string, value any) Sqlizer -type sqlRestful struct { - filterMappings map[string]filterFunc -} - -func (r sqlRestful) parseRestFilters(options rest.QueryOptions) Sqlizer { +func (r *sqlRepository) parseRestFilters(ctx context.Context, options rest.QueryOptions) Sqlizer { if len(options.Filters) == 0 { return nil } filters := And{} for f, v := range options.Filters { + // Ignore filters with empty values if v == "" { continue } + // Look for a custom filter function + f = strings.ToLower(f) if ff, ok := r.filterMappings[f]; ok { filters = append(filters, ff(f, v)) - } else if strings.HasSuffix(strings.ToLower(f), "id") { - filters = append(filters, eqFilter(f, v)) - } else { - filters = append(filters, startsWithFilter(f, v)) + continue } + // Ignore invalid filters (not based on a field or filter function) + if r.isFieldWhiteListed != nil && !r.isFieldWhiteListed(f) { + log.Warn(ctx, "Ignoring filter not whitelisted", "filter", f) + continue + } + // For fields ending in "id", use an exact match + if strings.HasSuffix(f, "id") { + filters = append(filters, eqFilter(f, v)) + continue + } + // Default to a "starts with" filter + filters = append(filters, startsWithFilter(f, v)) } return filters } -func (r sqlRestful) parseRestOptions(options ...rest.QueryOptions) model.QueryOptions { +func (r *sqlRepository) parseRestOptions(ctx context.Context, options ...rest.QueryOptions) model.QueryOptions { qo := model.QueryOptions{} if len(options) > 0 { - qo.Sort = options[0].Sort - qo.Order = strings.ToLower(options[0].Order) + qo.Sort, qo.Order = r.sanitizeSort(options[0].Sort, options[0].Order) qo.Max = options[0].Max qo.Offset = options[0].Offset if seed, ok := options[0].Filters["seed"].(string); ok { qo.Seed = seed delete(options[0].Filters, "seed") } - qo.Filters = r.parseRestFilters(options[0]) + qo.Filters = r.parseRestFilters(ctx, options[0]) } return qo } -func eqFilter(field string, value interface{}) Sqlizer { +func (r sqlRepository) sanitizeSort(sort, order string) (string, string) { + if sort != "" { + sort = toSnakeCase(sort) + if mapped, ok := r.sortMappings[sort]; ok { + sort = mapped + } else { + if !r.isFieldWhiteListed(sort) { + log.Warn(r.ctx, "Ignoring sort not whitelisted", "sort", sort) + sort = "" + } + } + } + if order != "" { + order = strings.ToLower(order) + if order != "desc" { + order = "asc" + } + } + return sort, order +} + +func eqFilter(field string, value any) Sqlizer { return Eq{field: value} } -func startsWithFilter(field string, value interface{}) Sqlizer { +func startsWithFilter(field string, value any) Sqlizer { return Like{field: fmt.Sprintf("%s%%", value)} } @@ -65,16 +98,16 @@ func containsFilter(field string) func(string, any) Sqlizer { } } -func booleanFilter(field string, value interface{}) Sqlizer { +func booleanFilter(field string, value any) Sqlizer { v := strings.ToLower(value.(string)) return Eq{field: strings.ToLower(v) == "true"} } -func fullTextFilter(field string, value interface{}) Sqlizer { +func fullTextFilter(_ string, value any) Sqlizer { return fullTextExpr(value.(string)) } -func substringFilter(field string, value interface{}) Sqlizer { +func substringFilter(field string, value any) Sqlizer { parts := strings.Split(value.(string), " ") filters := And{} for _, part := range parts { @@ -83,8 +116,57 @@ func substringFilter(field string, value interface{}) Sqlizer { return filters } -func idFilter(tableName string) func(string, interface{}) Sqlizer { - return func(field string, value interface{}) Sqlizer { +func idFilter(tableName string) func(string, any) Sqlizer { + return func(field string, value any) Sqlizer { return Eq{tableName + ".id": value} } } + +func invalidFilter(ctx context.Context) func(string, any) Sqlizer { + return func(field string, value any) Sqlizer { + log.Warn(ctx, "Invalid filter", "fieldName", field, "value", value) + return Eq{"1": "0"} + } +} + +var ( + whiteList = map[string]map[string]struct{}{} + mutex sync.RWMutex +) + +func registerModelWhiteList(instance any) fieldWhiteListedFunc { + name := reflect.TypeOf(instance).String() + registerFieldWhiteList(name, instance) + return getFieldWhiteListedFunc(name) +} + +func registerFieldWhiteList(name string, instance any) { + mutex.Lock() + defer mutex.Unlock() + if whiteList[name] != nil { + return + } + m := structs.Map(instance) + whiteList[name] = map[string]struct{}{} + for k := range m { + whiteList[name][toSnakeCase(k)] = struct{}{} + } + ma := structs.Map(model.Annotations{}) + for k := range ma { + whiteList[name][toSnakeCase(k)] = struct{}{} + } +} + +type fieldWhiteListedFunc func(field string) bool + +func getFieldWhiteListedFunc(tableName string) fieldWhiteListedFunc { + return func(field string) bool { + mutex.RLock() + defer mutex.RUnlock() + if _, ok := whiteList[tableName]; !ok { + return false + } + _, ok := whiteList[tableName][field] + return ok + } +} diff --git a/persistence/sql_restful_test.go b/persistence/sql_restful_test.go index 637ac3446..6579e9ffa 100644 --- a/persistence/sql_restful_test.go +++ b/persistence/sql_restful_test.go @@ -1,6 +1,8 @@ package persistence import ( + "context" + "github.com/Masterminds/squirrel" "github.com/deluan/rest" . "github.com/onsi/ginkgo/v2" @@ -9,31 +11,31 @@ import ( var _ = Describe("sqlRestful", func() { Describe("parseRestFilters", func() { - var r sqlRestful + var r sqlRepository var options rest.QueryOptions BeforeEach(func() { - r = sqlRestful{} + r = sqlRepository{} }) It("returns nil if filters is empty", func() { options.Filters = nil - Expect(r.parseRestFilters(options)).To(BeNil()) + Expect(r.parseRestFilters(context.Background(), options)).To(BeNil()) }) It("returns a '=' condition for 'id' filter", func() { options.Filters = map[string]interface{}{"id": "123"} - Expect(r.parseRestFilters(options)).To(Equal(squirrel.And{squirrel.Eq{"id": "123"}})) + Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": "123"}})) }) It("returns a 'in' condition for multiples 'id' filters", func() { options.Filters = map[string]interface{}{"id": []string{"123", "456"}} - Expect(r.parseRestFilters(options)).To(Equal(squirrel.And{squirrel.Eq{"id": []string{"123", "456"}}})) + Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": []string{"123", "456"}}})) }) It("returns a 'like' condition for other filters", func() { options.Filters = map[string]interface{}{"name": "joe"} - Expect(r.parseRestFilters(options)).To(Equal(squirrel.And{squirrel.Like{"name": "joe%"}})) + Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Like{"name": "joe%"}})) }) It("uses the custom filter", func() { @@ -43,7 +45,7 @@ var _ = Describe("sqlRestful", func() { }, } options.Filters = map[string]interface{}{"test": 100} - Expect(r.parseRestFilters(options)).To(Equal(squirrel.And{squirrel.Gt{"test": 100}})) + Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Gt{"test": 100}})) }) }) }) diff --git a/persistence/transcoding_repository.go b/persistence/transcoding_repository.go index 17cd0a019..9f8998b80 100644 --- a/persistence/transcoding_repository.go +++ b/persistence/transcoding_repository.go @@ -12,14 +12,13 @@ import ( type transcodingRepository struct { sqlRepository - sqlRestful } func NewTranscodingRepository(ctx context.Context, db dbx.Builder) model.TranscodingRepository { r := &transcodingRepository{} r.ctx = ctx r.db = db - r.tableName = "transcoding" + r.registerModel(&model.Transcoding{}, nil) return r } @@ -47,7 +46,7 @@ func (r *transcodingRepository) Put(t *model.Transcoding) error { } func (r *transcodingRepository) Count(options ...rest.QueryOptions) (int64, error) { - return r.count(Select(), r.parseRestOptions(options...)) + return r.count(Select(), r.parseRestOptions(r.ctx, options...)) } func (r *transcodingRepository) Read(id string) (interface{}, error) { @@ -55,7 +54,7 @@ func (r *transcodingRepository) Read(id string) (interface{}, error) { } func (r *transcodingRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - sel := r.newSelect(r.parseRestOptions(options...)).Columns("*") + sel := r.newSelect(r.parseRestOptions(r.ctx, options...)).Columns("*") res := model.Transcodings{} err := r.queryAll(sel, &res) return res, err diff --git a/persistence/user_repository.go b/persistence/user_repository.go index 6019f8916..34162446d 100644 --- a/persistence/user_repository.go +++ b/persistence/user_repository.go @@ -22,7 +22,6 @@ import ( type userRepository struct { sqlRepository - sqlRestful } var ( @@ -34,7 +33,9 @@ func NewUserRepository(ctx context.Context, db dbx.Builder) model.UserRepository r := &userRepository{} r.ctx = ctx r.db = db - r.tableName = "user" + r.registerModel(&model.User{}, map[string]filterFunc{ + "password": invalidFilter(ctx), + }) once.Do(func() { _ = r.initPasswordEncryptionKey() }) @@ -91,7 +92,7 @@ func (r *userRepository) FindFirstAdmin() (*model.User, error) { } func (r *userRepository) FindByUsername(username string) (*model.User, error) { - sel := r.newSelect().Columns("*").Where(Like{"user_name": username}) + sel := r.newSelect().Columns("*").Where(Expr("user_name = ? COLLATE NOCASE", username)) var usr model.User err := r.queryOne(sel, &usr) return &usr, err @@ -123,10 +124,10 @@ func (r *userRepository) Count(options ...rest.QueryOptions) (int64, error) { if !usr.IsAdmin { return 0, rest.ErrPermissionDenied } - return r.CountAll(r.parseRestOptions(options...)) + return r.CountAll(r.parseRestOptions(r.ctx, options...)) } -func (r *userRepository) Read(id string) (interface{}, error) { +func (r *userRepository) Read(id string) (any, error) { usr := loggedUser(r.ctx) if !usr.IsAdmin && usr.ID != id { return nil, rest.ErrPermissionDenied @@ -138,23 +139,23 @@ func (r *userRepository) Read(id string) (interface{}, error) { return usr, err } -func (r *userRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { +func (r *userRepository) ReadAll(options ...rest.QueryOptions) (any, error) { usr := loggedUser(r.ctx) if !usr.IsAdmin { return nil, rest.ErrPermissionDenied } - return r.GetAll(r.parseRestOptions(options...)) + return r.GetAll(r.parseRestOptions(r.ctx, options...)) } func (r *userRepository) EntityName() string { return "user" } -func (r *userRepository) NewInstance() interface{} { +func (r *userRepository) NewInstance() any { return &model.User{} } -func (r *userRepository) Save(entity interface{}) (string, error) { +func (r *userRepository) Save(entity any) (string, error) { usr := loggedUser(r.ctx) if !usr.IsAdmin { return "", rest.ErrPermissionDenied @@ -170,7 +171,7 @@ func (r *userRepository) Save(entity interface{}) (string, error) { return u.ID, err } -func (r *userRepository) Update(id string, entity interface{}, cols ...string) error { +func (r *userRepository) Update(id string, entity any, _ ...string) error { u := entity.(*model.User) u.ID = id usr := loggedUser(r.ctx) diff --git a/ui/src/album/AlbumSongs.js b/ui/src/album/AlbumSongs.js index 3ad0eb318..9ceef8030 100644 --- a/ui/src/album/AlbumSongs.js +++ b/ui/src/album/AlbumSongs.js @@ -97,12 +97,7 @@ const AlbumSongs = (props) => { const toggleableFields = useMemo(() => { return { trackNumber: isDesktop && ( - + ), title: ( {columns} { {columns} resource={'album'} songQueryParams={{ pagination: { page: 1, perPage: -1 }, - sort: { field: 'releaseDate, discNumber, trackNumber', order: 'ASC' }, + sort: { field: 'trackNumber', order: 'ASC' }, filter: { album_id: props.record.id, release_date: props.releaseDate, @@ -234,7 +234,7 @@ export const ArtistContextMenu = (props) => songQueryParams={{ pagination: { page: 1, perPage: 200 }, sort: { - field: 'album, releaseDate, discNumber, trackNumber', + field: 'trackNumber', order: 'ASC', }, filter: { album_artist_id: props.record.id }, diff --git a/ui/src/common/PlayButton.js b/ui/src/common/PlayButton.js index ddc676004..0df3125a6 100644 --- a/ui/src/common/PlayButton.js +++ b/ui/src/common/PlayButton.js @@ -21,7 +21,7 @@ export const PlayButton = ({ record, size, className }) => { dataProvider .getList('song', { pagination: { page: 1, perPage: -1 }, - sort: { field: 'releaseDate, discNumber, trackNumber', order: 'ASC' }, + sort: { field: 'trackNumber', order: 'ASC' }, filter: { album_id: record.id, release_date: record.releaseDate, diff --git a/ui/src/song/SongList.js b/ui/src/song/SongList.js index 8222896c8..8251ae651 100644 --- a/ui/src/song/SongList.js +++ b/ui/src/song/SongList.js @@ -98,15 +98,7 @@ const SongList = (props) => { const toggleableFields = React.useMemo(() => { return { - album: isDesktop && ( - - ), + album: isDesktop && , artist: , albumArtist: , trackNumber: isDesktop && , @@ -179,8 +171,7 @@ const SongList = (props) => { {columns}