mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
don't close established connections on Listener.Close, when using a Transport (#4072)
* don't close established connections on Listener.Close * only close once
This commit is contained in:
parent
ef800d6f71
commit
dda63b90eb
12 changed files with 136 additions and 151 deletions
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Stream deadline tests", func() {
|
||||
setup := func() (*quic.Listener, quic.Stream, quic.Stream) {
|
||||
setup := func() (serverStr, clientStr quic.Stream, close func()) {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan := make(chan quic.SendStream)
|
||||
|
@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientStr, err := conn.OpenStream()
|
||||
clientStr, err = conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var serverStr quic.Stream
|
||||
Eventually(strChan).Should(Receive(&serverStr))
|
||||
return server, serverStr, clientStr
|
||||
return serverStr, clientStr, func() {
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
}
|
||||
}
|
||||
|
||||
Context("read deadlines", func() {
|
||||
It("completes a transfer when the deadline is set", func() {
|
||||
server, serverStr, clientStr := setup()
|
||||
defer server.Close()
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
done := make(chan struct{})
|
||||
|
@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() {
|
|||
})
|
||||
|
||||
It("completes a transfer when the deadline is set concurrently", func() {
|
||||
server, serverStr, clientStr := setup()
|
||||
defer server.Close()
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
go func() {
|
||||
|
@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() {
|
|||
|
||||
Context("write deadlines", func() {
|
||||
It("completes a transfer when the deadline is set", func() {
|
||||
server, serverStr, clientStr := setup()
|
||||
defer server.Close()
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
done := make(chan struct{})
|
||||
|
@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() {
|
|||
})
|
||||
|
||||
It("completes a transfer when the deadline is set concurrently", func() {
|
||||
server, serverStr, clientStr := setup()
|
||||
defer server.Close()
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
readDone := make(chan struct{})
|
||||
|
|
|
@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() {
|
|||
Context("Certificate validation", func() {
|
||||
It("accepts the certificate", func() {
|
||||
runServer(getTLSConfig())
|
||||
_, err := quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
|
||||
|
@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(server.Addr()).To(Equal(local))
|
||||
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
||||
|
@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
It("works with a long certificate chain", func() {
|
||||
runServer(getTLSConfigWithLongCertChain())
|
||||
_, err := quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It("errors if the server name doesn't match", func() {
|
||||
|
|
|
@ -52,22 +52,23 @@ var _ = Describe("TLS session resumption", func() {
|
|||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn, err := quic.DialAddr(
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
var sessionKey string
|
||||
Eventually(puts).Should(Receive(&sessionKey))
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn, err = quic.DialAddr(
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
|
@ -75,11 +76,12 @@ var _ = Describe("TLS session resumption", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(gets).To(Receive(Equal(sessionKey)))
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||
conn2.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It("doesn't use session resumption, if the config disables it", func() {
|
||||
|
@ -94,15 +96,16 @@ var _ = Describe("TLS session resumption", func() {
|
|||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn, err := quic.DialAddr(
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
Consistently(puts).ShouldNot(Receive())
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
@ -110,14 +113,15 @@ var _ = Describe("TLS session resumption", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn, err = quic.DialAddr(
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn2.CloseWithError(0, "")
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -142,7 +146,7 @@ var _ = Describe("TLS session resumption", func() {
|
|||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn, err := quic.DialAddr(
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
|
@ -150,7 +154,8 @@ var _ = Describe("TLS session resumption", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Consistently(puts).ShouldNot(Receive())
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
@ -158,14 +163,15 @@ var _ = Describe("TLS session resumption", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn, err = quic.DialAddr(
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn2.CloseWithError(0, "")
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
Expect(server.Close()).To(Succeed())
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
<-conn.Context().Done() // block until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
Expect(server.Close()).To(Succeed())
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
_, err = str.Write([]byte("foobar"))
|
||||
checkTimeoutError(err)
|
||||
|
||||
Expect(server.Close()).To(Succeed())
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
|
|
@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() {
|
|||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
|
|
|
@ -186,42 +186,6 @@ func (c *PacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *PacketHandle
|
|||
return c
|
||||
}
|
||||
|
||||
// CloseServer mocks base method.
|
||||
func (m *MockPacketHandlerManager) CloseServer() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "CloseServer")
|
||||
}
|
||||
|
||||
// CloseServer indicates an expected call of CloseServer.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *PacketHandlerManagerCloseServerCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
return &PacketHandlerManagerCloseServerCall{Call: call}
|
||||
}
|
||||
|
||||
// PacketHandlerManagerCloseServerCall wrap *gomock.Call
|
||||
type PacketHandlerManagerCloseServerCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *PacketHandlerManagerCloseServerCall) Return() *PacketHandlerManagerCloseServerCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *PacketHandlerManagerCloseServerCall) Do(f func()) *PacketHandlerManagerCloseServerCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *PacketHandlerManagerCloseServerCall) DoAndReturn(f func()) *PacketHandlerManagerCloseServerCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
|
|||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) CloseServer() {
|
||||
h.mutex.Lock()
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
handler.shutdown()
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Close(e error) {
|
||||
h.mutex.Lock()
|
||||
|
||||
|
|
|
@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("closes the server", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
if i%2 == 0 {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
} else {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
conn.EXPECT().shutdown()
|
||||
}
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.CloseServer()
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
testErr := errors.New("shutdown")
|
||||
|
|
62
server.go
62
server.go
|
@ -34,7 +34,6 @@ type packetHandlerManager interface {
|
|||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
||||
Close(error)
|
||||
CloseServer()
|
||||
connRunner
|
||||
}
|
||||
|
||||
|
@ -61,8 +60,6 @@ type rejectedPacket struct {
|
|||
|
||||
// A Listener of QUIC
|
||||
type baseServer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
disableVersionNegotiation bool
|
||||
acceptEarlyConns bool
|
||||
|
||||
|
@ -104,10 +101,11 @@ type baseServer struct {
|
|||
protocol.VersionNumber,
|
||||
) quicConn
|
||||
|
||||
serverError error
|
||||
errorChan chan struct{}
|
||||
closed bool
|
||||
running chan struct{} // closed as soon as run() returns
|
||||
closeOnce sync.Once
|
||||
errorChan chan struct{} // is closed when the server is closed
|
||||
closeErr error
|
||||
running chan struct{} // closed as soon as run() returns
|
||||
|
||||
versionNegotiationQueue chan receivedPacket
|
||||
invalidTokenQueue chan rejectedPacket
|
||||
connectionRefusedQueue chan rejectedPacket
|
||||
|
@ -132,7 +130,10 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) {
|
|||
return l.baseServer.Accept(ctx)
|
||||
}
|
||||
|
||||
// Close the server. All active connections will be closed.
|
||||
// Close closes the listener.
|
||||
// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
|
||||
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
|
||||
// Closing the listener doesn't have any effect on already established connections.
|
||||
func (l *Listener) Close() error {
|
||||
return l.baseServer.Close()
|
||||
}
|
||||
|
@ -321,38 +322,25 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
|||
atomic.AddInt32(&s.connQueueLen, -1)
|
||||
return conn, nil
|
||||
case <-s.errorChan:
|
||||
return nil, s.serverError
|
||||
return nil, s.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
// Close the server
|
||||
func (s *baseServer) Close() error {
|
||||
s.mutex.Lock()
|
||||
if s.closed {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
if s.serverError == nil {
|
||||
s.serverError = ErrServerClosed
|
||||
}
|
||||
s.closed = true
|
||||
close(s.errorChan)
|
||||
s.mutex.Unlock()
|
||||
|
||||
<-s.running
|
||||
s.onClose()
|
||||
s.close(ErrServerClosed, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *baseServer) setCloseError(e error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
s.serverError = e
|
||||
close(s.errorChan)
|
||||
func (s *baseServer) close(e error, notifyOnClose bool) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeErr = e
|
||||
close(s.errorChan)
|
||||
|
||||
<-s.running
|
||||
if notifyOnClose {
|
||||
s.onClose()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
|
@ -701,8 +689,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
func (s *baseServer) handleNewConn(conn quicConn) {
|
||||
connCtx := conn.Context()
|
||||
if s.acceptEarlyConns {
|
||||
// wait until the early connection is ready (or the handshake fails)
|
||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
return
|
||||
case <-conn.earlyConnReady():
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
|
@ -710,6 +701,9 @@ func (s *baseServer) handleNewConn(conn quicConn) {
|
|||
} else {
|
||||
// wait until the handshake is complete (or fails)
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
return
|
||||
case <-conn.HandshakeComplete():
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
|
|
|
@ -326,6 +326,8 @@ var _ = Describe("Server", func() {
|
|||
// make sure we're using a server-generated connection ID
|
||||
Eventually(run).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
})
|
||||
|
||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||
|
@ -527,6 +529,8 @@ var _ = Describe("Server", func() {
|
|||
// make sure we're using a server-generated connection ID
|
||||
Eventually(run).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
})
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
|
@ -565,6 +569,8 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().run().MaxTimes(1)
|
||||
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
|
||||
return conn
|
||||
}
|
||||
|
||||
|
@ -956,30 +962,69 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
Context("accepting connections", func() {
|
||||
It("returns Accept when an error occurs", func() {
|
||||
testErr := errors.New("test err")
|
||||
|
||||
It("returns Accept when closed", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(ErrServerClosed))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
serv.setCloseError(testErr)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("returns immediately, if an error occurred before", func() {
|
||||
testErr := errors.New("test err")
|
||||
serv.setCloseError(testErr)
|
||||
serv.Close()
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(ErrServerClosed))
|
||||
}
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("closes connection that are still handshaking after Close", func() {
|
||||
serv.Close()
|
||||
|
||||
destroyed := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().run().MaxTimes(1)
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
)
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns when the context is canceled", func() {
|
||||
|
@ -1343,10 +1388,7 @@ var _ = Describe("Server", func() {
|
|||
serv.connHandler = phm
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
tr.Close()
|
||||
})
|
||||
AfterEach(func() { tr.Close() })
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
@ -1425,6 +1467,8 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().earlyConnReady()
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
close(called)
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
return conn
|
||||
}
|
||||
|
||||
|
|
|
@ -275,7 +275,8 @@ func (t *Transport) runSendQueue() {
|
|||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying connection and waits until listen has returned.
|
||||
// Close closes the underlying connection.
|
||||
// If any listener was started, it will be closed as well.
|
||||
// It is invalid to start new listeners or connections after that.
|
||||
func (t *Transport) Close() error {
|
||||
t.close(errors.New("closing"))
|
||||
|
@ -294,7 +295,6 @@ func (t *Transport) Close() error {
|
|||
}
|
||||
|
||||
func (t *Transport) closeServer() {
|
||||
t.handlerMap.CloseServer()
|
||||
t.mutex.Lock()
|
||||
t.server = nil
|
||||
if t.isSingleUse {
|
||||
|
@ -322,7 +322,7 @@ func (t *Transport) close(e error) {
|
|||
t.handlerMap.Close(e)
|
||||
}
|
||||
if t.server != nil {
|
||||
t.server.setCloseError(e)
|
||||
t.server.close(e, false)
|
||||
}
|
||||
t.closed = true
|
||||
}
|
||||
|
|
|
@ -114,7 +114,6 @@ var _ = Describe("Transport", func() {
|
|||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
phm.EXPECT().CloseServer()
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
|
||||
// shutdown
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue