diff --git a/utils/singleton/singleton.go b/utils/singleton/singleton.go index e1202e19c..f18fa3f14 100644 --- a/utils/singleton/singleton.go +++ b/utils/singleton/singleton.go @@ -3,49 +3,42 @@ package singleton import ( "fmt" "reflect" + "sync" "github.com/navidrome/navidrome/log" ) var ( - instances = make(map[string]any) - getOrCreateC = make(chan entry) + instances = make(map[string]any) + lock sync.RWMutex ) -type entry struct { - f func() any - object any - resultC chan any -} - // GetInstance returns an existing instance of object. If it is not yet created, calls `constructor`, stores the // result for future calls and return it func GetInstance[T any](constructor func() T) T { - var t T - e := entry{ - object: t, - f: func() any { - return constructor() - }, - resultC: make(chan any), - } - getOrCreateC <- e - v := <-e.resultC - return v.(T) -} + var v T + name := reflect.TypeOf(v).String() -func init() { - go func() { - for { - e := <-getOrCreateC - name := reflect.TypeOf(e.object).String() - v, created := instances[name] - if !created { - v = e.f() - log.Trace("Created new singleton", "type", name, "instance", fmt.Sprintf("%+v", v)) - instances[name] = v - } - e.resultC <- v - } + v, available := func() (T, bool) { + lock.RLock() + defer lock.RUnlock() + v, available := instances[name].(T) + return v, available }() + + if available { + return v + } + + lock.Lock() + defer lock.Unlock() + v, available = instances[name].(T) + if available { + return v + } + + v = constructor() + log.Trace("Created new singleton", "type", name, "instance", fmt.Sprintf("%+v", v)) + instances[name] = v + return v } diff --git a/utils/singleton/singleton_test.go b/utils/singleton/singleton_test.go index b187731a9..a49bb64eb 100644 --- a/utils/singleton/singleton_test.go +++ b/utils/singleton/singleton_test.go @@ -19,15 +19,15 @@ func TestSingleton(t *testing.T) { var _ = Describe("GetInstance", func() { type T struct{ id string } - var numInstances int + var numInstancesCreated int constructor := func() *T { - numInstances++ + numInstancesCreated++ return &T{id: uuid.NewString()} } It("calls the constructor to create a new instance", func() { instance := singleton.GetInstance(constructor) - Expect(numInstances).To(Equal(1)) + Expect(numInstancesCreated).To(Equal(1)) Expect(instance).To(BeAssignableToTypeOf(&T{})) }) @@ -36,24 +36,24 @@ var _ = Describe("GetInstance", func() { newInstance := singleton.GetInstance(constructor) Expect(newInstance.id).To(Equal(instance.id)) - Expect(numInstances).To(Equal(1)) + Expect(numInstancesCreated).To(Equal(1)) }) It("makes a distinction between a type and its pointer", func() { instance := singleton.GetInstance(constructor) newInstance := singleton.GetInstance(func() T { - numInstances++ + numInstancesCreated++ return T{id: uuid.NewString()} }) Expect(instance).To(BeAssignableToTypeOf(&T{})) Expect(newInstance).To(BeAssignableToTypeOf(T{})) Expect(newInstance.id).ToNot(Equal(instance.id)) - Expect(numInstances).To(Equal(2)) + Expect(numInstancesCreated).To(Equal(2)) }) It("only calls the constructor once when called concurrently", func() { - const maxCalls = 8000 + const maxCalls = 80000 var numCalls int32 start := sync.WaitGroup{} start.Add(1) @@ -61,12 +61,12 @@ var _ = Describe("GetInstance", func() { prepare.Add(maxCalls) done := sync.WaitGroup{} done.Add(maxCalls) - numInstances = 0 + numInstancesCreated = 0 for i := 0; i < maxCalls; i++ { go func() { start.Wait() singleton.GetInstance(func() struct{ I int } { - numInstances++ + numInstancesCreated++ return struct{ I int }{I: 1} }) atomic.AddInt32(&numCalls, 1) @@ -79,6 +79,6 @@ var _ = Describe("GetInstance", func() { done.Wait() Expect(numCalls).To(Equal(int32(maxCalls))) - Expect(numInstances).To(Equal(1)) + Expect(numInstancesCreated).To(Equal(1)) }) })