fix deadlock on concurrent http3.Server.Serve and Close calls (#3387)

This commit is contained in:
Marten Seemann 2022-04-25 12:10:39 +01:00 committed by GitHub
parent a6a9b2494b
commit 21eda36971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 49 deletions

View file

@ -160,12 +160,11 @@ type Server struct {
mutex sync.RWMutex mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo listeners map[*quic.EarlyListener]listenerInfo
closed utils.AtomicBool closed bool
altSvcHeader string altSvcHeader string
loggerOnce sync.Once logger utils.Logger
logger utils.Logger
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
@ -203,55 +202,54 @@ func (s *Server) Serve(conn net.PacketConn) error {
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener. // and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener. // Closing the server does close the listener.
func (s *Server) ServeListener(listener quic.EarlyListener) error { func (s *Server) ServeListener(ln quic.EarlyListener) error {
return s.serveImpl(func() (quic.EarlyListener, error) { return listener, nil })
}
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
return s.serveImpl(func() (quic.EarlyListener, error) {
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{}
} else {
quicConf = s.QuicConfig.Clone()
}
if s.EnableDatagrams {
quicConf.EnableDatagrams = true
}
var ln quic.EarlyListener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, baseConf, quicConf)
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil {
return nil, err
}
return ln, nil
})
}
func (s *Server) serveImpl(startListener func() (quic.EarlyListener, error)) error {
if s.closed.Get() {
return http.ErrServerClosed
}
if s.Server == nil { if s.Server == nil {
return errors.New("use of http3.Server without http.Server") return errors.New("use of http3.Server without http.Server")
} }
s.loggerOnce.Do(func() {
s.logger = utils.DefaultLogger.WithPrefix("server")
})
ln, err := startListener() if err := s.addListener(&ln); err != nil {
return err
}
err := s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of http3.Server without http.Server")
}
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{}
} else {
quicConf = s.QuicConfig.Clone()
}
if s.EnableDatagrams {
quicConf.EnableDatagrams = true
}
var ln quic.EarlyListener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, baseConf, quicConf)
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil { if err != nil {
return err return err
} }
s.addListener(&ln) if err := s.addListener(&ln); err != nil {
defer s.removeListener(&ln) return err
}
err = s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveListener(ln quic.EarlyListener) error {
for { for {
conn, err := ln.Accept(context.Background()) conn, err := ln.Accept(context.Background())
if err != nil { if err != nil {
@ -327,8 +325,16 @@ func (s *Server) generateAltSvcHeader() {
// We store a pointer to interface in the map set. This is safe because we only // We store a pointer to interface in the map set. This is safe because we only
// call trackListener via Serve and can track+defer untrack the same pointer to // call trackListener via Serve and can track+defer untrack the same pointer to
// local variable there. We never need to compare a Listener from another caller. // local variable there. We never need to compare a Listener from another caller.
func (s *Server) addListener(l *quic.EarlyListener) { func (s *Server) addListener(l *quic.EarlyListener) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return http.ErrServerClosed
}
if s.logger == nil {
s.logger = utils.DefaultLogger.WithPrefix("server")
}
if s.listeners == nil { if s.listeners == nil {
s.listeners = make(map[*quic.EarlyListener]listenerInfo) s.listeners = make(map[*quic.EarlyListener]listenerInfo)
} }
@ -341,8 +347,7 @@ func (s *Server) addListener(l *quic.EarlyListener) {
s.listeners[l] = listenerInfo{} s.listeners[l] = listenerInfo{}
} }
s.generateAltSvcHeader() s.generateAltSvcHeader()
return nil
s.mutex.Unlock()
} }
func (s *Server) removeListener(l *quic.EarlyListener) { func (s *Server) removeListener(l *quic.EarlyListener) {
@ -553,11 +558,11 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error { func (s *Server) Close() error {
s.closed.Set(true)
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.closed = true
var err error var err error
for ln := range s.listeners { for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil { if cerr := (*ln).Close(); cerr != nil && err == nil {

View file

@ -9,6 +9,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"runtime"
"sync/atomic" "sync/atomic"
"time" "time"
@ -775,6 +776,22 @@ var _ = Describe("Server", func() {
Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed))
}) })
It("handles concurrent Serve and Close", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
s.Serve(c)
}()
runtime.Gosched()
s.Close()
Eventually(done).Should(BeClosed())
})
Context("ConfigureTLSConfig", func() { Context("ConfigureTLSConfig", func() {
var tlsConf *tls.Config var tlsConf *tls.Config
var ch *tls.ClientHelloInfo var ch *tls.ClientHelloInfo