diff --git a/closed_session.go b/closed_session.go index 872e3701..96a22182 100644 --- a/closed_session.go +++ b/closed_session.go @@ -3,58 +3,44 @@ package quic import ( "sync" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) -type closedSession interface { - destroy() -} - // A closedLocalSession is a session that we closed locally. // When receiving packets for such a session, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. -type closedBaseSession struct { +type closedLocalSession struct { + conn connection + connClosePacket []byte + closeOnce sync.Once closeChan chan struct{} // is closed when the session is closed or destroyed - receivedPackets <-chan *receivedPacket -} - -func (s *closedBaseSession) destroy() { - s.closeOnce.Do(func() { - close(s.closeChan) - }) -} - -func newClosedBaseSession(receivedPackets <-chan *receivedPacket) closedBaseSession { - return closedBaseSession{ - receivedPackets: receivedPackets, - closeChan: make(chan struct{}), - } -} - -type closedLocalSession struct { - closedBaseSession - - conn connection - connClosePacket []byte + receivedPackets chan *receivedPacket counter uint64 // number of packets received + perspective protocol.Perspective + logger utils.Logger } +var _ packetHandler = &closedLocalSession{} + // newClosedLocalSession creates a new closedLocalSession and runs it. func newClosedLocalSession( conn connection, - receivedPackets <-chan *receivedPacket, connClosePacket []byte, + perspective protocol.Perspective, logger utils.Logger, -) closedSession { +) packetHandler { s := &closedLocalSession{ - closedBaseSession: newClosedBaseSession(receivedPackets), - conn: conn, - connClosePacket: connClosePacket, - logger: logger, + conn: conn, + connClosePacket: connClosePacket, + perspective: perspective, + logger: logger, + closeChan: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, 64), } go s.run() return s @@ -64,14 +50,21 @@ func (s *closedLocalSession) run() { for { select { case p := <-s.receivedPackets: - s.handlePacket(p) + s.handlePacketImpl(p) case <-s.closeChan: return } } } -func (s *closedLocalSession) handlePacket(_ *receivedPacket) { +func (s *closedLocalSession) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + } +} + +func (s *closedLocalSession) handlePacketImpl(_ *receivedPacket) { s.counter++ // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving @@ -86,29 +79,35 @@ func (s *closedLocalSession) handlePacket(_ *receivedPacket) { } } +func (s *closedLocalSession) Close() error { + s.destroy(nil) + return nil +} + +func (s *closedLocalSession) destroy(error) { + s.closeOnce.Do(func() { + close(s.closeChan) + }) +} + +func (s *closedLocalSession) getPerspective() protocol.Perspective { + return s.perspective +} + // A closedRemoteSession is a session that was closed remotely. // For such a session, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. type closedRemoteSession struct { - closedBaseSession + perspective protocol.Perspective } -var _ closedSession = &closedRemoteSession{} +var _ packetHandler = &closedRemoteSession{} -func newClosedRemoteSession(receivedPackets <-chan *receivedPacket) closedSession { - s := &closedRemoteSession{ - closedBaseSession: newClosedBaseSession(receivedPackets), - } - go s.run() - return s +func newClosedRemoteSession(pers protocol.Perspective) packetHandler { + return &closedRemoteSession{perspective: pers} } -func (s *closedRemoteSession) run() { - for { - select { - case <-s.receivedPackets: // discard packets - case <-s.closeChan: - return - } - } -} +func (s *closedRemoteSession) handlePacket(*receivedPacket) {} +func (s *closedRemoteSession) Close() error { return nil } +func (s *closedRemoteSession) destroy(error) {} +func (s *closedRemoteSession) getPerspective() protocol.Perspective { return s.perspective } diff --git a/closed_session_test.go b/closed_session_test.go index 674c2f18..9cabfeec 100644 --- a/closed_session_test.go +++ b/closed_session_test.go @@ -1,30 +1,40 @@ package quic import ( + "errors" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("closed local session", func() { +var _ = Describe("Closed local session", func() { var ( - sess closedSession - mconn *mockConnection - receivedPackets chan *receivedPacket + sess packetHandler + mconn *mockConnection ) BeforeEach(func() { mconn = newMockConnection() - receivedPackets = make(chan *receivedPacket, 10) - sess = newClosedLocalSession(mconn, receivedPackets, []byte("close"), utils.DefaultLogger) + sess = newClosedLocalSession(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) + }) + + AfterEach(func() { + Eventually(areClosedSessionsRunning).Should(BeFalse()) + }) + + It("tells its perspective", func() { + Expect(sess.getPerspective()).To(Equal(protocol.PerspectiveClient)) + // stop the session + Expect(sess.Close()).To(Succeed()) }) It("repeats the packet containing the CONNECTION_CLOSE frame", func() { for i := 1; i <= 20; i++ { - receivedPackets <- &receivedPacket{} + sess.handlePacket(&receivedPacket{}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { Eventually(mconn.written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE } else { @@ -32,40 +42,12 @@ var _ = Describe("closed local session", func() { } } // stop the session - sess.destroy() - Eventually(areClosedSessionsRunning).Should(BeFalse()) + Expect(sess.Close()).To(Succeed()) }) It("destroys sessions", func() { Expect(areClosedSessionsRunning()).To(BeTrue()) - sess.destroy() - Eventually(areClosedSessionsRunning).Should(BeFalse()) - }) -}) - -var _ = Describe("closed remote session", func() { - var ( - sess closedSession - receivedPackets chan *receivedPacket - ) - - BeforeEach(func() { - receivedPackets = make(chan *receivedPacket, 10) - sess = newClosedRemoteSession(receivedPackets) - }) - - It("discards packets", func() { - for i := 0; i < 1000; i++ { - receivedPackets <- &receivedPacket{} - } - // stop the session - sess.destroy() - Eventually(areClosedSessionsRunning).Should(BeFalse()) - }) - - It("destroys sessions", func() { - Expect(areClosedSessionsRunning()).To(BeTrue()) - sess.destroy() + sess.destroy(errors.New("destroy")) Eventually(areClosedSessionsRunning).Should(BeFalse()) }) }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 3793d270..8d057df6 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -122,6 +122,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) } +// ReplaceWithClosed mocks base method +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) +} + +// ReplaceWithClosed indicates an expected call of ReplaceWithClosed +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) +} + // Retire mocks base method func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index 9bb54a05..d9f5d3f4 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -70,6 +70,18 @@ func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0) } +// ReplaceWithClosed mocks base method +func (m *MockSessionRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) +} + +// ReplaceWithClosed indicates an expected call of ReplaceWithClosed +func (mr *MockSessionRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockSessionRunner)(nil).ReplaceWithClosed), arg0, arg1) +} + // Retire mocks base method func (m *MockSessionRunner) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 61cf9ecd..74eaf8ea 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -84,6 +84,19 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { }) } +func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { + h.mutex.Lock() + h.handlers[string(id)] = handler + h.mutex.Unlock() + + time.AfterFunc(h.deleteRetiredSessionsAfter, func() { + h.mutex.Lock() + handler.Close() + delete(h.handlers, string(id)) + h.mutex.Unlock() + }) +} + func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler diff --git a/quic_suite_test.go b/quic_suite_test.go index c1a2f5fb..1798fff7 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -1,9 +1,6 @@ package quic import ( - "bytes" - "runtime/pprof" - "strings" "sync" "github.com/golang/mock/gomock" @@ -27,21 +24,6 @@ var _ = BeforeEach(func() { connMuxerOnce = *new(sync.Once) }) -func areSessionsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*session).run") -} - -func areClosedSessionsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*closedLocalSession).run") || - strings.Contains(b.String(), "quic-go.(*closedRemoteSession).run") -} - var _ = AfterEach(func() { mockCtrl.Finish() - Eventually(areSessionsRunning).Should(BeFalse()) - Eventually(areClosedSessionsRunning).Should(BeFalse()) }) diff --git a/server.go b/server.go index 0f0e2e79..fdd9f48d 100644 --- a/server.go +++ b/server.go @@ -37,6 +37,7 @@ type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + ReplaceWithClosed(protocol.ConnectionID, packetHandler) AddResetToken([16]byte, packetHandler) RemoveResetToken([16]byte) GetStatelessResetToken(protocol.ConnectionID) [16]byte @@ -59,6 +60,7 @@ type quicSession interface { type sessionRunner interface { Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + ReplaceWithClosed(protocol.ConnectionID, packetHandler) AddResetToken([16]byte, packetHandler) RemoveResetToken([16]byte) } diff --git a/session.go b/session.go index a6552915..f7d73ae5 100644 --- a/session.go +++ b/session.go @@ -87,7 +87,7 @@ func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeC type closeError struct { err error remote bool - immediate bool + sendClose bool } var errCloseForRecreating = errors.New("closing session in order to recreate it") @@ -131,9 +131,7 @@ type session struct { receivedPackets chan *receivedPacket sendingScheduled chan struct{} - closeOnce sync.Once - closedSessionMutex sync.Mutex - closedSession closedSession + closeOnce sync.Once // closeChan is used to notify the run loop that it should terminate closeChan chan closeError @@ -901,17 +899,12 @@ func (s *session) closeLocal(e error) { } else { s.logger.Errorf("Closing session with error: %s", e) } - s.closeChan <- closeError{err: e, remote: false} + s.closeChan <- closeError{err: e, sendClose: true, remote: false} }) } // destroy closes the session without sending the error on the wire func (s *session) destroy(e error) { - s.closedSessionMutex.Lock() - if s.closedSession != nil { - s.closedSession.destroy() - } - s.closedSessionMutex.Unlock() s.destroyImpl(e) <-s.ctx.Done() } @@ -924,7 +917,7 @@ func (s *session) destroyImpl(e error) { s.logger.Errorf("Destroying session %s with error: %s", s.destConnID, e) } s.sessionRunner.Remove(s.srcConnID) - s.closeChan <- closeError{err: e, immediate: true, remote: false} + s.closeChan <- closeError{err: e, sendClose: false, remote: false} }) } @@ -939,6 +932,7 @@ func (s *session) closeForRecreating() protocol.PacketNumber { func (s *session) closeRemote(e error) { s.closeOnce.Do(func() { s.logger.Errorf("Peer closed session with error: %s", e) + s.sessionRunner.ReplaceWithClosed(s.srcConnID, newClosedRemoteSession(s.perspective)) s.closeChan <- closeError{err: e, remote: true} }) } @@ -970,24 +964,19 @@ func (s *session) handleCloseError(closeErr closeError) { s.streamsMap.CloseWithError(quicErr) - if closeErr.immediate { + if !closeErr.sendClose { return } - s.sessionRunner.Retire(s.srcConnID) // If this is a remote close we're done here if closeErr.remote { - s.closedSessionMutex.Lock() - s.closedSession = newClosedRemoteSession(s.receivedPackets) - s.closedSessionMutex.Unlock() return } connClosePacket, err := s.sendConnectionClose(quicErr) if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - s.closedSessionMutex.Lock() - s.closedSession = newClosedLocalSession(s.conn, s.receivedPackets, connClosePacket, s.logger) - s.closedSessionMutex.Unlock() + cs := newClosedLocalSession(s.conn, connClosePacket, s.perspective, s.logger) + s.sessionRunner.ReplaceWithClosed(s.srcConnID, cs) } func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { diff --git a/session_test.go b/session_test.go index 56e1c03e..d36c478f 100644 --- a/session_test.go +++ b/session_test.go @@ -7,6 +7,8 @@ import ( "crypto/tls" "errors" "net" + "runtime/pprof" + "strings" "time" . "github.com/onsi/ginkgo" @@ -56,6 +58,18 @@ func (m *mockConnection) LocalAddr() net.Addr { return m.localAddr } func (m *mockConnection) RemoteAddr() net.Addr { return m.remoteAddr } func (*mockConnection) Close() error { panic("not implemented") } +func areSessionsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*session).run") +} + +func areClosedSessionsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*closedLocalSession).run") +} + var _ = Describe("Session", func() { var ( sess *session @@ -77,7 +91,17 @@ var _ = Describe("Session", func() { } } + expectReplaceWithClosed := func() { + sessionRunner.EXPECT().ReplaceWithClosed(sess.srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) + Expect(s.Close()).To(Succeed()) + Eventually(areClosedSessionsRunning).Should(BeFalse()) + }) + } + BeforeEach(func() { + Eventually(areSessionsRunning).Should(BeFalse()) + sessionRunner = NewMockSessionRunner(mockCtrl) mconn = newMockConnection() tokenGenerator, err := handshake.NewTokenGenerator() @@ -104,9 +128,7 @@ var _ = Describe("Session", func() { }) AfterEach(func() { - if sess.closedSession != nil { - sess.closedSession.destroy() - } + Eventually(areSessionsRunning).Should(BeFalse()) }) Context("frame handling", func() { @@ -323,7 +345,9 @@ var _ = Describe("Session", func() { It("handles CONNECTION_CLOSE frames, with a transport error code", func() { testErr := qerr.Error(qerr.StreamLimitError, "foobar") streamManager.EXPECT().CloseWithError(testErr) - sessionRunner.EXPECT().Retire(gomock.Any()) + sessionRunner.EXPECT().ReplaceWithClosed(sess.srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + }) cryptoSetup.EXPECT().Close() go func() { @@ -342,7 +366,9 @@ var _ = Describe("Session", func() { It("handles CONNECTION_CLOSE frames, with an application error code", func() { testErr := qerr.ApplicationError(0x1337, "foobar") streamManager.EXPECT().CloseWithError(testErr) - sessionRunner.EXPECT().Retire(gomock.Any()) + sessionRunner.EXPECT().ReplaceWithClosed(sess.srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + }) cryptoSetup.EXPECT().Close() go func() { @@ -390,7 +416,7 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.NoError, "")) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("connection close")}, nil) Expect(sess.Close()).To(Succeed()) @@ -402,7 +428,7 @@ var _ = Describe("Session", func() { It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.NoError, "")) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) @@ -415,7 +441,7 @@ var _ = Describe("Session", func() { It("closes streams with proper error", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error())) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.CloseWithError(0x1337, testErr.Error()) @@ -446,7 +472,7 @@ var _ = Describe("Session", func() { It("cancels the context when the run loop exists", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) returned := make(chan struct{}) @@ -542,7 +568,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -567,7 +593,7 @@ var _ = Describe("Session", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ProtocolViolation)) close(done) }() - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -587,7 +613,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) runErr <- sess.run() }() - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -614,7 +640,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("PROTOCOL_VIOLATION: empty packet")) close(done) }() - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -815,7 +841,7 @@ var _ = Describe("Session", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) @@ -917,7 +943,7 @@ var _ = Describe("Session", func() { AfterEach(func() { // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1034,7 +1060,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(mconn.written).Should(Receive()) // make the go routine return - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1068,7 +1094,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(Receive()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() @@ -1092,7 +1118,7 @@ var _ = Describe("Session", func() { Eventually(handshakeCtx.Done()).Should(BeClosed()) // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) @@ -1102,7 +1128,7 @@ var _ = Describe("Session", func() { It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { packer.EXPECT().PackPacket().AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() go func() { @@ -1136,7 +1162,7 @@ var _ = Describe("Session", func() { Eventually(done).Should(BeClosed()) // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) @@ -1152,7 +1178,7 @@ var _ = Describe("Session", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) @@ -1170,7 +1196,7 @@ var _ = Describe("Session", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -1187,7 +1213,7 @@ var _ = Describe("Session", func() { Expect(err.Error()).To(ContainSubstring("transport parameter")) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.processTransportParameters([]byte("invalid")) @@ -1215,7 +1241,7 @@ var _ = Describe("Session", func() { // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.Close() @@ -1234,7 +1260,7 @@ var _ = Describe("Session", func() { AfterEach(func() { // make the go routine return - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1342,7 +1368,7 @@ var _ = Describe("Session", func() { }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1381,7 +1407,7 @@ var _ = Describe("Session", func() { Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1483,6 +1509,13 @@ var _ = Describe("Client Session", func() { } } + expectReplaceWithClosed := func() { + sessionRunner.EXPECT().ReplaceWithClosed(sess.srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s.Close()).To(Succeed()) + Eventually(areClosedSessionsRunning).Should(BeFalse()) + }) + } + BeforeEach(func() { quicConf = populateClientConfig(&Config{}, true) }) @@ -1514,12 +1547,6 @@ var _ = Describe("Client Session", func() { sess.cryptoStreamHandler = cryptoSetup }) - AfterEach(func() { - if sess.closedSession != nil { - sess.closedSession.destroy() - } - }) - It("changes the connection ID when receiving the first packet from the server", func() { unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { @@ -1549,7 +1576,7 @@ var _ = Describe("Client Session", func() { }, []byte{0}))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1631,7 +1658,7 @@ var _ = Describe("Client Session", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("transport parameter")) }() - sessionRunner.EXPECT().Retire(gomock.Any()) + expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.processTransportParameters([]byte("invalid")) @@ -1734,6 +1761,7 @@ var _ = Describe("Client Session", func() { // Illustrates that an injected Initial with a CONNECTION_CLOSE frame causes // the connection to immediately break down It("fails on Initial-level CONNECTION_CLOSE frame", func() { + sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()) connCloseFrame := testutils.ComposeConnCloseFrame() initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{connCloseFrame}) Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue())