diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index 0f5123ca..70b0264c 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -79,7 +79,7 @@ func init() { b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds()) ln.Close() - sess.Close(nil) + sess.Close() }, samples) }) } diff --git a/client.go b/client.go index 63efd46e..551800e4 100644 --- a/client.go +++ b/client.go @@ -336,8 +336,8 @@ func (c *client) establishSecureConnection(ctx context.Context) error { select { case <-ctx.Done(): - // The session sending a PeerGoingAway error to the server. - c.session.Close(nil) + // The session will send a PeerGoingAway error to the server. + c.session.Close() return ctx.Err() case err := <-errorChan: return err @@ -366,7 +366,7 @@ func (c *client) handlePacketImpl(p *receivedPacket) error { // version negotiation packets have no payload if err := c.handleVersionNegotiationPacket(p.header); err != nil { - c.session.Close(err) + c.session.destroy(err) } return nil } @@ -474,7 +474,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) - c.session.Close(errCloseSessionForNewVersion) + c.session.destroy(errCloseSessionForNewVersion) return nil } @@ -526,13 +526,13 @@ func (c *client) createNewTLSSession( return err } -func (c *client) Close(err error) error { +func (c *client) Close() error { c.mutex.Lock() defer c.mutex.Unlock() if c.session == nil { return nil } - return c.session.Close(err) + return c.session.Close() } func (c *client) GetVersion() protocol.VersionNumber { diff --git a/client_multiplexer.go b/client_multiplexer.go index 9784e01b..5a19b338 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -93,7 +93,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, p *connManager) { n, addr, err := c.ReadFrom(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - p.manager.Close(err) + p.manager.Close() } return } diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index 56038841..ab3fe5e8 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -39,7 +39,7 @@ var _ = Describe("Client Multiplexer", func() { conn.dataToRead <- getPacket(connID) Eventually(handledPacket).Should(BeClosed()) // makes the listen go routine return - packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() + packetHandler.EXPECT().Close().AnyTimes() close(conn.dataToRead) }) @@ -85,8 +85,8 @@ var _ = Describe("Client Multiplexer", func() { Eventually(handledPacket2).Should(BeClosed()) // makes the listen go routine return - packetHandler1.EXPECT().Close(gomock.Any()).AnyTimes() - packetHandler2.EXPECT().Close(gomock.Any()).AnyTimes() + packetHandler1.EXPECT().Close().AnyTimes() + packetHandler2.EXPECT().Close().AnyTimes() close(conn.dataToRead) }) @@ -114,11 +114,10 @@ var _ = Describe("Client Multiplexer", func() { It("closes the packet handlers when reading from the conn fails", func() { conn := newMockPacketConn() - testErr := errors.New("test error") - conn.readErr = testErr + conn.readErr = errors.New("test error") done := make(chan struct{}) packetHandler := NewMockQuicSession(mockCtrl) - packetHandler.EXPECT().Close(testErr).Do(func(error) { + packetHandler.EXPECT().Close().Do(func() { close(done) }) getClientMultiplexer().AddConn(conn, 8) diff --git a/client_test.go b/client_test.go index 329689bb..e79bfc06 100644 --- a/client_test.go +++ b/client_test.go @@ -86,7 +86,7 @@ var _ = Describe("Client", func() { AfterEach(func() { if s, ok := cl.session.(*session); ok { - s.Close(nil) + s.Close() } Eventually(areSessionsRunning).Should(BeFalse()) }) @@ -254,7 +254,7 @@ var _ = Describe("Client", func() { close(dialed) }() Consistently(dialed).ShouldNot(BeClosed()) - sess.EXPECT().Close(nil) + sess.EXPECT().Close() cancel() Eventually(dialed).Should(BeClosed()) }) @@ -493,7 +493,7 @@ var _ = Describe("Client", func() { sess1 := NewMockQuicSession(mockCtrl) run1 := make(chan struct{}) sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion) - sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) + sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) sess2 := NewMockQuicSession(mockCtrl) sess2.EXPECT().run() sessionChan := make(chan *MockQuicSession, 2) @@ -538,7 +538,7 @@ var _ = Describe("Client", func() { sess1 := NewMockQuicSession(mockCtrl) run1 := make(chan struct{}) sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion) - sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) + sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) sess2 := NewMockQuicSession(mockCtrl) sess2.EXPECT().run() sessionChan := make(chan *MockQuicSession, 2) @@ -578,7 +578,7 @@ var _ = Describe("Client", func() { It("errors if no matching version is found", func() { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().Close(gomock.Any()) + sess.EXPECT().destroy(qerr.InvalidVersion) cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1})) @@ -586,7 +586,7 @@ var _ = Describe("Client", func() { It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().Close(gomock.Any()) + sess.EXPECT().destroy(qerr.InvalidVersion) cl.session = sess v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) @@ -597,7 +597,7 @@ var _ = Describe("Client", func() { It("changes to the version preferred by the quic.Config", func() { mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().Close(errCloseSessionForNewVersion) + sess.EXPECT().destroy(errCloseSessionForNewVersion) cl.session = sess versions := []protocol.VersionNumber{1234, 4321} cl.config = &Config{Versions: versions} diff --git a/h2quic/client.go b/h2quic/client.go index af4ad11a..9ab7a65e 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -280,11 +280,14 @@ func (c *client) CloseWithError(e error) error { if c.session == nil { return nil } - return c.session.Close(e) + return c.session.CloseWithError(e) } func (c *client) Close() error { - return c.CloseWithError(nil) + if c.session == nil { + return nil + } + return c.session.Close() } // copied from net/transport.go diff --git a/h2quic/server.go b/h2quic/server.go index 0f3ad6dc..2b76b0c4 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { func (s *Server) handleHeaderStream(session streamCreator) { stream, err := session.AcceptStream() if err != nil { - session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) + session.CloseWithError(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) return } @@ -143,7 +143,7 @@ func (s *Server) handleHeaderStream(session streamCreator) { if _, ok := err.(*qerr.QuicError); !ok { s.logger.Errorf("error handling h2 request: %s", err.Error()) } - session.Close(err) + session.CloseWithError(err) return } } @@ -246,7 +246,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, } if s.CloseAfterFirstRequest { time.Sleep(100 * time.Millisecond) - session.Close(nil) + session.Close() } }() diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 28597594..69a8c96c 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -61,8 +61,7 @@ func (s *mockSession) OpenStreamSync() (quic.Stream, error) { } return s.OpenStream() } -func (s *mockSession) Close(e error) error { - s.closedWithError = e +func (s *mockSession) Close() error { s.ctxCancel() if !s.closed { close(s.blockOpenStreamChan) @@ -70,6 +69,10 @@ func (s *mockSession) Close(e error) error { s.closed = true return nil } +func (s *mockSession) CloseWithError(e error) error { + s.closedWithError = e + return s.Close() +} func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index cfa9ec6d..991154f2 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() { conf, ) Expect(err).ToNot(HaveOccurred()) - defer cl.Close(nil) + defer cl.Close() str, err := cl.AcceptStream() Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 38af3b33..dd83c9dd 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() { defer GinkgoRecover() sess, err := ln.Accept() Expect(err).ToNot(HaveOccurred()) - defer sess.Close(nil) + defer sess.Close() str, err := sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) @@ -83,8 +83,8 @@ var _ = Describe("Handshake drop tests", func() { var serverSession quic.Session Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession)) - sess.Close(nil) - serverSession.Close(nil) + sess.Close() + serverSession.Close() }, } @@ -117,8 +117,8 @@ var _ = Describe("Handshake drop tests", func() { var serverSession quic.Session Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession)) - sess.Close(nil) - serverSession.Close(nil) + sess.Close() + serverSession.Close() }, } @@ -141,8 +141,8 @@ var _ = Describe("Handshake drop tests", func() { var serverSession quic.Session Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession)) // both server and client accepted a session. Close now. - sess.Close(nil) - serverSession.Close(nil) + sess.Close() + serverSession.Close() }, } diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index c0b842d3..40cca5b9 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -74,7 +74,7 @@ var _ = Describe("non-zero RTT", func() { data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testserver.PRData)) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) } diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index a7b3b565..64039698 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -109,7 +109,7 @@ var _ = Describe("Bidirectional streams", func() { sess, err := server.Accept() Expect(err).ToNot(HaveOccurred()) runSendingPeer(sess) - sess.Close(nil) + sess.Close() }() client, err := quic.DialAddr(serverAddr, nil, qconf) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 562ab716..ba3d7e4d 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -78,7 +78,7 @@ var _ = Describe("Unidirectional Streams", func() { sess, err = server.Accept() Expect(err).ToNot(HaveOccurred()) runReceivingPeer(sess) - sess.Close(nil) + sess.Close() }() client, err := quic.DialAddr(serverAddr, nil, qconf) diff --git a/interface.go b/interface.go index b1f19e52..3eac4c7b 100644 --- a/interface.go +++ b/interface.go @@ -145,8 +145,10 @@ type Session interface { LocalAddr() net.Addr // RemoteAddr returns the address of the peer. RemoteAddr() net.Addr - // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. - Close(error) error + // Close the connection. + io.Closer + // Close the connection with an error. + CloseWithError(error) error // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 6bb1b7b4..5cd722cb 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -45,13 +45,15 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom } // Close mocks base method -func (m *MockPacketHandlerManager) Close(arg0 error) { - m.ctrl.Call(m, "Close", arg0) +func (m *MockPacketHandlerManager) Close() error { + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 } // Close indicates an expected call of Close -func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) +func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close)) } // Get mocks base method diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 4362efe6..eb3e42e5 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -64,15 +64,27 @@ func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call { } // Close mocks base method -func (m *MockQuicSession) Close(arg0 error) error { - ret := m.ctrl.Call(m, "Close", arg0) +func (m *MockQuicSession) Close() error { + ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close -func (mr *MockQuicSessionMockRecorder) Close(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQuicSession)(nil).Close), arg0) +func (mr *MockQuicSessionMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQuicSession)(nil).Close)) +} + +// CloseWithError mocks base method +func (m *MockQuicSession) CloseWithError(arg0 error) error { + ret := m.ctrl.Call(m, "CloseWithError", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWithError indicates an expected call of CloseWithError +func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0) } // ConnectionState mocks base method @@ -197,6 +209,16 @@ func (mr *MockQuicSessionMockRecorder) closeRemote(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeRemote", reflect.TypeOf((*MockQuicSession)(nil).closeRemote), arg0) } +// destroy mocks base method +func (m *MockQuicSession) destroy(arg0 error) { + m.ctrl.Call(m, "destroy", arg0) +} + +// destroy indicates an expected call of destroy +func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0) +} + // getCryptoStream mocks base method func (m *MockQuicSession) getCryptoStream() cryptoStreamI { ret := m.ctrl.Call(m, "getCryptoStream") diff --git a/packet_handler_map.go b/packet_handler_map.go index b8e5038f..82a3a73b 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -54,11 +54,11 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { }) } -func (h *packetHandlerMap) Close(err error) { +func (h *packetHandlerMap) Close() error { h.mutex.Lock() if h.closed { h.mutex.Unlock() - return + return nil } h.closed = true @@ -68,11 +68,12 @@ func (h *packetHandlerMap) Close(err error) { wg.Add(1) go func(handler packetHandler) { // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = handler.Close(err) + _ = handler.Close() wg.Done() }(handler) } } h.mutex.Unlock() wg.Wait() + return nil } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 380132ea..18ee9d9a 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -1,7 +1,6 @@ package quic import ( - "errors" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -46,13 +45,12 @@ var _ = Describe("Packet Handler Map", func() { }) It("closes", func() { - testErr := errors.New("test error") sess1 := NewMockQuicSession(mockCtrl) - sess1.EXPECT().Close(testErr) + sess1.EXPECT().Close() sess2 := NewMockQuicSession(mockCtrl) - sess2.EXPECT().Close(testErr) + sess2.EXPECT().Close() handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) - handler.Close(testErr) + handler.Close() }) }) diff --git a/server.go b/server.go index 8b560c99..4cf54c4e 100644 --- a/server.go +++ b/server.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net" "time" @@ -19,15 +20,15 @@ import ( // packetHandler handles packets type packetHandler interface { handlePacket(*receivedPacket) - Close(error) error GetVersion() protocol.VersionNumber + io.Closer } type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Get(protocol.ConnectionID) (packetHandler, bool) Remove(protocol.ConnectionID) - Close(error) + io.Closer } type quicSession interface { @@ -36,6 +37,7 @@ type quicSession interface { getCryptoStream() cryptoStreamI GetVersion() protocol.VersionNumber run() error + destroy(error) closeRemote(error) } @@ -294,7 +296,7 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { - s.sessionHandler.Close(nil) + s.sessionHandler.Close() err := s.conn.Close() <-s.errorChan // wait for serve() to return return err diff --git a/server_test.go b/server_test.go index 00fe8e52..7c86014d 100644 --- a/server_test.go +++ b/server_test.go @@ -220,7 +220,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sessionHandler.EXPECT().Close(nil) + sessionHandler.EXPECT().Close() close(serv.errorChan) serv.Close() Eventually(done).Should(BeClosed()) @@ -242,7 +242,7 @@ var _ = Describe("Server", func() { serv.serve() }() // close the server - sessionHandler.EXPECT().Close(nil).AnyTimes() + sessionHandler.EXPECT().Close().AnyTimes() Expect(serv.Close()).To(Succeed()) Expect(conn.closed).To(BeTrue()) }) @@ -279,7 +279,7 @@ var _ = Describe("Server", func() { It("errors when encountering a connection error", func() { testErr := errors.New("connection error") conn.readErr = testErr - sessionHandler.EXPECT().Close(nil) + sessionHandler.EXPECT().Close() done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/session.go b/session.go index 347e8170..4ed42d28 100644 --- a/session.go +++ b/session.go @@ -67,8 +67,9 @@ var ( ) type closeError struct { - err error - remote bool + err error + remote bool + sendClose bool } // A Session is a QUIC session @@ -441,7 +442,11 @@ func (s *session) run() error { go func() { if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil { - s.Close(err) + if err == handshake.ErrCloseSessionForRetry { + s.destroy(err) + } else { + s.closeLocal(err) + } } }() @@ -825,9 +830,17 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return nil } +// closeLocal closes the session and send a CONNECTION_CLOSE containing the error func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { - s.closeChan <- closeError{err: e, remote: false} + s.closeChan <- closeError{err: e, sendClose: true, remote: false} + }) +} + +// destroy closes the session without sending the error on the wire +func (s *session) destroy(e error) { + s.closeOnce.Do(func() { + s.closeChan <- closeError{err: e, sendClose: false, remote: false} }) } @@ -837,9 +850,13 @@ func (s *session) closeRemote(e error) { }) } -// Close the connection. If err is nil it will be set to qerr.PeerGoingAway. +// Close the connection. It sends a qerr.PeerGoingAway. // It waits until the run loop has stopped before returning -func (s *session) Close(e error) error { +func (s *session) Close() error { + return s.CloseWithError(nil) +} + +func (s *session) CloseWithError(e error) error { s.closeLocal(e) <-s.ctx.Done() return nil @@ -865,7 +882,7 @@ func (s *session) handleCloseError(closeErr closeError) error { s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) - if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { + if !closeErr.sendClose { return nil } @@ -1243,7 +1260,7 @@ func (s *session) onHasStreamData(id protocol.StreamID) { func (s *session) onStreamCompleted(id protocol.StreamID) { if err := s.streamsMap.DeleteStream(id); err != nil { - s.Close(err) + s.closeLocal(err) } } diff --git a/session_test.go b/session_test.go index f6225692..cc5820d7 100644 --- a/session_test.go +++ b/session_test.go @@ -513,7 +513,7 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) buf := &bytes.Buffer{} @@ -526,8 +526,8 @@ var _ = Describe("Session", func() { It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) - sess.Close(nil) + sess.Close() + sess.Close() Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) Expect(sess.Context().Done()).To(BeClosed()) @@ -537,7 +537,7 @@ var _ = Describe("Session", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(testErr) + sess.CloseWithError(testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) }) @@ -545,7 +545,7 @@ var _ = Describe("Session", func() { It("closes the session in order to replace it with another QUIC version", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(errCloseSessionForNewVersion) + sess.destroy(errCloseSessionForNewVersion) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent }) @@ -553,8 +553,8 @@ var _ = Describe("Session", func() { It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(handshake.ErrNSTPExperiment) - Expect(mconn.written).To(HaveLen(1)) + sess.closeLocal(handshake.ErrNSTPExperiment) + Eventually(mconn.written).Should(HaveLen(1)) Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset Expect(sess.Context().Done()).To(BeClosed()) }) @@ -571,7 +571,7 @@ var _ = Describe("Session", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) - sess.Close(nil) + sess.Close() Eventually(returned).Should(BeClosed()) }) }) @@ -911,7 +911,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(2)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -935,7 +935,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(1)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -963,7 +963,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -986,7 +986,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(HaveLen(3)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1007,7 +1007,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) }) @@ -1050,7 +1050,7 @@ var _ = Describe("Session", func() { // make sure that the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1081,7 +1081,7 @@ var _ = Describe("Session", func() { // make sure that the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) }) @@ -1252,7 +1252,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1283,7 +1283,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) }) @@ -1335,7 +1335,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(0)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1348,7 +1348,7 @@ var _ = Describe("Session", func() { Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1363,7 +1363,7 @@ var _ = Describe("Session", func() { Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1398,7 +1398,7 @@ var _ = Describe("Session", func() { Expect(sess.Context().Done()).ToNot(Receive()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1412,7 +1412,7 @@ var _ = Describe("Session", func() { Consistently(sess.undecryptablePackets).Should(BeEmpty()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1437,7 +1437,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1452,10 +1452,24 @@ var _ = Describe("Session", func() { // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("doesn't return a run error when closing", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + Expect(sess.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + It("passes errors to the session runner", func() { testErr := errors.New("handshake error") done := make(chan struct{}) @@ -1467,7 +1481,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(testErr) + Expect(sess.CloseWithError(testErr)).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1494,7 +1508,7 @@ var _ = Describe("Session", func() { // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1525,7 +1539,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1543,7 +1557,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1561,7 +1575,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(done).Should(BeClosed()) }) }) @@ -1602,7 +1616,7 @@ var _ = Describe("Session", func() { It("does not use the idle timeout before the handshake complete", func() { sess.config.IdleTimeout = 9999 * time.Second - defer sess.Close(nil) + defer sess.Close() sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity @@ -1613,7 +1627,7 @@ var _ = Describe("Session", func() { Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close(nil) + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1793,7 +1807,7 @@ var _ = Describe("Client Session", func() { Eventually(mconn.written).Should(Receive()) //make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1828,7 +1842,7 @@ var _ = Describe("Client Session", func() { Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7})) // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1856,7 +1870,7 @@ var _ = Describe("Client Session", func() { Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) })