mirror of
https://github.com/navidrome/navidrome.git
synced 2025-04-03 20:47:35 +03:00
Fix writeEvents race condition.
This required removing the compress middleware from the /events route.
This commit is contained in:
parent
83ae2ba3e6
commit
1c7fb74a1d
6 changed files with 91 additions and 195 deletions
|
@ -31,16 +31,16 @@ import (
|
|||
func CreateServer(musicFolder string) *server.Server {
|
||||
sqlDB := db.Db()
|
||||
dataStore := persistence.New(sqlDB)
|
||||
serverServer := server.New(dataStore)
|
||||
broker := events.GetBroker()
|
||||
serverServer := server.New(dataStore, broker)
|
||||
return serverServer
|
||||
}
|
||||
|
||||
func CreateNativeAPIRouter() *nativeapi.Router {
|
||||
sqlDB := db.Db()
|
||||
dataStore := persistence.New(sqlDB)
|
||||
broker := events.GetBroker()
|
||||
share := core.NewShare(dataStore)
|
||||
router := nativeapi.New(dataStore, broker, share)
|
||||
router := nativeapi.New(dataStore, share)
|
||||
return router
|
||||
}
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
-s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -tags netgo .
|
||||
-s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -race -tags netgo .
|
||||
|
|
|
@ -3,7 +3,6 @@ package events
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -93,38 +92,35 @@ func (b *broker) prepareMessage(ctx context.Context, event Event) message {
|
|||
return msg
|
||||
}
|
||||
|
||||
var errWriteTimeOut = errors.New("write timeout")
|
||||
|
||||
// writeEvent writes a message to the given io.Writer, formatted as a Server-Sent Event.
|
||||
// If the writer is an http.Flusher, it flushes the data immediately instead of buffering it.
|
||||
// The function waits for the message to be written or times out after the specified timeout.
|
||||
func writeEvent(w io.Writer, event message, timeout time.Duration) error {
|
||||
// Create a context with a timeout based on the event's sender context.
|
||||
ctx, cancel := context.WithTimeout(event.senderCtx, timeout)
|
||||
defer cancel()
|
||||
func writeEvent(ctx context.Context, w io.Writer, event message, timeout time.Duration) error {
|
||||
if err := setWriteTimeout(w, timeout); err != nil {
|
||||
log.Debug(ctx, "Error setting write timeout", err)
|
||||
}
|
||||
|
||||
// Create a channel to signal the completion of writing.
|
||||
errC := make(chan error, 1)
|
||||
|
||||
// Start a goroutine to write the event and optionally flush the writer.
|
||||
go func() {
|
||||
_, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)
|
||||
|
||||
// If the writer is an http.Flusher, flush the data immediately.
|
||||
if flusher, ok := w.(http.Flusher); ok && flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Signal that writing is complete.
|
||||
errC <- err
|
||||
}()
|
||||
|
||||
// Wait for either the write completion or the context to time out.
|
||||
select {
|
||||
case err := <-errC:
|
||||
_, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)
|
||||
if err != nil {
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return errWriteTimeOut
|
||||
}
|
||||
|
||||
// If the writer is an http.Flusher, flush the data immediately.
|
||||
if flusher, ok := w.(http.Flusher); ok && flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setWriteTimeout(rw io.Writer, timeout time.Duration) error {
|
||||
for {
|
||||
switch t := rw.(type) {
|
||||
case interface{ SetWriteDeadline(time.Time) error }:
|
||||
return t.SetWriteDeadline(time.Now().Add(timeout))
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
rw = t.Unwrap()
|
||||
default:
|
||||
return fmt.Errorf("%T - %w", rw, http.ErrNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,9 +156,9 @@ func (b *broker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
log.Trace(ctx, "Sending event to client", "event", *event, "client", c.String())
|
||||
if err := writeEvent(w, *event, writeTimeOut); errors.Is(err, errWriteTimeOut) {
|
||||
log.Debug(ctx, "Timeout sending event to client", "event", *event, "client", c.String())
|
||||
return
|
||||
err := writeEvent(ctx, w, *event, writeTimeOut)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "Error sending event to client", "event", *event, "client", c.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -63,126 +58,4 @@ var _ = Describe("Broker", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("writeEvent", func() {
|
||||
var (
|
||||
timeout time.Duration
|
||||
buffer *bytes.Buffer
|
||||
event message
|
||||
senderCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
buffer = &bytes.Buffer{}
|
||||
senderCtx, cancel = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancel)
|
||||
})
|
||||
|
||||
Context("with an HTTP flusher", func() {
|
||||
var flusher *fakeFlusher
|
||||
|
||||
BeforeEach(func() {
|
||||
flusher = &fakeFlusher{Writer: buffer}
|
||||
event = message{
|
||||
senderCtx: senderCtx,
|
||||
id: 1,
|
||||
event: "test",
|
||||
data: "testdata",
|
||||
}
|
||||
})
|
||||
|
||||
Context("when the write completes before the timeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout = 1 * time.Second
|
||||
})
|
||||
It("should successfully write the event", func() {
|
||||
err := writeEvent(flusher, event, timeout)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)))
|
||||
Expect(flusher.flushed.Load()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when the write does not complete before the timeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout = 1 * time.Millisecond
|
||||
flusher.delay = 2 * time.Second
|
||||
})
|
||||
|
||||
It("should return an errWriteTimeOut error", func() {
|
||||
err := writeEvent(flusher, event, timeout)
|
||||
Expect(err).To(MatchError(errWriteTimeOut))
|
||||
Expect(flusher.flushed.Load()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("without an HTTP flusher", func() {
|
||||
var writer *fakeWriter
|
||||
|
||||
BeforeEach(func() {
|
||||
writer = &fakeWriter{Writer: buffer}
|
||||
event = message{
|
||||
senderCtx: senderCtx,
|
||||
id: 1,
|
||||
event: "test",
|
||||
data: "testdata",
|
||||
}
|
||||
})
|
||||
|
||||
Context("when the write completes before the timeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout = 1 * time.Second
|
||||
})
|
||||
|
||||
It("should successfully write the event", func() {
|
||||
err := writeEvent(writer, event, timeout)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(writer.done.Load).Should(BeTrue())
|
||||
Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when the write does not complete before the timeout", func() {
|
||||
BeforeEach(func() {
|
||||
timeout = 1 * time.Millisecond
|
||||
writer.delay = 2 * time.Second
|
||||
})
|
||||
|
||||
It("should return an errWriteTimeOut error", func() {
|
||||
err := writeEvent(writer, event, timeout)
|
||||
Expect(err).To(MatchError(errWriteTimeOut))
|
||||
Expect(writer.done.Load()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
type fakeWriter struct {
|
||||
io.Writer
|
||||
delay time.Duration
|
||||
done atomic.Bool
|
||||
}
|
||||
|
||||
func (f *fakeWriter) Write(p []byte) (n int, err error) {
|
||||
time.Sleep(f.delay)
|
||||
f.done.Store(true)
|
||||
return f.Writer.Write(p)
|
||||
}
|
||||
|
||||
type fakeFlusher struct {
|
||||
io.Writer
|
||||
delay time.Duration
|
||||
flushed atomic.Bool
|
||||
}
|
||||
|
||||
func (f *fakeFlusher) Write(p []byte) (n int, err error) {
|
||||
time.Sleep(f.delay)
|
||||
return f.Writer.Write(p)
|
||||
}
|
||||
|
||||
func (f *fakeFlusher) Flush() {
|
||||
f.flushed.Store(true)
|
||||
}
|
||||
|
|
|
@ -10,18 +10,16 @@ import (
|
|||
"github.com/navidrome/navidrome/core"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/server"
|
||||
"github.com/navidrome/navidrome/server/events"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
http.Handler
|
||||
ds model.DataStore
|
||||
broker events.Broker
|
||||
share core.Share
|
||||
ds model.DataStore
|
||||
share core.Share
|
||||
}
|
||||
|
||||
func New(ds model.DataStore, broker events.Broker, share core.Share) *Router {
|
||||
r := &Router{ds: ds, broker: broker, share: share}
|
||||
func New(ds model.DataStore, share core.Share) *Router {
|
||||
r := &Router{ds: ds, share: share}
|
||||
r.Handler = r.routes()
|
||||
return r
|
||||
}
|
||||
|
@ -55,10 +53,6 @@ func (n *Router) routes() http.Handler {
|
|||
r.Get("/keepalive/*", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"response":"ok", "id":"keepalive"}`))
|
||||
})
|
||||
|
||||
if conf.Server.DevActivityPanel {
|
||||
r.Handle("/events", n.broker)
|
||||
}
|
||||
})
|
||||
|
||||
return r
|
||||
|
|
|
@ -19,21 +19,25 @@ import (
|
|||
"github.com/navidrome/navidrome/core/auth"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/server/events"
|
||||
"github.com/navidrome/navidrome/ui"
|
||||
. "github.com/navidrome/navidrome/utils/gg"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *chi.Mux
|
||||
router chi.Router
|
||||
ds model.DataStore
|
||||
appRoot string
|
||||
broker events.Broker
|
||||
}
|
||||
|
||||
func New(ds model.DataStore) *Server {
|
||||
s := &Server{ds: ds}
|
||||
func New(ds model.DataStore, broker events.Broker) *Server {
|
||||
s := &Server{ds: ds, broker: broker}
|
||||
auth.Init(s.ds)
|
||||
initialSetup(ds)
|
||||
s.initRoutes()
|
||||
s.mountAuthenticationRoutes()
|
||||
s.mountRootRedirector()
|
||||
checkFfmpegInstallation()
|
||||
checkExternalCredentials()
|
||||
return s
|
||||
|
@ -131,24 +135,52 @@ func (s *Server) initRoutes() {
|
|||
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(secureMiddleware())
|
||||
r.Use(corsHandler())
|
||||
r.Use(middleware.RequestID)
|
||||
if conf.Server.ReverseProxyWhitelist == "" {
|
||||
r.Use(middleware.RealIP)
|
||||
middlewares := chi.Middlewares{
|
||||
secureMiddleware(),
|
||||
corsHandler(),
|
||||
middleware.RequestID,
|
||||
}
|
||||
if conf.Server.ReverseProxyWhitelist == "" {
|
||||
middlewares = append(middlewares, middleware.RealIP)
|
||||
}
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(compressMiddleware())
|
||||
r.Use(middleware.Heartbeat("/ping"))
|
||||
r.Use(serverAddressMiddleware)
|
||||
r.Use(clientUniqueIDMiddleware)
|
||||
r.Use(loggerInjector)
|
||||
r.Use(requestLogger)
|
||||
r.Use(robotsTXT(ui.BuildAssets()))
|
||||
r.Use(authHeaderMapper)
|
||||
r.Use(jwtVerifier)
|
||||
|
||||
r.Route(path.Join(conf.Server.BasePath, "/auth"), func(r chi.Router) {
|
||||
middlewares = append(middlewares,
|
||||
middleware.Recoverer,
|
||||
middleware.Heartbeat("/ping"),
|
||||
robotsTXT(ui.BuildAssets()),
|
||||
serverAddressMiddleware,
|
||||
clientUniqueIDMiddleware,
|
||||
)
|
||||
|
||||
// Mount the Native API /events endpoint with all middlewares, except the compress and request logger,
|
||||
// adding the authentication middlewares
|
||||
if conf.Server.DevActivityPanel {
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middlewares...)
|
||||
r.Use(loggerInjector)
|
||||
r.Use(authHeaderMapper)
|
||||
r.Use(jwtVerifier)
|
||||
r.Use(Authenticator(s.ds))
|
||||
r.Use(JWTRefresher)
|
||||
r.Handle(path.Join(conf.Server.BasePath, consts.URLPathNativeAPI, "events"), s.broker)
|
||||
})
|
||||
}
|
||||
|
||||
// Configure the router with the default middlewares
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middlewares...)
|
||||
r.Use(compressMiddleware())
|
||||
r.Use(loggerInjector)
|
||||
r.Use(requestLogger)
|
||||
r.Use(authHeaderMapper)
|
||||
r.Use(jwtVerifier)
|
||||
s.router = r
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) mountAuthenticationRoutes() chi.Router {
|
||||
r := s.router
|
||||
return r.Route(path.Join(conf.Server.BasePath, "/auth"), func(r chi.Router) {
|
||||
if conf.Server.AuthRequestLimit > 0 {
|
||||
log.Info("Login rate limit set", "requestLimit", conf.Server.AuthRequestLimit,
|
||||
"windowLength", conf.Server.AuthWindowLength)
|
||||
|
@ -162,7 +194,11 @@ func (s *Server) initRoutes() {
|
|||
}
|
||||
r.Post("/createAdmin", createAdmin(s.ds))
|
||||
})
|
||||
}
|
||||
|
||||
// Serve UI app assets
|
||||
func (s *Server) mountRootRedirector() {
|
||||
r := s.router
|
||||
// Redirect root to UI URL
|
||||
r.Get("/*", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, s.appRoot+"/", http.StatusFound)
|
||||
|
@ -170,11 +206,8 @@ func (s *Server) initRoutes() {
|
|||
r.Get(s.appRoot, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, s.appRoot+"/", http.StatusFound)
|
||||
})
|
||||
|
||||
s.router = r
|
||||
}
|
||||
|
||||
// Serve UI app assets
|
||||
func (s *Server) frontendAssetsHandler() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue