From 98218d045e2ebc848714c8aac46c78f6febc11cf Mon Sep 17 00:00:00 2001 From: Guilherme Souza <32180229+gqgs@users.noreply.github.com> Date: Sat, 18 May 2024 15:10:53 -0300 Subject: [PATCH] Deterministic pagination in random albums sort (#1841) * Deterministic pagination in random albums sort * Reseed on first random page * Add unit tests * Use rand in Subsonic API * Use different seeds per user on SEEDEDRAND() SQLite3 function * Small refactor * Fix id mismatch * Add seeded random to media_file (subsonic endpoint `getRandomSongs`) * Refactor * Remove unneeded import --------- Co-authored-by: Deluan --- db/db.go | 11 ++++++-- persistence/album_repository.go | 5 ++-- persistence/mediafile_repository.go | 5 ++-- persistence/sql_base_repository.go | 13 +++++++++ server/subsonic/filter/filters.go | 4 +-- utils/hasher/hasher.go | 44 +++++++++++++++++++++++++++++ utils/hasher/hasher_test.go | 36 +++++++++++++++++++++++ 7 files changed, 110 insertions(+), 8 deletions(-) create mode 100644 utils/hasher/hasher.go create mode 100644 utils/hasher/hasher_test.go diff --git a/db/db.go b/db/db.go index b8fd3a36f..cf0ce2cfb 100644 --- a/db/db.go +++ b/db/db.go @@ -5,10 +5,11 @@ import ( "embed" "fmt" - _ "github.com/mattn/go-sqlite3" + "github.com/mattn/go-sqlite3" "github.com/navidrome/navidrome/conf" _ "github.com/navidrome/navidrome/db/migrations" "github.com/navidrome/navidrome/log" + "github.com/navidrome/navidrome/utils/hasher" "github.com/navidrome/navidrome/utils/singleton" "github.com/pressly/goose/v3" ) @@ -25,13 +26,19 @@ const migrationsFolder = "migrations" func Db() *sql.DB { return singleton.GetInstance(func() *sql.DB { + sql.Register(Driver+"_custom", &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + return conn.RegisterFunc("SEEDEDRAND", hasher.HashFunc(), false) + }, + }) + Path = conf.Server.DbPath if Path == ":memory:" { Path = "file::memory:?cache=shared&_foreign_keys=on" conf.Server.DbPath = Path } log.Debug("Opening DataBase", "dbPath", Path, "driver", Driver) - instance, err := sql.Open(Driver, Path) + instance, err := sql.Open(Driver+"_custom", Path) if err != nil { panic(err) } diff --git a/persistence/album_repository.go b/persistence/album_repository.go index 862852d7e..c820fc13a 100644 --- a/persistence/album_repository.go +++ b/persistence/album_repository.go @@ -75,7 +75,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito "artist": "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", "albumArtist": "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", "max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc", - "random": "RANDOM()", + "random": r.seededRandomSort(), "recently_added": recentlyAddedSort(), } } else { @@ -84,7 +84,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito "artist": "compilation asc, order_album_artist_name asc, order_album_name asc", "albumArtist": "compilation asc, order_album_artist_name asc, order_album_name asc", "max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, order_album_name asc", - "random": "RANDOM()", + "random": r.seededRandomSort(), "recently_added": recentlyAddedSort(), } } @@ -180,6 +180,7 @@ func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, e } func (r *albumRepository) GetAllWithoutGenres(options ...model.QueryOptions) (model.Albums, error) { + r.resetSeededRandom(options) sq := r.selectAlbum(options...) var dba dbAlbums err := r.queryAll(sq, &dba) diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index 5c018f34a..6c476a4fe 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -36,7 +36,7 @@ func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepos "title": "COALESCE(NULLIF(sort_title,''),title)", "artist": "COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc", "album": "COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc, COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_title,''),title) asc", - "random": "RANDOM()", + "random": r.seededRandomSort(), "createdAt": "media_file.created_at", } } else { @@ -44,7 +44,7 @@ func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepos "title": "order_title", "artist": "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc", "album": "order_album_name asc, release_date asc, disc_number asc, track_number asc, order_artist_name asc, title asc", - "random": "RANDOM()", + "random": r.seededRandomSort(), "createdAt": "media_file.created_at", } } @@ -102,6 +102,7 @@ func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) { } func (r *mediaFileRepository) GetAll(options ...model.QueryOptions) (model.MediaFiles, error) { + r.resetSeededRandom(options) sq := r.selectMediaFile(options...) res := model.MediaFiles{} err := r.queryAll(sq, &res, options...) diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index d5282516d..da5ac6a3d 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -14,6 +14,7 @@ import ( "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" + "github.com/navidrome/navidrome/utils/hasher" "github.com/pocketbase/dbx" ) @@ -137,6 +138,18 @@ func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOpti return sq } +func (r sqlRepository) seededRandomSort() string { + u, _ := request.UserFrom(r.ctx) + return fmt.Sprintf("SEEDEDRAND('%s', id)", r.tableName+u.ID) +} + +func (r sqlRepository) resetSeededRandom(options []model.QueryOptions) { + if len(options) > 0 && options[0].Offset == 0 && options[0].Sort == "random" { + u, _ := request.UserFrom(r.ctx) + hasher.Reseed(r.tableName + u.ID) + } +} + func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) { query, args, err := r.toSQL(sq) if err != nil { diff --git a/server/subsonic/filter/filters.go b/server/subsonic/filter/filters.go index fca482f33..87fb4804e 100644 --- a/server/subsonic/filter/filters.go +++ b/server/subsonic/filter/filters.go @@ -24,7 +24,7 @@ func AlbumsByFrequent() Options { } func AlbumsByRandom() Options { - return Options{Sort: "random()"} + return Options{Sort: "random"} } func AlbumsByName() Options { @@ -100,7 +100,7 @@ func SongsByAlbum(albumId string) Options { func SongsByRandom(genre string, fromYear, toYear int) Options { options := Options{ - Sort: "random()", + Sort: "random", } ff := squirrel.And{} if genre != "" { diff --git a/utils/hasher/hasher.go b/utils/hasher/hasher.go new file mode 100644 index 000000000..78566913a --- /dev/null +++ b/utils/hasher/hasher.go @@ -0,0 +1,44 @@ +package hasher + +import "hash/maphash" + +var instance = NewHasher() + +func Reseed(id string) { + instance.Reseed(id) +} + +func HashFunc() func(id, str string) uint64 { + return instance.HashFunc() +} + +type hasher struct { + seeds map[string]maphash.Seed +} + +func NewHasher() *hasher { + h := new(hasher) + h.seeds = make(map[string]maphash.Seed) + return h +} + +// Reseed generates a new seed for the given id +func (h *hasher) Reseed(id string) { + h.seeds[id] = maphash.MakeSeed() +} + +// HashFunc returns a function that hashes a string using the seed for the given id +func (h *hasher) HashFunc() func(id, str string) uint64 { + return func(id, str string) uint64 { + var hash maphash.Hash + var seed maphash.Seed + var ok bool + if seed, ok = h.seeds[id]; !ok { + seed = maphash.MakeSeed() + h.seeds[id] = seed + } + hash.SetSeed(seed) + _, _ = hash.WriteString(str) + return hash.Sum64() + } +} diff --git a/utils/hasher/hasher_test.go b/utils/hasher/hasher_test.go new file mode 100644 index 000000000..3a1f9dfde --- /dev/null +++ b/utils/hasher/hasher_test.go @@ -0,0 +1,36 @@ +package hasher_test + +import ( + "github.com/navidrome/navidrome/utils/hasher" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HashFunc", func() { + const input = "123e4567e89b12d3a456426614174000" + + It("hashes the input and returns the sum", func() { + hashFunc := hasher.HashFunc() + sum := hashFunc("1", input) + Expect(sum > 0).To(BeTrue()) + }) + + It("hashes the input, reseeds and returns a different sum", func() { + hashFunc := hasher.HashFunc() + sum := hashFunc("1", input) + hasher.Reseed("1") + sum2 := hashFunc("1", input) + Expect(sum).NotTo(Equal(sum2)) + }) + + It("keeps different hashes for different ids", func() { + hashFunc := hasher.HashFunc() + sum := hashFunc("1", input) + sum2 := hashFunc("2", input) + + Expect(sum).NotTo(Equal(sum2)) + + Expect(sum).To(Equal(hashFunc("1", input))) + Expect(sum2).To(Equal(hashFunc("2", input))) + }) +})