don't close established connections on Listener.Close, when using a Transport (#4072)

* don't close established connections on Listener.Close

* only close once
This commit is contained in:
Marten Seemann 2023-10-27 13:10:13 +07:00 committed by GitHub
parent ef800d6f71
commit dda63b90eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 136 additions and 151 deletions

View file

@ -14,7 +14,7 @@ import (
)
var _ = Describe("Stream deadline tests", func() {
setup := func() (*quic.Listener, quic.Stream, quic.Stream) {
setup := func() (serverStr, clientStr quic.Stream, close func()) {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
strChan := make(chan quic.SendStream)
@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
clientStr, err := conn.OpenStream()
clientStr, err = conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
Expect(err).ToNot(HaveOccurred())
var serverStr quic.Stream
Eventually(strChan).Should(Receive(&serverStr))
return server, serverStr, clientStr
return serverStr, clientStr, func() {
Expect(server.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
}
}
Context("read deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
done := make(chan struct{})
@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() {
})
It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
go func() {
@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() {
Context("write deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
done := make(chan struct{})
@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() {
})
It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
readDone := make(chan struct{})

View file

@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() {
Context("Certificate validation", func() {
It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() {
It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
It("errors if the server name doesn't match", func() {

View file

@ -52,22 +52,23 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
@ -75,11 +76,12 @@ var _ = Describe("TLS session resumption", func() {
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
conn2.CloseWithError(0, "")
})
It("doesn't use session resumption, if the config disables it", func() {
@ -94,15 +96,16 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
@ -110,14 +113,15 @@ var _ = Describe("TLS session resumption", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
@ -142,7 +146,7 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
@ -150,7 +154,8 @@ var _ = Describe("TLS session resumption", func() {
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn1.CloseWithError(0, "")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
@ -158,14 +163,15 @@ var _ = Describe("TLS session resumption", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())

View file

@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())
// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
<-conn.Context().Done() // block until the connection is closed
close(serverConnClosed)
}()
@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())
// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
})
@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() {
_, err = str.Write([]byte("foobar"))
checkTimeoutError(err)
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})

View file

@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() {
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})

View file

@ -186,42 +186,6 @@ func (c *PacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *PacketHandle
return c
}
// CloseServer mocks base method.
func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "CloseServer")
}
// CloseServer indicates an expected call of CloseServer.
func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *PacketHandlerManagerCloseServerCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
return &PacketHandlerManagerCloseServerCall{Call: call}
}
// PacketHandlerManagerCloseServerCall wrap *gomock.Call
type PacketHandlerManagerCloseServerCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *PacketHandlerManagerCloseServerCall) Return() *PacketHandlerManagerCloseServerCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *PacketHandlerManagerCloseServerCall) Do(f func()) *PacketHandlerManagerCloseServerCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *PacketHandlerManagerCloseServerCall) DoAndReturn(f func()) *PacketHandlerManagerCloseServerCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Get mocks base method.
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
m.ctrl.T.Helper()

View file

@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
return handler, ok
}
func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(handler)
}
}
h.mutex.Unlock()
wg.Wait()
}
func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()

View file

@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})
It("closes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")

View file

@ -34,7 +34,6 @@ type packetHandlerManager interface {
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
Close(error)
CloseServer()
connRunner
}
@ -61,8 +60,6 @@ type rejectedPacket struct {
// A Listener of QUIC
type baseServer struct {
mutex sync.Mutex
disableVersionNegotiation bool
acceptEarlyConns bool
@ -104,10 +101,11 @@ type baseServer struct {
protocol.VersionNumber,
) quicConn
serverError error
errorChan chan struct{}
closed bool
running chan struct{} // closed as soon as run() returns
closeOnce sync.Once
errorChan chan struct{} // is closed when the server is closed
closeErr error
running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan rejectedPacket
connectionRefusedQueue chan rejectedPacket
@ -132,7 +130,10 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) {
return l.baseServer.Accept(ctx)
}
// Close the server. All active connections will be closed.
// Close closes the listener.
// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
// Closing the listener doesn't have any effect on already established connections.
func (l *Listener) Close() error {
return l.baseServer.Close()
}
@ -321,38 +322,25 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
atomic.AddInt32(&s.connQueueLen, -1)
return conn, nil
case <-s.errorChan:
return nil, s.serverError
return nil, s.closeErr
}
}
// Close the server
func (s *baseServer) Close() error {
s.mutex.Lock()
if s.closed {
s.mutex.Unlock()
return nil
}
if s.serverError == nil {
s.serverError = ErrServerClosed
}
s.closed = true
close(s.errorChan)
s.mutex.Unlock()
<-s.running
s.onClose()
s.close(ErrServerClosed, true)
return nil
}
func (s *baseServer) setCloseError(e error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return
}
s.closed = true
s.serverError = e
close(s.errorChan)
func (s *baseServer) close(e error, notifyOnClose bool) {
s.closeOnce.Do(func() {
s.closeErr = e
close(s.errorChan)
<-s.running
if notifyOnClose {
s.onClose()
}
})
}
// Addr returns the server's network address
@ -701,8 +689,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
func (s *baseServer) handleNewConn(conn quicConn) {
connCtx := conn.Context()
if s.acceptEarlyConns {
// wait until the early connection is ready (or the handshake fails)
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
case <-conn.earlyConnReady():
case <-connCtx.Done():
return
@ -710,6 +701,9 @@ func (s *baseServer) handleNewConn(conn quicConn) {
} else {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return

View file

@ -326,6 +326,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
@ -527,6 +529,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
})
It("drops packets if the receive queue is full", func() {
@ -565,6 +569,8 @@ var _ = Describe("Server", func() {
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
return conn
}
@ -956,30 +962,69 @@ var _ = Describe("Server", func() {
})
Context("accepting connections", func() {
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")
It("returns Accept when closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
close(done)
}()
serv.setCloseError(testErr)
serv.Close()
Eventually(done).Should(BeClosed())
serv.onClose() // shutdown
})
It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
serv.Close()
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
}
serv.onClose() // shutdown
})
It("closes connection that are still handshaking after Close", func() {
serv.Close()
destroyed := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Eventually(destroyed).Should(BeClosed())
})
It("returns when the context is canceled", func() {
@ -1343,10 +1388,7 @@ var _ = Describe("Server", func() {
serv.connHandler = phm
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
tr.Close()
})
AfterEach(func() { tr.Close() })
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
@ -1425,6 +1467,8 @@ var _ = Describe("Server", func() {
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().destroy(gomock.Any())
return conn
}

View file

@ -275,7 +275,8 @@ func (t *Transport) runSendQueue() {
}
}
// Close closes the underlying connection and waits until listen has returned.
// Close closes the underlying connection.
// If any listener was started, it will be closed as well.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
@ -294,7 +295,6 @@ func (t *Transport) Close() error {
}
func (t *Transport) closeServer() {
t.handlerMap.CloseServer()
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
@ -322,7 +322,7 @@ func (t *Transport) close(e error) {
t.handlerMap.Close(e)
}
if t.server != nil {
t.server.setCloseError(e)
t.server.close(e, false)
}
t.closed = true
}

View file

@ -114,7 +114,6 @@ var _ = Describe("Transport", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
phm.EXPECT().CloseServer()
Expect(ln.Close()).To(Succeed())
// shutdown