mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
use callbacks for signaling the session status
Instead of exposing a session.handshakeStatus() <-chan error, it's easier to pass a callback to the session which is called when the handshake is done. The removeConnectionID callback is in preparation for IETF QUIC, where a connection can have multiple connection IDs over its lifetime.
This commit is contained in:
parent
c7119b2adf
commit
733e2e952b
10 changed files with 295 additions and 174 deletions
|
@ -9,7 +9,6 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -22,13 +21,13 @@ import (
|
|||
)
|
||||
|
||||
type mockSession struct {
|
||||
runner sessionRunner
|
||||
connectionID protocol.ConnectionID
|
||||
handledPackets []*receivedPacket
|
||||
closed bool
|
||||
closeReason error
|
||||
closedRemote bool
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
handshakeChan chan error
|
||||
}
|
||||
|
||||
func (s *mockSession) handlePacket(p *receivedPacket) {
|
||||
|
@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl
|
|||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
|
||||
func newMockSession(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
_ *handshake.ServerConfig,
|
||||
|
@ -82,9 +81,9 @@ func newMockSession(
|
|||
_ utils.Logger,
|
||||
) (packetHandler, error) {
|
||||
s := mockSession{
|
||||
connectionID: connectionID,
|
||||
handshakeChan: make(chan error),
|
||||
stopRunLoop: make(chan struct{}),
|
||||
runner: runner,
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
@ -181,7 +180,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("accepts new TLS sessions", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -198,9 +197,9 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("only accepts one new TLS sessions for one connection ID", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -224,38 +223,45 @@ var _ = Describe("Server", func() {
|
|||
}).Should(Equal(sess1))
|
||||
})
|
||||
|
||||
It("accepts a session once the connection it is forward secure", func(done Done) {
|
||||
It("accepts a session once the connection it is forward secure", func() {
|
||||
var acceptedSess Session
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
acceptedSess, err = serv.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
Consistently(func() Session { return acceptedSess }).Should(BeNil())
|
||||
close(sess.handshakeChan)
|
||||
serv.sessionQueue <- sess
|
||||
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't accept session that error during the handshake", func(done Done) {
|
||||
var accepted bool
|
||||
It("doesn't accept sessions that error during the handshake", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.Accept()
|
||||
accepted = true
|
||||
close(done)
|
||||
}()
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
sess.handshakeChan <- errors.New("handshake failed")
|
||||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||
close(done)
|
||||
sess.closeReason = errors.New("handshake failed")
|
||||
close(sess.stopRunLoop)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
serv.removeConnection(connID)
|
||||
close(serv.errorChan)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
|
@ -268,16 +274,10 @@ var _ = Describe("Server", func() {
|
|||
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("closes and deletes sessions", func() {
|
||||
It("deletes sessions", func() {
|
||||
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).ToNot(BeNil())
|
||||
// make session.run() return
|
||||
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
serv.removeConnection(connID)
|
||||
// The server should now have closed the session, leaving a nil value in the sessions map
|
||||
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
||||
|
@ -285,14 +285,9 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions).To(HaveKey(string(connID)))
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
// make session.run() return
|
||||
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
|
||||
serv.removeConnection(connID)
|
||||
Eventually(func() bool {
|
||||
serv.sessionsMutex.Lock()
|
||||
_, ok := serv.sessions[string(connID)]
|
||||
|
@ -303,7 +298,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("closes sessions and the connection when Close is called", func() {
|
||||
go serv.serve()
|
||||
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
|
||||
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
|
||||
serv.sessions[string(connID)] = session
|
||||
err := serv.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -353,7 +348,7 @@ var _ = Describe("Server", func() {
|
|||
}, 0.5)
|
||||
|
||||
It("closes all sessions when encountering a connection error", func() {
|
||||
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
|
||||
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
|
||||
serv.sessions[string(connID)] = session
|
||||
Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse())
|
||||
testErr := errors.New("connection error")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue