From 4d28d534cc6245727390499b38c96017addbefa2 Mon Sep 17 00:00:00 2001 From: Deluan Date: Fri, 17 May 2024 15:45:34 -0400 Subject: [PATCH] Refactor random.WeightedChooser, unsing generics --- core/external_metadata.go | 8 +-- utils/random/weighted_random_chooser.go | 49 +++++++++--------- utils/random/weighted_random_chooser_test.go | 52 ++++++++++++++------ 3 files changed, 68 insertions(+), 41 deletions(-) diff --git a/core/external_metadata.go b/core/external_metadata.go index 6ff78dfcc..c95f044c9 100644 --- a/core/external_metadata.go +++ b/core/external_metadata.go @@ -267,8 +267,8 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in return nil, ctx.Err() } - weightedSongs := random.NewWeightedRandomChooser() - addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser, count, artistWeight int) error { + weightedSongs := random.NewWeightedChooser[model.MediaFile]() + addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser[model.MediaFile], count, artistWeight int) error { if utils.IsCtxDone(ctx) { log.Warn(ctx, "SimilarSongs call canceled", ctx.Err()) return ctx.Err() @@ -302,12 +302,12 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in var similarSongs model.MediaFiles for len(similarSongs) < count && weightedSongs.Size() > 0 { - s, err := weightedSongs.GetAndRemove() + s, err := weightedSongs.Pick() if err != nil { log.Warn(ctx, "Error getting weighted song", err) continue } - similarSongs = append(similarSongs, s.(model.MediaFile)) + similarSongs = append(similarSongs, s) } return similarSongs, nil diff --git a/utils/random/weighted_random_chooser.go b/utils/random/weighted_random_chooser.go index ca6dbeadb..0ae5c8562 100644 --- a/utils/random/weighted_random_chooser.go +++ b/utils/random/weighted_random_chooser.go @@ -2,42 +2,46 @@ package random import ( "errors" + "slices" ) -type WeightedChooser struct { - entries []interface{} +// WeightedChooser allows to randomly choose an entry based on their weights +// (higher weight = higher chance of being chosen). Based on the subtraction method described in +// https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/ +type WeightedChooser[T any] struct { + entries []T weights []int totalWeight int } -func NewWeightedRandomChooser() *WeightedChooser { - return &WeightedChooser{} +func NewWeightedChooser[T any]() *WeightedChooser[T] { + return &WeightedChooser[T]{} } -func (w *WeightedChooser) Add(value interface{}, weight int) { +func (w *WeightedChooser[T]) Add(value T, weight int) { w.entries = append(w.entries, value) w.weights = append(w.weights, weight) w.totalWeight += weight } -// GetAndRemove choose a random entry based on their weights, and removes it from the list -func (w *WeightedChooser) GetAndRemove() (interface{}, error) { +// Pick choose a random entry based on their weights, and removes it from the list +func (w *WeightedChooser[T]) Pick() (T, error) { + var empty T if w.totalWeight == 0 { - return nil, errors.New("cannot choose from zero weight") + return empty, errors.New("cannot choose from zero weight") } i, err := w.weightedChoice() if err != nil { - return nil, err + return empty, err } entry := w.entries[i] - w.Remove(i) + _ = w.Remove(i) return entry, nil } -// Based on https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/ -func (w *WeightedChooser) weightedChoice() (int, error) { - if w.totalWeight == 0 { - return 0, errors.New("no choices available") +func (w *WeightedChooser[T]) weightedChoice() (int, error) { + if len(w.entries) == 0 { + return 0, errors.New("cannot choose from empty list") } rnd := Int64(w.totalWeight) for i, weight := range w.weights { @@ -49,17 +53,18 @@ func (w *WeightedChooser) weightedChoice() (int, error) { return 0, errors.New("internal error - code should not reach this point") } -func (w *WeightedChooser) Remove(i int) { +func (w *WeightedChooser[T]) Remove(i int) error { + if i < 0 || i >= len(w.entries) { + return errors.New("index out of bounds") + } + w.totalWeight -= w.weights[i] - w.weights[i] = w.weights[len(w.weights)-1] - w.weights = w.weights[:len(w.weights)-1] - - w.entries[i] = w.entries[len(w.entries)-1] - w.entries[len(w.entries)-1] = nil - w.entries = w.entries[:len(w.entries)-1] + w.weights = slices.Delete(w.weights, i, i+1) + w.entries = slices.Delete(w.entries, i, i+1) + return nil } -func (w *WeightedChooser) Size() int { +func (w *WeightedChooser[T]) Size() int { return len(w.entries) } diff --git a/utils/random/weighted_random_chooser_test.go b/utils/random/weighted_random_chooser_test.go index c5b3b46ce..026ee92cd 100644 --- a/utils/random/weighted_random_chooser_test.go +++ b/utils/random/weighted_random_chooser_test.go @@ -5,35 +5,57 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("WeightedRandomChooser", func() { - var w *WeightedChooser +var _ = Describe("WeightedChooser", func() { + var w *WeightedChooser[int] BeforeEach(func() { - w = NewWeightedRandomChooser() + w = NewWeightedChooser[int]() for i := 0; i < 10; i++ { - w.Add(i, i) + w.Add(i, i+1) } }) - It("removes a random item", func() { + It("selects and removes a random item", func() { Expect(w.Size()).To(Equal(10)) - _, err := w.GetAndRemove() + _, err := w.Pick() Expect(err).ToNot(HaveOccurred()) Expect(w.Size()).To(Equal(9)) }) + It("removes items", func() { + Expect(w.Size()).To(Equal(10)) + for i := 0; i < 10; i++ { + Expect(w.Remove(0)).To(Succeed()) + } + Expect(w.Size()).To(Equal(0)) + }) + + It("returns error if trying to remove an invalid index", func() { + Expect(w.Size()).To(Equal(10)) + Expect(w.Remove(-1)).ToNot(Succeed()) + Expect(w.Remove(10000)).ToNot(Succeed()) + Expect(w.Size()).To(Equal(10)) + }) + It("returns the sole item", func() { - w = NewWeightedRandomChooser() - w.Add("a", 1) - Expect(w.GetAndRemove()).To(Equal("a")) + ws := NewWeightedChooser[string]() + ws.Add("a", 1) + Expect(ws.Pick()).To(Equal("a")) + }) + + It("returns all items from the list", func() { + for i := 0; i < 10; i++ { + Expect(w.Pick()).To(BeElementOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + } + Expect(w.Size()).To(Equal(0)) }) It("fails when trying to choose from empty set", func() { - w = NewWeightedRandomChooser() - w.Add("a", 1) - w.Add("b", 1) - Expect(w.GetAndRemove()).To(BeElementOf("a", "b")) - Expect(w.GetAndRemove()).To(BeElementOf("a", "b")) - _, err := w.GetAndRemove() + w = NewWeightedChooser[int]() + w.Add(1, 1) + w.Add(2, 1) + Expect(w.Pick()).To(BeElementOf(1, 2)) + Expect(w.Pick()).To(BeElementOf(1, 2)) + _, err := w.Pick() Expect(err).To(HaveOccurred()) })