diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index b165aff0..43e3b247 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -14,7 +14,7 @@ import ( ) var _ = Describe("Stream deadline tests", func() { - setup := func() (*quic.Listener, quic.Stream, quic.Stream) { + setup := func() (serverStr, clientStr quic.Stream, close func()) { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) strChan := make(chan quic.SendStream) @@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - clientStr, err := conn.OpenStream() + clientStr, err = conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream Expect(err).ToNot(HaveOccurred()) - var serverStr quic.Stream Eventually(strChan).Should(Receive(&serverStr)) - return server, serverStr, clientStr + return serverStr, clientStr, func() { + Expect(server.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + } } Context("read deadlines", func() { It("completes a transfer when the deadline is set", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond done := make(chan struct{}) @@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() { }) It("completes a transfer when the deadline is set concurrently", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond go func() { @@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() { Context("write deadlines", func() { It("completes a transfer when the deadline is set", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond done := make(chan struct{}) @@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() { }) It("completes a transfer when the deadline is set concurrently", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond readDone := make(chan struct{}) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index d4f65a12..141487ab 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() { Context("Certificate validation", func() { It("accepts the certificate", func() { runServer(getTLSConfig()) - _, err := quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") }) It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { @@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") Eventually(done).Should(BeClosed()) Expect(server.Addr()).To(Equal(local)) Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) @@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() { It("works with a long certificate chain", func() { runServer(getTLSConfigWithLongCertChain()) - _, err := quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") }) It("errors if the server name doesn't match", func() { diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index bedf0470..23f241be 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -52,22 +52,23 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") var sessionKey string Eventually(puts).Should(Receive(&sessionKey)) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) serverConn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, @@ -75,11 +76,12 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) - Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue()) serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + conn2.CloseWithError(0, "") }) It("doesn't use session resumption, if the config disables it", func() { @@ -94,15 +96,16 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") Consistently(puts).ShouldNot(Receive()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -110,14 +113,15 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn2.CloseWithError(0, "") serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) @@ -142,7 +146,7 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, @@ -150,7 +154,8 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn1.CloseWithError(0, "") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -158,14 +163,15 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn2.CloseWithError(0, "") serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index dd18c4ea..da4c6883 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn conn.AcceptStream(context.Background()) // blocks until the connection is closed close(serverConnClosed) }() @@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() { Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) @@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn <-conn.Context().Done() // block until the connection is closed close(serverConnClosed) }() @@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() { Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) }) @@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn conn.AcceptStream(context.Background()) // blocks until the connection is closed close(serverConnClosed) }() @@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() { _, err = str.Write([]byte("foobar")) checkTimeoutError(err) - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index a2fe4e50..d47df35f 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() { runReceivingPeer(client) <-done1 <-done2 + client.CloseWithError(0, "") }) }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index fafd43c9..e3f1e00c 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -186,42 +186,6 @@ func (c *PacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *PacketHandle return c } -// CloseServer mocks base method. -func (m *MockPacketHandlerManager) CloseServer() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseServer") -} - -// CloseServer indicates an expected call of CloseServer. -func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *PacketHandlerManagerCloseServerCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) - return &PacketHandlerManagerCloseServerCall{Call: call} -} - -// PacketHandlerManagerCloseServerCall wrap *gomock.Call -type PacketHandlerManagerCloseServerCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *PacketHandlerManagerCloseServerCall) Return() *PacketHandlerManagerCloseServerCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *PacketHandlerManagerCloseServerCall) Do(f func()) *PacketHandlerManagerCloseServerCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *PacketHandlerManagerCloseServerCall) DoAndReturn(f func()) *PacketHandlerManagerCloseServerCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // Get mocks base method. func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 006eadf9..ba62b1e5 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) ( return handler, ok } -func (h *packetHandlerMap) CloseServer() { - h.mutex.Lock() - var wg sync.WaitGroup - for _, handler := range h.handlers { - if handler.getPerspective() == protocol.PerspectiveServer { - wg.Add(1) - go func(handler packetHandler) { - // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - handler.shutdown() - wg.Done() - }(handler) - } - } - h.mutex.Unlock() - wg.Wait() -} - func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 24cef871..d40c395f 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() { Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) }) - It("closes the server", func() { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) - for i := 0; i < 10; i++ { - conn := NewMockPacketHandler(mockCtrl) - if i%2 == 0 { - conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - } else { - conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - conn.EXPECT().shutdown() - } - b := make([]byte, 12) - rand.Read(b) - m.Add(protocol.ParseConnectionID(b), conn) - } - m.CloseServer() - }) - It("closes", func() { m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) testErr := errors.New("shutdown") diff --git a/server.go b/server.go index 8353cdeb..f7c638b9 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,6 @@ type packetHandlerManager interface { GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool Close(error) - CloseServer() connRunner } @@ -61,8 +60,6 @@ type rejectedPacket struct { // A Listener of QUIC type baseServer struct { - mutex sync.Mutex - disableVersionNegotiation bool acceptEarlyConns bool @@ -104,10 +101,11 @@ type baseServer struct { protocol.VersionNumber, ) quicConn - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns + closeOnce sync.Once + errorChan chan struct{} // is closed when the server is closed + closeErr error + running chan struct{} // closed as soon as run() returns + versionNegotiationQueue chan receivedPacket invalidTokenQueue chan rejectedPacket connectionRefusedQueue chan rejectedPacket @@ -132,7 +130,10 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) { return l.baseServer.Accept(ctx) } -// Close the server. All active connections will be closed. +// Close closes the listener. +// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted. +// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error. +// Closing the listener doesn't have any effect on already established connections. func (l *Listener) Close() error { return l.baseServer.Close() } @@ -321,38 +322,25 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) { atomic.AddInt32(&s.connQueueLen, -1) return conn, nil case <-s.errorChan: - return nil, s.serverError + return nil, s.closeErr } } -// Close the server func (s *baseServer) Close() error { - s.mutex.Lock() - if s.closed { - s.mutex.Unlock() - return nil - } - if s.serverError == nil { - s.serverError = ErrServerClosed - } - s.closed = true - close(s.errorChan) - s.mutex.Unlock() - - <-s.running - s.onClose() + s.close(ErrServerClosed, true) return nil } -func (s *baseServer) setCloseError(e error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.closed { - return - } - s.closed = true - s.serverError = e - close(s.errorChan) +func (s *baseServer) close(e error, notifyOnClose bool) { + s.closeOnce.Do(func() { + s.closeErr = e + close(s.errorChan) + + <-s.running + if notifyOnClose { + s.onClose() + } + }) } // Addr returns the server's network address @@ -701,8 +689,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error func (s *baseServer) handleNewConn(conn quicConn) { connCtx := conn.Context() if s.acceptEarlyConns { - // wait until the early connection is ready (or the handshake fails) + // wait until the early connection is ready, the handshake fails, or the server is closed select { + case <-s.errorChan: + conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) + return case <-conn.earlyConnReady(): case <-connCtx.Done(): return @@ -710,6 +701,9 @@ func (s *baseServer) handleNewConn(conn quicConn) { } else { // wait until the handshake is complete (or fails) select { + case <-s.errorChan: + conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) + return case <-conn.HandshakeComplete(): case <-connCtx.Done(): return diff --git a/server_test.go b/server_test.go index 6b250c33..eeea3b8c 100644 --- a/server_test.go +++ b/server_test.go @@ -326,6 +326,8 @@ var _ = Describe("Server", func() { // make sure we're using a server-generated connection ID Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) + // shutdown + conn.EXPECT().destroy(gomock.Any()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { @@ -527,6 +529,8 @@ var _ = Describe("Server", func() { // make sure we're using a server-generated connection ID Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) + // shutdown + conn.EXPECT().destroy(gomock.Any()) }) It("drops packets if the receive queue is full", func() { @@ -565,6 +569,8 @@ var _ = Describe("Server", func() { conn.EXPECT().run().MaxTimes(1) conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) + // shutdown + conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) return conn } @@ -956,30 +962,69 @@ var _ = Describe("Server", func() { }) Context("accepting connections", func() { - It("returns Accept when an error occurs", func() { - testErr := errors.New("test err") - + It("returns Accept when closed", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(ErrServerClosed)) close(done) }() - serv.setCloseError(testErr) + serv.Close() Eventually(done).Should(BeClosed()) - serv.onClose() // shutdown }) It("returns immediately, if an error occurred before", func() { - testErr := errors.New("test err") - serv.setCloseError(testErr) + serv.Close() for i := 0; i < 3; i++ { _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(ErrServerClosed)) } - serv.onClose() // shutdown + }) + + It("closes connection that are still handshaking after Close", func() { + serv.Close() + + destroyed := make(chan struct{}) + serv.newConn = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ protocol.StatelessResetToken, + conf *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ *logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) }) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().run().MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + phm.EXPECT().Get(gomock.Any()) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok + }) + serv.handleInitialImpl( + receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, + ) + Eventually(destroyed).Should(BeClosed()) }) It("returns when the context is canceled", func() { @@ -1343,10 +1388,7 @@ var _ = Describe("Server", func() { serv.connHandler = phm }) - AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) - tr.Close() - }) + AfterEach(func() { tr.Close() }) It("passes packets to existing connections", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) @@ -1425,6 +1467,8 @@ var _ = Describe("Server", func() { conn.EXPECT().earlyConnReady() conn.EXPECT().Context().Return(context.Background()) close(called) + // shutdown + conn.EXPECT().destroy(gomock.Any()) return conn } diff --git a/transport.go b/transport.go index 60d44a43..38d35238 100644 --- a/transport.go +++ b/transport.go @@ -275,7 +275,8 @@ func (t *Transport) runSendQueue() { } } -// Close closes the underlying connection and waits until listen has returned. +// Close closes the underlying connection. +// If any listener was started, it will be closed as well. // It is invalid to start new listeners or connections after that. func (t *Transport) Close() error { t.close(errors.New("closing")) @@ -294,7 +295,6 @@ func (t *Transport) Close() error { } func (t *Transport) closeServer() { - t.handlerMap.CloseServer() t.mutex.Lock() t.server = nil if t.isSingleUse { @@ -322,7 +322,7 @@ func (t *Transport) close(e error) { t.handlerMap.Close(e) } if t.server != nil { - t.server.setCloseError(e) + t.server.close(e, false) } t.closed = true } diff --git a/transport_test.go b/transport_test.go index 5fa63213..c93d1da9 100644 --- a/transport_test.go +++ b/transport_test.go @@ -114,7 +114,6 @@ var _ = Describe("Transport", func() { phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm - phm.EXPECT().CloseServer() Expect(ln.Close()).To(Succeed()) // shutdown