remove the closeCallback from the session

The closeCallback was run when a session was closed, i.e. after the run
loop of the session stopped. Instead of explicitely calling this callback
from the session, the caller of session.run() can just execute the code
after session.run() returns.
This commit is contained in:
Marten Seemann 2017-04-27 19:10:14 +07:00
parent 5f25ffc795
commit 96e49b0c31
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
5 changed files with 52 additions and 51 deletions

View file

@ -244,18 +244,27 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.version,
c.connectionID,
c.config.TLSConfig,
c.closeCallback,
c.cryptoChangeCallback,
negotiatedVersions)
negotiatedVersions,
)
if err != nil {
return err
}
go c.session.run()
go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.mutex.Lock()
c.listenErr = err
c.connStateChangeOrErrCond.Signal()
c.mutex.Unlock()
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil
}
func (c *client) closeCallback(_ protocol.ConnectionID) {
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}

View file

@ -18,7 +18,7 @@ import (
type packetHandler interface {
Session
handlePacket(*receivedPacket)
run()
run() error
}
// A Listener of QUIC
@ -34,7 +34,7 @@ type server struct {
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
}
var _ Listener = &server{}
@ -182,16 +182,22 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
version,
hdr.ConnectionID,
s.scfg,
s.closeCallback,
s.cryptoChangeCallback,
)
if err != nil {
return err
}
go session.run()
s.sessionsMutex.Lock()
s.sessions[hdr.ConnectionID] = session
s.sessionsMutex.Unlock()
go func() {
// session.run() returns as soon as the session is closed
_ = session.run()
s.removeConnection(hdr.ConnectionID)
}()
if s.config.ConnState != nil {
go s.config.ConnState(session, ConnStateVersionNegotiated)
}
@ -221,7 +227,7 @@ func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
}
}
func (s *server) closeCallback(id protocol.ConnectionID) {
func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[id] = nil
s.sessionsMutex.Unlock()

View file

@ -21,13 +21,17 @@ type mockSession struct {
packetCount int
closed bool
closeReason error
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
}
func (s *mockSession) handlePacket(*receivedPacket) {
s.packetCount++
}
func (s *mockSession) run() {}
func (s *mockSession) run() error {
<-s.stopRunLoop
return s.closeReason
}
func (s *mockSession) Close(e error) error {
s.closeReason = e
s.closed = true
@ -51,9 +55,10 @@ func (s *mockSession) RemoteAddr() net.Addr {
var _ Session = &mockSession{}
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
return &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}, nil
}
@ -174,9 +179,10 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID]).ToNot(BeNil())
serv.closeCallback(connID)
// make session.run() return
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
// The server should now have closed the session, leaving a nil value in the sessions map
Expect(serv.sessions).To(HaveLen(1))
Consistently(func() map[protocol.ConnectionID]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[connID]).To(BeNil())
})
@ -186,8 +192,9 @@ var _ = Describe("Server", func() {
err := serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
serv.closeCallback(connID)
Expect(serv.sessions).To(HaveKey(connID))
// make session.run() return
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
Eventually(func() bool {
serv.sessionsMutex.Lock()
_, ok := serv.sessions[connID]

View file

@ -39,16 +39,12 @@ var (
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(session Session, isForwardSecure bool)
// closeCallback is called when a session is closed
type closeCallback func(id protocol.ConnectionID)
// A Session is a QUIC session
type session struct {
connectionID protocol.ConnectionID
perspective protocol.Perspective
version protocol.VersionNumber
closeCallback closeCallback
cryptoChangeCallback cryptoChangeCallback
conn connection
@ -73,6 +69,8 @@ type session struct {
// closeChan is used to notify the run loop that it should terminate.
// If the value is not nil, the error is sent as a CONNECTION_CLOSE.
closeChan chan *qerr.QuicError
// the error this session was closed with
closeErr error
runClosed chan struct{}
closed uint32 // atomic bool
@ -103,14 +101,13 @@ type session struct {
var _ Session = &session{}
// newSession makes a new session
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
closeCallback: closeCallback,
cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
}
@ -136,14 +133,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return s, err
}
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
closeCallback: closeCallback,
cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
}
@ -193,7 +189,7 @@ func (s *session) setup() {
}
// run the session main loop
func (s *session) run() {
func (s *session) run() error {
// Start the crypto stream handler
go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
@ -272,8 +268,8 @@ runLoop:
s.garbageCollectStreams()
}
s.closeCallback(s.connectionID)
s.runClosed <- struct{}{}
return s.closeErr
}
func (s *session) maybeResetTimer() {
@ -507,13 +503,11 @@ func (s *session) closeImpl(e error, remoteClose bool) error {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
}
s.closeErr = e
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
// when the run loop exits, it will call the closeCallback
// replace it with an noop function to make sure this doesn't have any effect
s.closeCallback = func(protocol.ConnectionID) {}
s.closeChan <- nil
return nil
}

View file

@ -121,12 +121,11 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() {
var (
sess *session
clientSess *session
closeCallbackCalled bool
scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager
sess *session
clientSess *session
scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager
)
BeforeEach(func() {
@ -135,8 +134,6 @@ var _ = Describe("Session", func() {
mconn = &mockConnection{
remoteAddr: &net.UDPAddr{},
}
closeCallbackCalled = false
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred())
@ -147,7 +144,6 @@ var _ = Describe("Session", func() {
protocol.Version35,
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).NotTo(HaveOccurred())
@ -163,7 +159,6 @@ var _ = Describe("Session", func() {
protocol.Version35,
0,
nil,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
nil,
)
@ -183,7 +178,6 @@ var _ = Describe("Session", func() {
protocol.VersionWhatever,
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).ToNot(HaveOccurred())
@ -199,7 +193,6 @@ var _ = Describe("Session", func() {
protocol.VersionWhatever,
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).ToNot(HaveOccurred())
@ -642,7 +635,6 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(closeCallbackCalled).To(BeTrue())
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
@ -660,7 +652,6 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
sess.Close(testErr)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(closeCallbackCalled).To(BeTrue())
n, err := s.Read([]byte{0})
Expect(n).To(BeZero())
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
@ -672,7 +663,6 @@ var _ = Describe("Session", func() {
It("closes the session in order to replace it with another QUIC version", func() {
sess.Close(errCloseSessionForNewVersion)
Expect(closeCallbackCalled).To(BeFalse())
Eventually(areSessionsRunning).Should(BeFalse())
Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue())
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
@ -680,7 +670,6 @@ var _ = Describe("Session", func() {
It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() {
sess.Close(handshake.ErrHOLExperiment)
Expect(closeCallbackCalled).To(BeTrue())
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
@ -1315,7 +1304,6 @@ var _ = Describe("Session", func() {
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(sess.runClosed).To(Receive())
close(done)
})
@ -1324,7 +1312,6 @@ var _ = Describe("Session", func() {
sess.sessionCreationTime = time.Now().Add(-time.Hour)
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(sess.runClosed).To(Receive())
close(done)
})
@ -1335,7 +1322,6 @@ var _ = Describe("Session", func() {
sess.packer.connectionParameters = sess.connectionParameters
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(sess.runClosed).To(Receive())
close(done)
})
@ -1348,7 +1334,6 @@ var _ = Describe("Session", func() {
sess.packer.connectionParameters = sess.connectionParameters
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(sess.runClosed).To(Receive())
close(done)
})