use callbacks for signaling the session status

Instead of exposing a session.handshakeStatus() <-chan error, it's
easier to pass a callback to the session which is called when the
handshake is done.
The removeConnectionID callback is in preparation for IETF QUIC, where a
connection can have multiple connection IDs over its lifetime.
This commit is contained in:
Marten Seemann 2018-05-11 08:40:25 +09:00
parent c7119b2adf
commit 733e2e952b
10 changed files with 295 additions and 174 deletions

View file

@ -9,7 +9,6 @@ import (
"reflect"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
@ -22,13 +21,13 @@ import (
)
type mockSession struct {
runner sessionRunner
connectionID protocol.ConnectionID
handledPackets []*receivedPacket
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
}
func (s *mockSession) handlePacket(p *receivedPacket) {
@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl
func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{}
func newMockSession(
_ connection,
runner sessionRunner,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *handshake.ServerConfig,
@ -82,9 +81,9 @@ func newMockSession(
_ utils.Logger,
) (packetHandler, error) {
s := mockSession{
connectionID: connectionID,
handshakeChan: make(chan error),
stopRunLoop: make(chan struct{}),
runner: runner,
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}
return &s, nil
}
@ -181,7 +180,7 @@ var _ = Describe("Server", func() {
It("accepts new TLS sessions", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
@ -198,9 +197,9 @@ var _ = Describe("Server", func() {
It("only accepts one new TLS sessions for one connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
@ -224,38 +223,45 @@ var _ = Describe("Server", func() {
}).Should(Equal(sess1))
})
It("accepts a session once the connection it is forward secure", func(done Done) {
It("accepts a session once the connection it is forward secure", func() {
var acceptedSess Session
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
acceptedSess, err = serv.Accept()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
Consistently(func() Session { return acceptedSess }).Should(BeNil())
close(sess.handshakeChan)
serv.sessionQueue <- sess
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done)
}, 0.5)
Eventually(done).Should(BeClosed())
})
It("doesn't accept session that error during the handshake", func(done Done) {
var accepted bool
It("doesn't accept sessions that error during the handshake", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept()
accepted = true
close(done)
}()
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
sess.handshakeChan <- errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse())
close(done)
sess.closeReason = errors.New("handshake failed")
close(sess.stopRunLoop)
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
serv.removeConnection(connID)
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
})
It("assigns packets to existing sessions", func() {
@ -268,16 +274,10 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
})
It("closes and deletes sessions", func() {
It("deletes sessions", func() {
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)]).ToNot(BeNil())
// make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
serv.sessions[string(connID)] = &mockSession{}
serv.removeConnection(connID)
// The server should now have closed the session, leaving a nil value in the sessions map
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[string(connID)]).To(BeNil())
@ -285,14 +285,9 @@ var _ = Describe("Server", func() {
It("deletes nil session entries after a wait time", func() {
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions).To(HaveKey(string(connID)))
serv.sessions[string(connID)] = &mockSession{}
// make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
serv.removeConnection(connID)
Eventually(func() bool {
serv.sessionsMutex.Lock()
_, ok := serv.sessions[string(connID)]
@ -303,7 +298,7 @@ var _ = Describe("Server", func() {
It("closes sessions and the connection when Close is called", func() {
go serv.serve()
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session
err := serv.Close()
Expect(err).NotTo(HaveOccurred())
@ -353,7 +348,7 @@ var _ = Describe("Server", func() {
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session
Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")