mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
remove the closeCallback from the session
The closeCallback was run when a session was closed, i.e. after the run loop of the session stopped. Instead of explicitely calling this callback from the session, the caller of session.run() can just execute the code after session.run() returns.
This commit is contained in:
parent
5f25ffc795
commit
96e49b0c31
5 changed files with 52 additions and 51 deletions
25
client.go
25
client.go
|
@ -244,18 +244,27 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
|||
c.version,
|
||||
c.connectionID,
|
||||
c.config.TLSConfig,
|
||||
c.closeCallback,
|
||||
c.cryptoChangeCallback,
|
||||
negotiatedVersions)
|
||||
negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go c.session.run()
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
err := c.session.run()
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.listenErr = err
|
||||
c.connStateChangeOrErrCond.Signal()
|
||||
c.mutex.Unlock()
|
||||
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) closeCallback(_ protocol.ConnectionID) {
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}
|
||||
|
|
16
server.go
16
server.go
|
@ -18,7 +18,7 @@ import (
|
|||
type packetHandler interface {
|
||||
Session
|
||||
handlePacket(*receivedPacket)
|
||||
run()
|
||||
run() error
|
||||
}
|
||||
|
||||
// A Listener of QUIC
|
||||
|
@ -34,7 +34,7 @@ type server struct {
|
|||
sessionsMutex sync.RWMutex
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
|
||||
}
|
||||
|
||||
var _ Listener = &server{}
|
||||
|
@ -182,16 +182,22 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
version,
|
||||
hdr.ConnectionID,
|
||||
s.scfg,
|
||||
s.closeCallback,
|
||||
s.cryptoChangeCallback,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go session.run()
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[hdr.ConnectionID] = session
|
||||
s.sessionsMutex.Unlock()
|
||||
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
_ = session.run()
|
||||
|
||||
s.removeConnection(hdr.ConnectionID)
|
||||
}()
|
||||
|
||||
if s.config.ConnState != nil {
|
||||
go s.config.ConnState(session, ConnStateVersionNegotiated)
|
||||
}
|
||||
|
@ -221,7 +227,7 @@ func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *server) closeCallback(id protocol.ConnectionID) {
|
||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[id] = nil
|
||||
s.sessionsMutex.Unlock()
|
||||
|
|
|
@ -21,13 +21,17 @@ type mockSession struct {
|
|||
packetCount int
|
||||
closed bool
|
||||
closeReason error
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
}
|
||||
|
||||
func (s *mockSession) handlePacket(*receivedPacket) {
|
||||
s.packetCount++
|
||||
}
|
||||
|
||||
func (s *mockSession) run() {}
|
||||
func (s *mockSession) run() error {
|
||||
<-s.stopRunLoop
|
||||
return s.closeReason
|
||||
}
|
||||
func (s *mockSession) Close(e error) error {
|
||||
s.closeReason = e
|
||||
s.closed = true
|
||||
|
@ -51,9 +55,10 @@ func (s *mockSession) RemoteAddr() net.Addr {
|
|||
|
||||
var _ Session = &mockSession{}
|
||||
|
||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
return &mockSession{
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -174,9 +179,10 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions[connID]).ToNot(BeNil())
|
||||
serv.closeCallback(connID)
|
||||
// make session.run() return
|
||||
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
|
||||
// The server should now have closed the session, leaving a nil value in the sessions map
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Consistently(func() map[protocol.ConnectionID]packetHandler { return serv.sessions }).Should(HaveLen(1))
|
||||
Expect(serv.sessions[connID]).To(BeNil())
|
||||
})
|
||||
|
||||
|
@ -186,8 +192,9 @@ var _ = Describe("Server", func() {
|
|||
err := serv.handlePacket(nil, nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
serv.closeCallback(connID)
|
||||
Expect(serv.sessions).To(HaveKey(connID))
|
||||
// make session.run() return
|
||||
serv.sessions[connID].(*mockSession).stopRunLoop <- struct{}{}
|
||||
Eventually(func() bool {
|
||||
serv.sessionsMutex.Lock()
|
||||
_, ok := serv.sessions[connID]
|
||||
|
|
20
session.go
20
session.go
|
@ -39,16 +39,12 @@ var (
|
|||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||
|
||||
// closeCallback is called when a session is closed
|
||||
type closeCallback func(id protocol.ConnectionID)
|
||||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
connectionID protocol.ConnectionID
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
|
||||
closeCallback closeCallback
|
||||
cryptoChangeCallback cryptoChangeCallback
|
||||
|
||||
conn connection
|
||||
|
@ -73,6 +69,8 @@ type session struct {
|
|||
// closeChan is used to notify the run loop that it should terminate.
|
||||
// If the value is not nil, the error is sent as a CONNECTION_CLOSE.
|
||||
closeChan chan *qerr.QuicError
|
||||
// the error this session was closed with
|
||||
closeErr error
|
||||
runClosed chan struct{}
|
||||
closed uint32 // atomic bool
|
||||
|
||||
|
@ -103,14 +101,13 @@ type session struct {
|
|||
var _ Session = &session{}
|
||||
|
||||
// newSession makes a new session
|
||||
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
|
||||
closeCallback: closeCallback,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
|
||||
}
|
||||
|
@ -136,14 +133,13 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||
return s, err
|
||||
}
|
||||
|
||||
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
|
||||
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
|
||||
closeCallback: closeCallback,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
|
||||
}
|
||||
|
@ -193,7 +189,7 @@ func (s *session) setup() {
|
|||
}
|
||||
|
||||
// run the session main loop
|
||||
func (s *session) run() {
|
||||
func (s *session) run() error {
|
||||
// Start the crypto stream handler
|
||||
go func() {
|
||||
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
|
@ -272,8 +268,8 @@ runLoop:
|
|||
s.garbageCollectStreams()
|
||||
}
|
||||
|
||||
s.closeCallback(s.connectionID)
|
||||
s.runClosed <- struct{}{}
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
|
@ -507,13 +503,11 @@ func (s *session) closeImpl(e error, remoteClose bool) error {
|
|||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
}
|
||||
s.closeErr = e
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
// when the run loop exits, it will call the closeCallback
|
||||
// replace it with an noop function to make sure this doesn't have any effect
|
||||
s.closeCallback = func(protocol.ConnectionID) {}
|
||||
s.closeChan <- nil
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -121,12 +121,11 @@ func areSessionsRunning() bool {
|
|||
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
clientSess *session
|
||||
closeCallbackCalled bool
|
||||
scfg *handshake.ServerConfig
|
||||
mconn *mockConnection
|
||||
cpm *mockConnectionParametersManager
|
||||
sess *session
|
||||
clientSess *session
|
||||
scfg *handshake.ServerConfig
|
||||
mconn *mockConnection
|
||||
cpm *mockConnectionParametersManager
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -135,8 +134,6 @@ var _ = Describe("Session", func() {
|
|||
mconn = &mockConnection{
|
||||
remoteAddr: &net.UDPAddr{},
|
||||
}
|
||||
closeCallbackCalled = false
|
||||
|
||||
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -147,7 +144,6 @@ var _ = Describe("Session", func() {
|
|||
protocol.Version35,
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -163,7 +159,6 @@ var _ = Describe("Session", func() {
|
|||
protocol.Version35,
|
||||
0,
|
||||
nil,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
nil,
|
||||
)
|
||||
|
@ -183,7 +178,6 @@ var _ = Describe("Session", func() {
|
|||
protocol.VersionWhatever,
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -199,7 +193,6 @@ var _ = Describe("Session", func() {
|
|||
protocol.VersionWhatever,
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -642,7 +635,6 @@ var _ = Describe("Session", func() {
|
|||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
})
|
||||
|
||||
|
@ -660,7 +652,6 @@ var _ = Describe("Session", func() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
sess.Close(testErr)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
n, err := s.Read([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
|
||||
|
@ -672,7 +663,6 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Expect(closeCallbackCalled).To(BeFalse())
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue())
|
||||
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
|
||||
|
@ -680,7 +670,6 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() {
|
||||
sess.Close(handshake.ErrHOLExperiment)
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
|
@ -1315,7 +1304,6 @@ var _ = Describe("Session", func() {
|
|||
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
@ -1324,7 +1312,6 @@ var _ = Describe("Session", func() {
|
|||
sess.sessionCreationTime = time.Now().Add(-time.Hour)
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
@ -1335,7 +1322,6 @@ var _ = Describe("Session", func() {
|
|||
sess.packer.connectionParameters = sess.connectionParameters
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
@ -1348,7 +1334,6 @@ var _ = Describe("Session", func() {
|
|||
sess.packer.connectionParameters = sess.connectionParameters
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue