fix deadlock on concurrent http3.Server.Serve and Close calls (#3387)

This commit is contained in:
Marten Seemann 2022-04-25 12:10:39 +01:00 committed by GitHub
parent a6a9b2494b
commit 21eda36971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 49 deletions

View file

@ -160,11 +160,10 @@ type Server struct {
mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo
closed utils.AtomicBool
closed bool
altSvcHeader string
loggerOnce sync.Once
logger utils.Logger
}
@ -203,12 +202,24 @@ func (s *Server) Serve(conn net.PacketConn) error {
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener.
func (s *Server) ServeListener(listener quic.EarlyListener) error {
return s.serveImpl(func() (quic.EarlyListener, error) { return listener, nil })
func (s *Server) ServeListener(ln quic.EarlyListener) error {
if s.Server == nil {
return errors.New("use of http3.Server without http.Server")
}
if err := s.addListener(&ln); err != nil {
return err
}
err := s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
return s.serveImpl(func() (quic.EarlyListener, error) {
if s.Server == nil {
return errors.New("use of http3.Server without http.Server")
}
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
@ -227,31 +238,18 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil {
return nil, err
}
return ln, nil
})
}
func (s *Server) serveImpl(startListener func() (quic.EarlyListener, error)) error {
if s.closed.Get() {
return http.ErrServerClosed
}
if s.Server == nil {
return errors.New("use of http3.Server without http.Server")
}
s.loggerOnce.Do(func() {
s.logger = utils.DefaultLogger.WithPrefix("server")
})
ln, err := startListener()
if err != nil {
return err
}
s.addListener(&ln)
defer s.removeListener(&ln)
if err := s.addListener(&ln); err != nil {
return err
}
err = s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveListener(ln quic.EarlyListener) error {
for {
conn, err := ln.Accept(context.Background())
if err != nil {
@ -327,8 +325,16 @@ func (s *Server) generateAltSvcHeader() {
// We store a pointer to interface in the map set. This is safe because we only
// call trackListener via Serve and can track+defer untrack the same pointer to
// local variable there. We never need to compare a Listener from another caller.
func (s *Server) addListener(l *quic.EarlyListener) {
func (s *Server) addListener(l *quic.EarlyListener) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return http.ErrServerClosed
}
if s.logger == nil {
s.logger = utils.DefaultLogger.WithPrefix("server")
}
if s.listeners == nil {
s.listeners = make(map[*quic.EarlyListener]listenerInfo)
}
@ -341,8 +347,7 @@ func (s *Server) addListener(l *quic.EarlyListener) {
s.listeners[l] = listenerInfo{}
}
s.generateAltSvcHeader()
s.mutex.Unlock()
return nil
}
func (s *Server) removeListener(l *quic.EarlyListener) {
@ -553,11 +558,11 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error {
s.closed.Set(true)
s.mutex.Lock()
defer s.mutex.Unlock()
s.closed = true
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {

View file

@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"runtime"
"sync/atomic"
"time"
@ -775,6 +776,22 @@ var _ = Describe("Server", func() {
Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed))
})
It("handles concurrent Serve and Close", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
s.Serve(c)
}()
runtime.Gosched()
s.Close()
Eventually(done).Should(BeClosed())
})
Context("ConfigureTLSConfig", func() {
var tlsConf *tls.Config
var ch *tls.ClientHelloInfo