Replace sync.WaitGroup with more appropriate errgroup.Group

This commit is contained in:
Deluan 2024-05-10 15:27:07 -04:00
parent c0885b55db
commit bb96d455f8

View file

@ -6,7 +6,6 @@ import (
"net/http"
"reflect"
"strings"
"sync"
"time"
"github.com/deluan/sanitize"
@ -15,6 +14,7 @@ import (
"github.com/navidrome/navidrome/server/public"
"github.com/navidrome/navidrome/server/subsonic/responses"
"github.com/navidrome/navidrome/utils/req"
"golang.org/x/sync/errgroup"
)
type searchParams struct {
@ -42,45 +42,39 @@ func (api *Router) getSearchParams(r *http.Request) (*searchParams, error) {
type searchFunc[T any] func(q string, offset int, size int) (T, error)
func callSearch[T any](ctx context.Context, wg *sync.WaitGroup, s searchFunc[T], q string, offset, size int, result *T) {
defer wg.Done()
if size == 0 {
return
}
done := make(chan struct{})
go func() {
func callSearch[T any](ctx context.Context, s searchFunc[T], q string, offset, size int, result *T) func() error {
return func() error {
if size == 0 {
return nil
}
typ := strings.TrimPrefix(reflect.TypeOf(*result).String(), "model.")
var err error
start := time.Now()
*result, err = s(q, offset, size)
if err != nil {
log.Error(ctx, "Error searching "+typ, "query", q, err)
log.Error(ctx, "Error searching "+typ, "query", q, "elapsed", time.Since(start), err)
} else {
log.Trace(ctx, "Search for "+typ+" completed", "query", q, "elapsed", time.Since(start))
}
done <- struct{}{}
}()
select {
case <-done:
case <-ctx.Done():
return nil
}
}
func (api *Router) searchAll(ctx context.Context, sp *searchParams) (mediaFiles model.MediaFiles, albums model.Albums, artists model.Artists) {
start := time.Now()
q := sanitize.Accents(strings.ToLower(strings.TrimSuffix(sp.query, "*")))
wg := &sync.WaitGroup{}
wg.Add(3)
go callSearch(ctx, wg, api.ds.MediaFile(ctx).Search, q, sp.songOffset, sp.songCount, &mediaFiles)
go callSearch(ctx, wg, api.ds.Album(ctx).Search, q, sp.albumOffset, sp.albumCount, &albums)
go callSearch(ctx, wg, api.ds.Artist(ctx).Search, q, sp.artistOffset, sp.artistCount, &artists)
wg.Wait()
if ctx.Err() == nil {
// Run searches in parallel
g, ctx := errgroup.WithContext(ctx)
g.Go(callSearch(ctx, api.ds.MediaFile(ctx).Search, q, sp.songOffset, sp.songCount, &mediaFiles))
g.Go(callSearch(ctx, api.ds.Album(ctx).Search, q, sp.albumOffset, sp.albumCount, &albums))
g.Go(callSearch(ctx, api.ds.Artist(ctx).Search, q, sp.artistOffset, sp.artistCount, &artists))
err := g.Wait()
if err == nil {
log.Debug(ctx, fmt.Sprintf("Search resulted in %d songs, %d albums and %d artists",
len(mediaFiles), len(albums), len(artists)), "query", sp.query, "elapsedTime", time.Since(start))
} else {
log.Warn(ctx, "Search was interrupted", ctx.Err(), "query", sp.query, "elapsedTime", time.Since(start))
log.Warn(ctx, "Search was interrupted", "query", sp.query, "elapsedTime", time.Since(start), err)
}
return mediaFiles, albums, artists
}