From 733e2e952b2d6c3db79cd308b935c69abc3cc89d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 11 May 2018 08:40:25 +0900 Subject: [PATCH] use callbacks for signaling the session status Instead of exposing a session.handshakeStatus() <-chan error, it's easier to pass a callback to the session which is called when the handshake is done. The removeConnectionID callback is in preparation for IETF QUIC, where a connection can have multiple connection IDs over its lifetime. --- client.go | 47 ++++++---- client_test.go | 30 ++++-- mock_session_runner_test.go | 55 +++++++++++ mockgen.go | 4 +- server.go | 48 +++++----- server_test.go | 69 +++++++------- server_tls.go | 6 +- server_tls_test.go | 2 +- session.go | 31 +++---- session_test.go | 177 ++++++++++++++++++++++-------------- 10 files changed, 295 insertions(+), 174 deletions(-) create mode 100644 mock_session_runner_test.go diff --git a/client.go b/client.go index e9f57540..3c6aefb3 100644 --- a/client.go +++ b/client.go @@ -37,6 +37,8 @@ type client struct { initialVersion protocol.VersionNumber version protocol.VersionNumber + handshakeChan chan struct{} + session packetHandler logger utils.Logger @@ -105,14 +107,15 @@ func Dial( } } c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - srcConnID: srcConnID, - destConnID: destConnID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: version, - logger: utils.DefaultLogger.WithPrefix("client"), + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + srcConnID: srcConnID, + destConnID: destConnID, + hostname: hostname, + tlsConf: tlsConf, + config: clientConfig, + version: version, + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) @@ -243,21 +246,19 @@ func (c *client) dialTLS() error { // - any other error that might occur // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) func (c *client) establishSecureConnection() error { - var runErr error - errorChan := make(chan struct{}) + errorChan := make(chan error, 1) + go func() { - runErr = c.session.run() // returns as soon as the session is closed - close(errorChan) - if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { - c.conn.Close() - } + err := c.session.run() // returns as soon as the session is closed + errorChan <- err }() select { - case <-errorChan: - return runErr - case err := <-c.session.handshakeStatus(): + case err := <-errorChan: return err + case <-c.handshakeChan: + // handshake successfully completed + return nil } } @@ -438,8 +439,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { func (c *client) createNewGQUICSession() (err error) { c.mutex.Lock() defer c.mutex.Unlock() + runner := &runner{ + onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) }, + removeConnectionIDImpl: func(protocol.ConnectionID) {}, + } c.session, err = newClientSession( c.conn, + runner, c.hostname, c.version, c.destConnID, @@ -458,8 +464,13 @@ func (c *client) createNewTLSSession( ) (err error) { c.mutex.Lock() defer c.mutex.Unlock() + runner := &runner{ + onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) }, + removeConnectionIDImpl: func(protocol.ConnectionID) {}, + } c.session, err = newTLSClientSession( c.conn, + runner, c.hostname, c.version, c.destConnID, diff --git a/client_test.go b/client_test.go index 5e7a1f77..5eca0f66 100644 --- a/client_test.go +++ b/client_test.go @@ -28,7 +28,7 @@ var _ = Describe("Client", func() { addr net.Addr connID protocol.ConnectionID - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error) + originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (packetHandler, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -48,7 +48,7 @@ var _ = Describe("Client", func() { connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) + msess, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = newMockPacketConn() @@ -97,6 +97,7 @@ var _ = Describe("Client", func() { remoteAddrChan := make(chan string) newClientSession = func( conn connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -126,6 +127,7 @@ var _ = Describe("Client", func() { hostnameChan := make(chan string) newClientSession = func( _ connection, + _ sessionRunner, h string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -160,6 +162,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { newClientSession = func( _ connection, + runner sessionRunner, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -169,6 +172,7 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, _ utils.Logger, ) (packetHandler, error) { + runner.onHandshakeComplete(sess) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) @@ -180,14 +184,16 @@ var _ = Describe("Client", func() { Expect(s).ToNot(BeNil()) close(dialed) }() - close(sess.handshakeChan) Eventually(dialed).Should(BeClosed()) + // make the session run loop return + close(sess.stopRunLoop) }) It("returns an error that occurs while waiting for the connection to become secure", func() { testErr := errors.New("early handshake error") newClientSession = func( conn connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -208,7 +214,8 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) close(done) }() - sess.handshakeChan <- testErr + sess.closeReason = testErr + close(sess.stopRunLoop) Eventually(done).Should(BeClosed()) }) @@ -269,6 +276,7 @@ var _ = Describe("Client", func() { testErr := errors.New("error creating session") newClientSession = func( _ connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -295,6 +303,7 @@ var _ = Describe("Client", func() { var conf *Config newTLSClientSession = func( connP connection, + _ sessionRunner, hostnameP string, versionP protocol.VersionNumber, _ protocol.ConnectionID, @@ -344,6 +353,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs during version negotiation", func() { newClientSession = func( conn connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -390,9 +400,10 @@ var _ = Describe("Client", func() { Expect(newVersion).ToNot(Equal(cl.version)) cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}} sessionChan := make(chan *mockSession) - handshakeChan := make(chan error) + stopRunLoop := make(chan struct{}) newClientSession = func( _ connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, connectionID protocol.ConnectionID, @@ -406,9 +417,8 @@ var _ = Describe("Client", func() { negotiatedVersions = negotiatedVersionsP sess := &mockSession{ - connectionID: connectionID, - stopRunLoop: make(chan struct{}), - handshakeChan: handshakeChan, + connectionID: connectionID, + stopRunLoop: stopRunLoop, } sessionChan <- sess return sess, nil @@ -441,7 +451,6 @@ var _ = Describe("Client", func() { Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(initialVersion).To(Equal(actualInitialVersion)) - close(handshakeChan) Eventually(established).Should(BeClosed()) }) @@ -449,6 +458,7 @@ var _ = Describe("Client", func() { sessionCounter := uint32(0) newClientSession = func( _ connection, + _ sessionRunner, _ string, _ protocol.VersionNumber, connectionID protocol.ConnectionID, @@ -613,6 +623,7 @@ var _ = Describe("Client", func() { var conf *Config newClientSession = func( connP connection, + _ sessionRunner, hostnameP string, versionP protocol.VersionNumber, _ protocol.ConnectionID, @@ -651,6 +662,7 @@ var _ = Describe("Client", func() { sessionChan := make(chan *mockSession) newTLSClientSession = func( connP connection, + _ sessionRunner, hostnameP string, versionP protocol.VersionNumber, _ protocol.ConnectionID, diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go new file mode 100644 index 00000000..7f49d885 --- /dev/null +++ b/mock_session_runner_test.go @@ -0,0 +1,55 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: SessionRunner) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockSessionRunner is a mock of SessionRunner interface +type MockSessionRunner struct { + ctrl *gomock.Controller + recorder *MockSessionRunnerMockRecorder +} + +// MockSessionRunnerMockRecorder is the mock recorder for MockSessionRunner +type MockSessionRunnerMockRecorder struct { + mock *MockSessionRunner +} + +// NewMockSessionRunner creates a new mock instance +func NewMockSessionRunner(ctrl *gomock.Controller) *MockSessionRunner { + mock := &MockSessionRunner{ctrl: ctrl} + mock.recorder = &MockSessionRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { + return m.recorder +} + +// onHandshakeComplete mocks base method +func (m *MockSessionRunner) onHandshakeComplete(arg0 packetHandler) { + m.ctrl.Call(m, "onHandshakeComplete", arg0) +} + +// onHandshakeComplete indicates an expected call of onHandshakeComplete +func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0) +} + +// removeConnectionID mocks base method +func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "removeConnectionID", arg0) +} + +// removeConnectionID indicates an expected call of removeConnectionID +func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0) +} diff --git a/mockgen.go b/mockgen.go index 65f38546..833f29f8 100644 --- a/mockgen.go +++ b/mockgen.go @@ -8,9 +8,9 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" -//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker" -//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go" //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" +//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" +//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" //go:generate sh -c "goimports -w mock*_test.go" diff --git a/server.go b/server.go index 02a9d077..160bc380 100644 --- a/server.go +++ b/server.go @@ -21,13 +21,27 @@ import ( type packetHandler interface { Session getCryptoStream() cryptoStreamI - handshakeStatus() <-chan error handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error closeRemote(error) } +type sessionRunner interface { + onHandshakeComplete(packetHandler) + removeConnectionID(protocol.ConnectionID) +} + +type runner struct { + onHandshakeCompleteImpl func(packetHandler) + removeConnectionIDImpl func(protocol.ConnectionID) +} + +func (r *runner) onHandshakeComplete(p packetHandler) { r.onHandshakeCompleteImpl(p) } +func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) } + +var _ sessionRunner = &runner{} + // A Listener of QUIC type server struct { tlsConf *tls.Config @@ -45,12 +59,14 @@ type server struct { sessions map[string] /* string(ConnectionID)*/ packetHandler closed bool - serverError error + serverError error + sessionQueue chan Session errorChan chan struct{} + sessionRunner sessionRunner // set as members, so they can be set in the tests - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error) + newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error) deleteClosedSessionsAfter time.Duration logger utils.Logger @@ -112,6 +128,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, supportsTLS: supportsTLS, logger: utils.DefaultLogger.WithPrefix("server"), } + s.sessionRunner = &runner{ + onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess }, + removeConnectionIDImpl: s.removeConnection, + } if supportsTLS { if err := s.setupTLS(); err != nil { return nil, err @@ -127,7 +147,7 @@ func (s *server) setupTLS() error { if err != nil { return err } - serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger) + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger) if err != nil { return err } @@ -148,7 +168,7 @@ func (s *server) setupTLS() error { } s.sessions[string(connID)] = sess s.sessionsMutex.Unlock() - s.runHandshakeAndSession(sess, connID) + go sess.run() } } }() @@ -415,6 +435,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd var err error session, err = s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, + s.sessionRunner, version, hdr.DestConnectionID, s.scfg, @@ -429,7 +450,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd s.sessions[string(hdr.DestConnectionID)] = session s.sessionsMutex.Unlock() - s.runHandshakeAndSession(session, hdr.DestConnectionID) + go session.run() } session.handlePacket(&receivedPacket{ @@ -441,21 +462,6 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd return nil } -func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) { - go func() { - _ = session.run() - // session.run() returns as soon as the session is closed - s.removeConnection(connID) - }() - - go func() { - if err := <-session.handshakeStatus(); err != nil { - return - } - s.sessionQueue <- session - }() -} - func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[string(id)] = nil diff --git a/server_test.go b/server_test.go index 0642bd72..6e29e46a 100644 --- a/server_test.go +++ b/server_test.go @@ -9,7 +9,6 @@ import ( "reflect" "time" - "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -22,13 +21,13 @@ import ( ) type mockSession struct { + runner sessionRunner connectionID protocol.ConnectionID handledPackets []*receivedPacket closed bool closeReason error closedRemote bool stopRunLoop chan struct{} // run returns as soon as this channel receives a value - handshakeChan chan error } func (s *mockSession) handlePacket(p *receivedPacket) { @@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } -func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan } func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } var _ Session = &mockSession{} func newMockSession( _ connection, + runner sessionRunner, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, @@ -82,9 +81,9 @@ func newMockSession( _ utils.Logger, ) (packetHandler, error) { s := mockSession{ - connectionID: connectionID, - handshakeChan: make(chan error), - stopRunLoop: make(chan struct{}), + runner: runner, + connectionID: connectionID, + stopRunLoop: make(chan struct{}), } return &s, nil } @@ -181,7 +180,7 @@ var _ = Describe("Server", func() { It("accepts new TLS sessions", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) + sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) err = serv.setupTLS() Expect(err).ToNot(HaveOccurred()) @@ -198,9 +197,9 @@ var _ = Describe("Server", func() { It("only accepts one new TLS sessions for one connection ID", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) + sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) - sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) + sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) err = serv.setupTLS() Expect(err).ToNot(HaveOccurred()) @@ -224,38 +223,45 @@ var _ = Describe("Server", func() { }).Should(Equal(sess1)) }) - It("accepts a session once the connection it is forward secure", func(done Done) { + It("accepts a session once the connection it is forward secure", func() { var acceptedSess Session + done := make(chan struct{}) go func() { defer GinkgoRecover() var err error acceptedSess, err = serv.Accept() Expect(err).ToNot(HaveOccurred()) + close(done) }() err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[string(connID)].(*mockSession) Consistently(func() Session { return acceptedSess }).Should(BeNil()) - close(sess.handshakeChan) + serv.sessionQueue <- sess Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) - close(done) - }, 0.5) + Eventually(done).Should(BeClosed()) + }) - It("doesn't accept session that error during the handshake", func(done Done) { - var accepted bool + It("doesn't accept sessions that error during the handshake", func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() serv.Accept() - accepted = true + close(done) }() err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[string(connID)].(*mockSession) - sess.handshakeChan <- errors.New("handshake failed") - Consistently(func() bool { return accepted }).Should(BeFalse()) - close(done) + sess.closeReason = errors.New("handshake failed") + close(sess.stopRunLoop) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + serv.removeConnection(connID) + close(serv.errorChan) + serv.Close() + Eventually(done).Should(BeClosed()) }) It("assigns packets to existing sessions", func() { @@ -268,16 +274,10 @@ var _ = Describe("Server", func() { Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) }) - It("closes and deletes sessions", func() { + It("deletes sessions", func() { serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...)) - Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)]).ToNot(BeNil()) - // make session.run() return - serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{} + serv.sessions[string(connID)] = &mockSession{} + serv.removeConnection(connID) // The server should now have closed the session, leaving a nil value in the sessions map Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1)) Expect(serv.sessions[string(connID)]).To(BeNil()) @@ -285,14 +285,9 @@ var _ = Describe("Server", func() { It("deletes nil session entries after a wait time", func() { serv.deleteClosedSessionsAfter = 25 * time.Millisecond - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...)) - Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions).To(HaveKey(string(connID))) + serv.sessions[string(connID)] = &mockSession{} // make session.run() return - serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{} + serv.removeConnection(connID) Eventually(func() bool { serv.sessionsMutex.Lock() _, ok := serv.sessions[string(connID)] @@ -303,7 +298,7 @@ var _ = Describe("Server", func() { It("closes sessions and the connection when Close is called", func() { go serv.serve() - session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) + session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) serv.sessions[string(connID)] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) @@ -353,7 +348,7 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) + session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) serv.sessions[string(connID)] = session Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") diff --git a/server_tls.go b/server_tls.go index 7424a409..5d1f4ea1 100644 --- a/server_tls.go +++ b/server_tls.go @@ -42,7 +42,8 @@ type serverTLS struct { params *handshake.TransportParameters newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) - sessionChan chan<- tlsSession + sessionRunner sessionRunner + sessionChan chan<- tlsSession logger utils.Logger } @@ -50,6 +51,7 @@ type serverTLS struct { func newServerTLS( conn net.PacketConn, config *Config, + runner sessionRunner, cookieHandler *handshake.CookieHandler, tlsConf *tls.Config, logger utils.Logger, @@ -72,6 +74,7 @@ func newServerTLS( config: config, supportedVersions: config.Versions, mintConf: mconf, + sessionRunner: runner, sessionChan: sessionChan, params: &handshake.TransportParameters{ StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, @@ -214,6 +217,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, params := <-paramsChan sess, err := newTLSServerSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, + s.sessionRunner, hdr.SrcConnectionID, hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here protocol.PacketNumber(1), // TODO: use a random packet number here diff --git a/server_tls_test.go b/server_tls_test.go index a03954b8..4d4e764d 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -37,7 +37,7 @@ var _ = Describe("Stateless TLS handling", func() { Versions: []protocol.VersionNumber{protocol.VersionTLS}, } var err error - server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig(), utils.DefaultLogger) + server, sessionChan, err = newServerTLS(conn, config, nil, nil, testdata.GetTLSConfig(), utils.DefaultLogger) Expect(err).ToNot(HaveOccurred()) server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { mintReply = bc diff --git a/session.go b/session.go index 9cfa1f30..d542672f 100644 --- a/session.go +++ b/session.go @@ -73,6 +73,8 @@ type closeError struct { // A Session is a QUIC session type session struct { + sessionRunner sessionRunner + destConnID protocol.ConnectionID srcConnID protocol.ConnectionID @@ -116,11 +118,7 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. - handshakeEvent <-chan struct{} - // handshakeChan is returned by handshakeStatus. - // It receives any error that might occur during the handshake. - // It is closed when the handshake is complete. - handshakeChan chan error + handshakeEvent <-chan struct{} handshakeComplete bool receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this @@ -151,6 +149,7 @@ var _ streamSender = &session{} // newSession makes a new session func newSession( conn connection, + sessionRunner sessionRunner, v protocol.VersionNumber, connectionID protocol.ConnectionID, scfg *handshake.ServerConfig, @@ -162,6 +161,7 @@ func newSession( handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, + sessionRunner: sessionRunner, srcConnID: connectionID, destConnID: connectionID, perspective: protocol.PerspectiveServer, @@ -221,6 +221,7 @@ func newSession( // declare this as a variable, so that we can it mock it in the tests var newClientSession = func( conn connection, + sessionRunner sessionRunner, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, @@ -234,6 +235,7 @@ var newClientSession = func( handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, + sessionRunner: sessionRunner, srcConnID: connectionID, destConnID: connectionID, perspective: protocol.PerspectiveClient, @@ -288,6 +290,7 @@ var newClientSession = func( func newTLSServerSession( conn connection, + runner sessionRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, @@ -302,6 +305,7 @@ func newTLSServerSession( handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, + sessionRunner: runner, config: config, srcConnID: srcConnID, destConnID: destConnID, @@ -345,6 +349,7 @@ func newTLSServerSession( // declare this as a variable, such that we can it mock it in the tests var newTLSClientSession = func( conn connection, + runner sessionRunner, hostname string, v protocol.VersionNumber, destConnID protocol.ConnectionID, @@ -358,6 +363,7 @@ var newTLSClientSession = func( handshakeEvent := make(chan struct{}, 1) s := &session{ conn: conn, + sessionRunner: runner, config: config, srcConnID: srcConnID, destConnID: destConnID, @@ -413,7 +419,6 @@ func (s *session) preSetup() { } func (s *session) postSetup() error { - s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -527,13 +532,11 @@ runLoop: } } - // only send the error the handshakeChan when the handshake is not completed yet - // otherwise this chan will already be closed - if !s.handshakeComplete { - s.handshakeChan <- closeErr.err + if err := s.handleCloseError(closeErr); err != nil { + s.logger.Infof("Handling close error failed: %s", err) } - s.handleCloseError(closeErr) s.logger.Infof("Connection %s closed.", s.srcConnID) + s.sessionRunner.removeConnectionID(s.srcConnID) return closeErr.err } @@ -580,6 +583,7 @@ func (s *session) handleHandshakeEvent(completed bool) { } s.handshakeComplete = true s.handshakeEvent = nil // prevent this case from ever being selected again + s.sessionRunner.onHandshakeComplete(s) // In gQUIC, the server completes the handshake first (after sending the SHLO). // In TLS 1.3, the client completes the handshake first (after sending the CFIN). @@ -593,7 +597,6 @@ func (s *session) handleHandshakeEvent(completed bool) { s.queueControlFrame(&wire.PingFrame{}) s.sentPacketHandler.SetHandshakeComplete() } - close(s.handshakeChan) } func (s *session) handlePacketImpl(p *receivedPacket) error { @@ -1239,10 +1242,6 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *session) handshakeStatus() <-chan error { - return s.handshakeChan -} - func (s *session) getCryptoStream() cryptoStreamI { return s.cryptoStream } diff --git a/session_test.go b/session_test.go index 571421df..b4d9b5d8 100644 --- a/session_test.go +++ b/session_test.go @@ -68,6 +68,7 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( sess *session + sessionRunner *MockSessionRunner scfg *handshake.ServerConfig mconn *mockConnection cryptoSetup *mockCryptoSetup @@ -97,6 +98,7 @@ var _ = Describe("Session", func() { return cryptoSetup, nil } + sessionRunner = NewMockSessionRunner(mockCtrl) mconn = newMockConnection() certChain := crypto.NewCertChain(testdata.GetTLSConfig()) kex, err := crypto.NewCurve25519KEX() @@ -106,6 +108,7 @@ var _ = Describe("Session", func() { var pSess Session pSess, err = newSession( mconn, + sessionRunner, protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, scfg, @@ -159,6 +162,7 @@ var _ = Describe("Session", func() { } pSess, err := newSession( mconn, + sessionRunner, protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, scfg, @@ -471,17 +475,15 @@ var _ = Describe("Session", func() { It("handles CONNECTION_CLOSE frames", func() { testErr := qerr.Error(qerr.ProofInvalid, "foobar") streamManager.EXPECT().CloseWithError(testErr) - done := make(chan struct{}) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) go func() { defer GinkgoRecover() err := sess.run() Expect(err).To(MatchError(testErr)) - close(done) }() err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) Eventually(sess.Context().Done()).Should(BeClosed()) - Eventually(done).Should(BeClosed()) }) }) @@ -510,6 +512,7 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) @@ -522,6 +525,7 @@ var _ = Describe("Session", func() { It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) @@ -532,6 +536,7 @@ var _ = Describe("Session", func() { It("closes streams with proper error", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -539,6 +544,7 @@ var _ = Describe("Session", func() { It("closes the session in order to replace it with another QUIC version", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(errCloseSessionForNewVersion) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent @@ -546,6 +552,7 @@ var _ = Describe("Session", func() { It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(handshake.ErrHOLExperiment) Expect(mconn.written).To(HaveLen(1)) Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset @@ -554,6 +561,7 @@ var _ = Describe("Session", func() { It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(handshake.ErrHOLExperiment) Expect(mconn.written).To(HaveLen(1)) Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset @@ -562,6 +570,7 @@ var _ = Describe("Session", func() { It("cancels the context when the run loop exists", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -619,20 +628,21 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) }) - It("closes when handling a packet fails", func(done Done) { + It("closes when handling a packet fails", func() { testErr := errors.New("unpack error") unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) hdr.PacketNumber = 5 - var runErr error + done := make(chan struct{}) go func() { defer GinkgoRecover() - runErr = sess.run() + err := sess.run() + Expect(err).To(MatchError(testErr)) + close(done) }() sess.handlePacket(&receivedPacket{header: hdr}) - Eventually(func() error { return runErr }).Should(MatchError(testErr)) - Expect(sess.Context().Done()).To(BeClosed()) - close(done) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + Eventually(done).Should(BeClosed()) }) It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { @@ -886,6 +896,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(HaveLen(2)) Consistently(mconn.written).Should(HaveLen(2)) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) }) @@ -909,6 +920,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(HaveLen(1)) Consistently(mconn.written).Should(HaveLen(1)) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) }) @@ -936,6 +948,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) }) @@ -958,6 +971,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(mconn.written).Should(HaveLen(3)) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) }) @@ -978,6 +992,7 @@ var _ = Describe("Session", func() { sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) }) @@ -1019,6 +1034,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(mconn.written).Should(HaveLen(1)) // make sure that the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) @@ -1049,6 +1065,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(mconn.written).Should(HaveLen(1)) // make sure that the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) @@ -1210,19 +1227,18 @@ var _ = Describe("Session", func() { sph.EXPECT().SentPacket(gomock.Any()) sess.sentPacketHandler = sph - done := make(chan struct{}) go func() { defer GinkgoRecover() sess.run() - close(done) }() Consistently(mconn.written).ShouldNot(Receive()) sess.scheduleSending() Eventually(mconn.written).Should(Receive()) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) It("sets the timer to the ack timer", func() { @@ -1243,32 +1259,31 @@ var _ = Describe("Session", func() { rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)) sess.receivedPacketHandler = rph - done := make(chan struct{}) + go func() { defer GinkgoRecover() sess.run() - close(done) }() Eventually(mconn.written).Should(Receive()) // make sure the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) }) It("closes when crypto stream errors", func() { testErr := errors.New("crypto setup error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) cryptoSetup.handleErr = testErr - done := make(chan struct{}) go func() { defer GinkgoRecover() err := sess.run() Expect(err).To(MatchError(testErr)) - close(done) }() - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { @@ -1303,7 +1318,9 @@ var _ = Describe("Session", func() { sendUndecryptablePackets() sess.scheduleSending() Consistently(mconn.written).Should(HaveLen(0)) - Expect(sess.Close(nil)).To(Succeed()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close(nil) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1314,6 +1331,8 @@ var _ = Describe("Session", func() { }() sendUndecryptablePackets() Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(nil) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1327,11 +1346,14 @@ var _ = Describe("Session", func() { Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets)) // check that old packets are kept, and the new packets are dropped Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("sends a Public Reset after a timeout", func() { + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero()) go func() { defer GinkgoRecover() @@ -1359,7 +1381,9 @@ var _ = Describe("Session", func() { // in reality, this happens when the trial decryption succeeded during the Public Reset timeout Consistently(mconn.written).ShouldNot(HaveLen(1)) Expect(sess.Context().Done()).ToNot(Receive()) - Expect(sess.Close(nil)).To(Succeed()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close(nil) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1371,6 +1395,8 @@ var _ = Describe("Session", func() { }() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1387,38 +1413,35 @@ var _ = Describe("Session", func() { }) It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() { - done := make(chan struct{}) go func() { defer GinkgoRecover() - err := sess.run() - Expect(err).ToNot(HaveOccurred()) - close(done) + sess.run() }() handshakeChan <- struct{}{} - Consistently(sess.handshakeStatus()).ShouldNot(Receive()) + // don't EXPECT any calls to sessionRunner.onHandshakeComplete() // make sure the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("closes the handshakeChan when the handshake completes", func() { - done := make(chan struct{}) + It("calls the onHandshakeComplete callback when the handshake completes", func() { go func() { defer GinkgoRecover() - err := sess.run() - Expect(err).ToNot(HaveOccurred()) - close(done) + sess.run() }() + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) close(handshakeChan) - Eventually(sess.handshakeStatus()).Should(BeClosed()) + Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make sure the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("passes errors to the handshakeChan", func() { + It("passes errors to the session runner", func() { testErr := errors.New("handshake error") done := make(chan struct{}) go func() { @@ -1428,19 +1451,17 @@ var _ = Describe("Session", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.Close(testErr) - Expect(sess.handshakeStatus()).To(Receive(Equal(testErr))) Eventually(done).Should(BeClosed()) }) It("process transport parameters received from the peer", func() { paramsChan := make(chan handshake.TransportParameters) sess.paramsChan = paramsChan - done := make(chan struct{}) go func() { defer GinkgoRecover() sess.run() - close(done) }() params := handshake.TransportParameters{ MaxStreams: 123, @@ -1457,8 +1478,9 @@ var _ = Describe("Session", func() { Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42))) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) - Expect(sess.Close(nil)).To(Succeed()) - Eventually(done).Should(BeClosed()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close(nil) + Eventually(sess.Context().Done()).Should(BeClosed()) }) Context("keep-alives", func() { @@ -1486,6 +1508,7 @@ var _ = Describe("Session", func() { // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) @@ -1503,6 +1526,7 @@ var _ = Describe("Session", func() { }() Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) @@ -1520,6 +1544,7 @@ var _ = Describe("Session", func() { }() Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) sess.Close(nil) Eventually(done).Should(BeClosed()) @@ -1531,23 +1556,33 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) }) - It("times out due to no network activity", func(done Done) { + It("times out due to no network activity", func() { + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.handshakeComplete = true sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) - err := sess.run() // Would normally not return - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) - Expect(sess.Context().Done()).To(BeClosed()) - close(done) }) - It("times out due to non-completed handshake", func(done Done) { + It("times out due to non-completed handshake", func() { + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) - err := sess.run() // Would normally not return - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout)) + close(done) + }() + Eventually(done).Should(BeClosed()) Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time."))) - Expect(sess.Context().Done()).To(BeClosed()) - close(done) }) It("does not use the idle timeout before the handshake complete", func() { @@ -1556,28 +1591,31 @@ var _ = Describe("Session", func() { sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity - done := make(chan struct{}) go func() { defer GinkgoRecover() - _ = sess.run() - close(done) + sess.run() }() - Consistently(done).ShouldNot(BeClosed()) + Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close(nil) + Eventually(sess.Context().Done()).Should(BeClosed()) }) It("closes the session due to the idle timeout after handshake", func() { + sessionRunner.EXPECT().onHandshakeComplete(sess) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.config.IdleTimeout = 0 close(handshakeChan) - errChan := make(chan error) + done := make(chan struct{}) go func() { defer GinkgoRecover() - errChan <- sess.run() // Would normally not return + err := sess.run() + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + close(done) }() - var err error - Eventually(errChan).Should(Receive(&err)) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + Eventually(done).Should(BeClosed()) Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) - Expect(sess.Context().Done()).To(BeClosed()) }) }) @@ -1679,6 +1717,7 @@ var _ = Describe("Session", func() { var _ = Describe("Client Session", func() { var ( sess *session + sessionRunner *MockSessionRunner mconn *mockConnection handshakeChan chan<- struct{} @@ -1707,8 +1746,10 @@ var _ = Describe("Client Session", func() { } mconn = newMockConnection() + sessionRunner = NewMockSessionRunner(mockCtrl) sessP, err := newClientSession( mconn, + sessionRunner, "hostname", protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, @@ -1727,19 +1768,18 @@ var _ = Describe("Client Session", func() { }) It("sends a forward-secure packet when the handshake completes", func() { + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) sess.packer.hasSentPacket = true - done := make(chan struct{}) go func() { defer GinkgoRecover() - err := sess.run() - Expect(err).ToNot(HaveOccurred()) - close(done) + sess.run() }() close(handshakeChan) Eventually(mconn.written).Should(Receive()) //make sure the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) Context("receiving packets", func() { @@ -1755,20 +1795,19 @@ var _ = Describe("Client Session", func() { unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) sess.unpacker = unpacker - done := make(chan struct{}) go func() { defer GinkgoRecover() - err := sess.run() - Expect(err).ToNot(HaveOccurred()) - close(done) + sess.run() }() hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") err := sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce)) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close(nil)).To(Succeed()) - Eventually(done).Should(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) }) }) })