diff --git a/client.go b/client.go index 00123f91..c66a155c 100644 --- a/client.go +++ b/client.go @@ -244,18 +244,27 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.version, c.connectionID, c.config.TLSConfig, - c.closeCallback, c.cryptoChangeCallback, - negotiatedVersions) + negotiatedVersions, + ) if err != nil { return err } - go c.session.run() + go func() { + // session.run() returns as soon as the session is closed + err := c.session.run() + if err == errCloseSessionForNewVersion { + return + } + + c.mutex.Lock() + c.listenErr = err + c.connStateChangeOrErrCond.Signal() + c.mutex.Unlock() + + utils.Infof("Connection %x closed.", c.connectionID) + c.conn.Close() + }() return nil } - -func (c *client) closeCallback(_ protocol.ConnectionID) { - utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() -} diff --git a/server.go b/server.go index af4b7e6b..dc1fb1fa 100644 --- a/server.go +++ b/server.go @@ -18,7 +18,7 @@ import ( type packetHandler interface { Session handlePacket(*receivedPacket) - run() + run() error } // A Listener of QUIC @@ -34,7 +34,7 @@ type server struct { sessionsMutex sync.RWMutex deleteClosedSessionsAfter time.Duration - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) } var _ Listener = &server{} @@ -182,16 +182,22 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet version, hdr.ConnectionID, s.scfg, - s.closeCallback, s.cryptoChangeCallback, ) if err != nil { return err } - go session.run() s.sessionsMutex.Lock() s.sessions[hdr.ConnectionID] = session s.sessionsMutex.Unlock() + + go func() { + // session.run() returns as soon as the session is closed + _ = session.run() + + s.removeConnection(hdr.ConnectionID) + }() + if s.config.ConnState != nil { go s.config.ConnState(session, ConnStateVersionNegotiated) } @@ -221,7 +227,7 @@ func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) { } } -func (s *server) closeCallback(id protocol.ConnectionID) { +func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[id] = nil s.sessionsMutex.Unlock() diff --git a/server_test.go b/server_test.go index 06512d8d..42de5ada 100644 --- a/server_test.go +++ b/server_test.go @@ -21,13 +21,17 @@ type mockSession struct { packetCount int closed bool closeReason error + stopRunLoop chan struct{} // run returns as soon as this channel receives a value } func (s *mockSession) handlePacket(*receivedPacket) { s.packetCount++ } -func (s *mockSession) run() {} +func (s *mockSession) run() error { + <-s.stopRunLoop + return s.closeReason +} func (s *mockSession) Close(e error) error { s.closeReason = e s.closed = true @@ -51,9 +55,10 @@ func (s *mockSession) RemoteAddr() net.Addr { var _ Session = &mockSession{} -func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) { +func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) { return &mockSession{ connectionID: connectionID, + stopRunLoop: make(chan struct{}), }, nil } @@ -174,9 +179,10 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions[connID]).ToNot(BeNil()) - serv.closeCallback(connID) + // make session.run() return + serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{} // The server should now have closed the session, leaving a nil value in the sessions map - Expect(serv.sessions).To(HaveLen(1)) + Consistently(func() map[protocol.ConnectionID]packetHandler { return serv.sessions }).Should(HaveLen(1)) Expect(serv.sessions[connID]).To(BeNil()) }) @@ -186,8 +192,9 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - serv.closeCallback(connID) Expect(serv.sessions).To(HaveKey(connID)) + // make session.run() return + serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{} Eventually(func() bool { serv.sessionsMutex.Lock() _, ok := serv.sessions[connID] diff --git a/session.go b/session.go index 39057ce6..c8fba901 100644 --- a/session.go +++ b/session.go @@ -39,16 +39,12 @@ var ( // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that type cryptoChangeCallback func(session Session, isForwardSecure bool) -// closeCallback is called when a session is closed -type closeCallback func(id protocol.ConnectionID) - // A Session is a QUIC session type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber - closeCallback closeCallback cryptoChangeCallback cryptoChangeCallback conn connection @@ -73,6 +69,8 @@ type session struct { // closeChan is used to notify the run loop that it should terminate. // If the value is not nil, the error is sent as a CONNECTION_CLOSE. closeChan chan *qerr.QuicError + // the error this session was closed with + closeErr error runClosed chan struct{} closed uint32 // atomic bool @@ -103,14 +101,13 @@ type session struct { var _ Session = &session{} // newSession makes a new session -func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) { +func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) { s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveServer, version: v, - closeCallback: closeCallback, cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } @@ -136,14 +133,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return s, err } -func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { +func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, - closeCallback: closeCallback, cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), } @@ -193,7 +189,7 @@ func (s *session) setup() { } // run the session main loop -func (s *session) run() { +func (s *session) run() error { // Start the crypto stream handler go func() { if err := s.cryptoSetup.HandleCryptoStream(); err != nil { @@ -272,8 +268,8 @@ runLoop: s.garbageCollectStreams() } - s.closeCallback(s.connectionID) s.runClosed <- struct{}{} + return s.closeErr } func (s *session) maybeResetTimer() { @@ -507,13 +503,11 @@ func (s *session) closeImpl(e error, remoteClose bool) error { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return errSessionAlreadyClosed } + s.closeErr = e if e == errCloseSessionForNewVersion { s.streamsMap.CloseWithError(e) s.closeStreamsWithError(e) - // when the run loop exits, it will call the closeCallback - // replace it with an noop function to make sure this doesn't have any effect - s.closeCallback = func(protocol.ConnectionID) {} s.closeChan <- nil return nil } diff --git a/session_test.go b/session_test.go index 419d825b..90b33d11 100644 --- a/session_test.go +++ b/session_test.go @@ -121,12 +121,11 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - clientSess *session - closeCallbackCalled bool - scfg *handshake.ServerConfig - mconn *mockConnection - cpm *mockConnectionParametersManager + sess *session + clientSess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cpm *mockConnectionParametersManager ) BeforeEach(func() { @@ -135,8 +134,6 @@ var _ = Describe("Session", func() { mconn = &mockConnection{ remoteAddr: &net.UDPAddr{}, } - closeCallbackCalled = false - certChain := crypto.NewCertChain(testdata.GetTLSConfig()) kex, err := crypto.NewCurve25519KEX() Expect(err).NotTo(HaveOccurred()) @@ -147,7 +144,6 @@ var _ = Describe("Session", func() { protocol.Version35, 0, scfg, - func(protocol.ConnectionID) { closeCallbackCalled = true }, func(Session, bool) {}, ) Expect(err).NotTo(HaveOccurred()) @@ -163,7 +159,6 @@ var _ = Describe("Session", func() { protocol.Version35, 0, nil, - func(protocol.ConnectionID) { closeCallbackCalled = true }, func(Session, bool) {}, nil, ) @@ -183,7 +178,6 @@ var _ = Describe("Session", func() { protocol.VersionWhatever, 0, scfg, - func(protocol.ConnectionID) { closeCallbackCalled = true }, func(Session, bool) {}, ) Expect(err).ToNot(HaveOccurred()) @@ -199,7 +193,6 @@ var _ = Describe("Session", func() { protocol.VersionWhatever, 0, scfg, - func(protocol.ConnectionID) { closeCallbackCalled = true }, func(Session, bool) {}, ) Expect(err).ToNot(HaveOccurred()) @@ -642,7 +635,6 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) - Expect(closeCallbackCalled).To(BeTrue()) Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() }) @@ -660,7 +652,6 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) sess.Close(testErr) Eventually(areSessionsRunning).Should(BeFalse()) - Expect(closeCallbackCalled).To(BeTrue()) n, err := s.Read([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) @@ -672,7 +663,6 @@ var _ = Describe("Session", func() { It("closes the session in order to replace it with another QUIC version", func() { sess.Close(errCloseSessionForNewVersion) - Expect(closeCallbackCalled).To(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse()) Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent @@ -680,7 +670,6 @@ var _ = Describe("Session", func() { It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() { sess.Close(handshake.ErrHOLExperiment) - Expect(closeCallbackCalled).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() @@ -1315,7 +1304,6 @@ var _ = Describe("Session", func() { sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(closeCallbackCalled).To(BeTrue()) Expect(sess.runClosed).To(Receive()) close(done) }) @@ -1324,7 +1312,6 @@ var _ = Describe("Session", func() { sess.sessionCreationTime = time.Now().Add(-time.Hour) sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) - Expect(closeCallbackCalled).To(BeTrue()) Expect(sess.runClosed).To(Receive()) close(done) }) @@ -1335,7 +1322,6 @@ var _ = Describe("Session", func() { sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(closeCallbackCalled).To(BeTrue()) Expect(sess.runClosed).To(Receive()) close(done) }) @@ -1348,7 +1334,6 @@ var _ = Describe("Session", func() { sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(closeCallbackCalled).To(BeTrue()) Expect(sess.runClosed).To(Receive()) close(done) })