use callbacks for signaling the session status

Instead of exposing a session.handshakeStatus() <-chan error, it's
easier to pass a callback to the session which is called when the
handshake is done.
The removeConnectionID callback is in preparation for IETF QUIC, where a
connection can have multiple connection IDs over its lifetime.
This commit is contained in:
Marten Seemann 2018-05-11 08:40:25 +09:00
parent c7119b2adf
commit 733e2e952b
10 changed files with 295 additions and 174 deletions

View file

@ -37,6 +37,8 @@ type client struct {
initialVersion protocol.VersionNumber
version protocol.VersionNumber
handshakeChan chan struct{}
session packetHandler
logger utils.Logger
@ -105,14 +107,15 @@ func Dial(
}
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: version,
logger: utils.DefaultLogger.WithPrefix("client"),
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: version,
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@ -243,21 +246,19 @@ func (c *client) dialTLS() error {
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error
errorChan := make(chan struct{})
errorChan := make(chan error, 1)
go func() {
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
err := c.session.run() // returns as soon as the session is closed
errorChan <- err
}()
select {
case <-errorChan:
return runErr
case err := <-c.session.handshakeStatus():
case err := <-errorChan:
return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
}
}
@ -438,8 +439,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
func (c *client) createNewGQUICSession() (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
}
c.session, err = newClientSession(
c.conn,
runner,
c.hostname,
c.version,
c.destConnID,
@ -458,8 +464,13 @@ func (c *client) createNewTLSSession(
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
}
c.session, err = newTLSClientSession(
c.conn,
runner,
c.hostname,
c.version,
c.destConnID,

View file

@ -28,7 +28,7 @@ var _ = Describe("Client", func() {
addr net.Addr
connID protocol.ConnectionID
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error)
originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (packetHandler, error)
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -48,7 +48,7 @@ var _ = Describe("Client", func() {
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
msess, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
msess, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
sess = msess.(*mockSession)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = newMockPacketConn()
@ -97,6 +97,7 @@ var _ = Describe("Client", func() {
remoteAddrChan := make(chan string)
newClientSession = func(
conn connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -126,6 +127,7 @@ var _ = Describe("Client", func() {
hostnameChan := make(chan string)
newClientSession = func(
_ connection,
_ sessionRunner,
h string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -160,6 +162,7 @@ var _ = Describe("Client", func() {
It("returns after the handshake is complete", func() {
newClientSession = func(
_ connection,
runner sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -169,6 +172,7 @@ var _ = Describe("Client", func() {
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
runner.onHandshakeComplete(sess)
return sess, nil
}
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
@ -180,14 +184,16 @@ var _ = Describe("Client", func() {
Expect(s).ToNot(BeNil())
close(dialed)
}()
close(sess.handshakeChan)
Eventually(dialed).Should(BeClosed())
// make the session run loop return
close(sess.stopRunLoop)
})
It("returns an error that occurs while waiting for the connection to become secure", func() {
testErr := errors.New("early handshake error")
newClientSession = func(
conn connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -208,7 +214,8 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handshakeChan <- testErr
sess.closeReason = testErr
close(sess.stopRunLoop)
Eventually(done).Should(BeClosed())
})
@ -269,6 +276,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("error creating session")
newClientSession = func(
_ connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -295,6 +303,7 @@ var _ = Describe("Client", func() {
var conf *Config
newTLSClientSession = func(
connP connection,
_ sessionRunner,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
@ -344,6 +353,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs during version negotiation", func() {
newClientSession = func(
conn connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@ -390,9 +400,10 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version))
cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}}
sessionChan := make(chan *mockSession)
handshakeChan := make(chan error)
stopRunLoop := make(chan struct{})
newClientSession = func(
_ connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
@ -406,9 +417,8 @@ var _ = Describe("Client", func() {
negotiatedVersions = negotiatedVersionsP
sess := &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
handshakeChan: handshakeChan,
connectionID: connectionID,
stopRunLoop: stopRunLoop,
}
sessionChan <- sess
return sess, nil
@ -441,7 +451,6 @@ var _ = Describe("Client", func() {
Expect(negotiatedVersions).To(ContainElement(newVersion))
Expect(initialVersion).To(Equal(actualInitialVersion))
close(handshakeChan)
Eventually(established).Should(BeClosed())
})
@ -449,6 +458,7 @@ var _ = Describe("Client", func() {
sessionCounter := uint32(0)
newClientSession = func(
_ connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
@ -613,6 +623,7 @@ var _ = Describe("Client", func() {
var conf *Config
newClientSession = func(
connP connection,
_ sessionRunner,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
@ -651,6 +662,7 @@ var _ = Describe("Client", func() {
sessionChan := make(chan *mockSession)
newTLSClientSession = func(
connP connection,
_ sessionRunner,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,

View file

@ -0,0 +1,55 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionRunner)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockSessionRunner is a mock of SessionRunner interface
type MockSessionRunner struct {
ctrl *gomock.Controller
recorder *MockSessionRunnerMockRecorder
}
// MockSessionRunnerMockRecorder is the mock recorder for MockSessionRunner
type MockSessionRunnerMockRecorder struct {
mock *MockSessionRunner
}
// NewMockSessionRunner creates a new mock instance
func NewMockSessionRunner(ctrl *gomock.Controller) *MockSessionRunner {
mock := &MockSessionRunner{ctrl: ctrl}
mock.recorder = &MockSessionRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
return m.recorder
}
// onHandshakeComplete mocks base method
func (m *MockSessionRunner) onHandshakeComplete(arg0 packetHandler) {
m.ctrl.Call(m, "onHandshakeComplete", arg0)
}
// onHandshakeComplete indicates an expected call of onHandshakeComplete
func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0)
}
// removeConnectionID mocks base method
func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) {
m.ctrl.Call(m, "removeConnectionID", arg0)
}
// removeConnectionID indicates an expected call of removeConnectionID
func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0)
}

View file

@ -8,9 +8,9 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
//go:generate sh -c "goimports -w mock*_test.go"

View file

@ -21,13 +21,27 @@ import (
type packetHandler interface {
Session
getCryptoStream() cryptoStreamI
handshakeStatus() <-chan error
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
run() error
closeRemote(error)
}
type sessionRunner interface {
onHandshakeComplete(packetHandler)
removeConnectionID(protocol.ConnectionID)
}
type runner struct {
onHandshakeCompleteImpl func(packetHandler)
removeConnectionIDImpl func(protocol.ConnectionID)
}
func (r *runner) onHandshakeComplete(p packetHandler) { r.onHandshakeCompleteImpl(p) }
func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
var _ sessionRunner = &runner{}
// A Listener of QUIC
type server struct {
tlsConf *tls.Config
@ -45,12 +59,14 @@ type server struct {
sessions map[string] /* string(ConnectionID)*/ packetHandler
closed bool
serverError error
serverError error
sessionQueue chan Session
errorChan chan struct{}
sessionRunner sessionRunner
// set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error)
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
deleteClosedSessionsAfter time.Duration
logger utils.Logger
@ -112,6 +128,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
supportsTLS: supportsTLS,
logger: utils.DefaultLogger.WithPrefix("server"),
}
s.sessionRunner = &runner{
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
removeConnectionIDImpl: s.removeConnection,
}
if supportsTLS {
if err := s.setupTLS(); err != nil {
return nil, err
@ -127,7 +147,7 @@ func (s *server) setupTLS() error {
if err != nil {
return err
}
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger)
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger)
if err != nil {
return err
}
@ -148,7 +168,7 @@ func (s *server) setupTLS() error {
}
s.sessions[string(connID)] = sess
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID)
go sess.run()
}
}
}()
@ -415,6 +435,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
var err error
session, err = s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
version,
hdr.DestConnectionID,
s.scfg,
@ -429,7 +450,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
s.sessions[string(hdr.DestConnectionID)] = session
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(session, hdr.DestConnectionID)
go session.run()
}
session.handlePacket(&receivedPacket{
@ -441,21 +462,6 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
return nil
}
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
go func() {
_ = session.run()
// session.run() returns as soon as the session is closed
s.removeConnection(connID)
}()
go func() {
if err := <-session.handshakeStatus(); err != nil {
return
}
s.sessionQueue <- session
}()
}
func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[string(id)] = nil

View file

@ -9,7 +9,6 @@ import (
"reflect"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
@ -22,13 +21,13 @@ import (
)
type mockSession struct {
runner sessionRunner
connectionID protocol.ConnectionID
handledPackets []*receivedPacket
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
}
func (s *mockSession) handlePacket(p *receivedPacket) {
@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl
func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{}
func newMockSession(
_ connection,
runner sessionRunner,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *handshake.ServerConfig,
@ -82,9 +81,9 @@ func newMockSession(
_ utils.Logger,
) (packetHandler, error) {
s := mockSession{
connectionID: connectionID,
handshakeChan: make(chan error),
stopRunLoop: make(chan struct{}),
runner: runner,
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}
return &s, nil
}
@ -181,7 +180,7 @@ var _ = Describe("Server", func() {
It("accepts new TLS sessions", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
@ -198,9 +197,9 @@ var _ = Describe("Server", func() {
It("only accepts one new TLS sessions for one connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
@ -224,38 +223,45 @@ var _ = Describe("Server", func() {
}).Should(Equal(sess1))
})
It("accepts a session once the connection it is forward secure", func(done Done) {
It("accepts a session once the connection it is forward secure", func() {
var acceptedSess Session
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
acceptedSess, err = serv.Accept()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
Consistently(func() Session { return acceptedSess }).Should(BeNil())
close(sess.handshakeChan)
serv.sessionQueue <- sess
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done)
}, 0.5)
Eventually(done).Should(BeClosed())
})
It("doesn't accept session that error during the handshake", func(done Done) {
var accepted bool
It("doesn't accept sessions that error during the handshake", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept()
accepted = true
close(done)
}()
err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
sess.handshakeChan <- errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse())
close(done)
sess.closeReason = errors.New("handshake failed")
close(sess.stopRunLoop)
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
serv.removeConnection(connID)
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
})
It("assigns packets to existing sessions", func() {
@ -268,16 +274,10 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
})
It("closes and deletes sessions", func() {
It("deletes sessions", func() {
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)]).ToNot(BeNil())
// make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
serv.sessions[string(connID)] = &mockSession{}
serv.removeConnection(connID)
// The server should now have closed the session, leaving a nil value in the sessions map
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[string(connID)]).To(BeNil())
@ -285,14 +285,9 @@ var _ = Describe("Server", func() {
It("deletes nil session entries after a wait time", func() {
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions).To(HaveKey(string(connID)))
serv.sessions[string(connID)] = &mockSession{}
// make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
serv.removeConnection(connID)
Eventually(func() bool {
serv.sessionsMutex.Lock()
_, ok := serv.sessions[string(connID)]
@ -303,7 +298,7 @@ var _ = Describe("Server", func() {
It("closes sessions and the connection when Close is called", func() {
go serv.serve()
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session
err := serv.Close()
Expect(err).NotTo(HaveOccurred())
@ -353,7 +348,7 @@ var _ = Describe("Server", func() {
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session
Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")

View file

@ -42,7 +42,8 @@ type serverTLS struct {
params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
sessionChan chan<- tlsSession
sessionRunner sessionRunner
sessionChan chan<- tlsSession
logger utils.Logger
}
@ -50,6 +51,7 @@ type serverTLS struct {
func newServerTLS(
conn net.PacketConn,
config *Config,
runner sessionRunner,
cookieHandler *handshake.CookieHandler,
tlsConf *tls.Config,
logger utils.Logger,
@ -72,6 +74,7 @@ func newServerTLS(
config: config,
supportedVersions: config.Versions,
mintConf: mconf,
sessionRunner: runner,
sessionChan: sessionChan,
params: &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
@ -214,6 +217,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
params := <-paramsChan
sess, err := newTLSServerSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
hdr.SrcConnectionID,
hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here
protocol.PacketNumber(1), // TODO: use a random packet number here

View file

@ -37,7 +37,7 @@ var _ = Describe("Stateless TLS handling", func() {
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
var err error
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig(), utils.DefaultLogger)
server, sessionChan, err = newServerTLS(conn, config, nil, nil, testdata.GetTLSConfig(), utils.DefaultLogger)
Expect(err).ToNot(HaveOccurred())
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
mintReply = bc

View file

@ -73,6 +73,8 @@ type closeError struct {
// A Session is a QUIC session
type session struct {
sessionRunner sessionRunner
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
@ -116,11 +118,7 @@ type session struct {
paramsChan <-chan handshake.TransportParameters
// the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets.
handshakeEvent <-chan struct{}
// handshakeChan is returned by handshakeStatus.
// It receives any error that might occur during the handshake.
// It is closed when the handshake is complete.
handshakeChan chan error
handshakeEvent <-chan struct{}
handshakeComplete bool
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
@ -151,6 +149,7 @@ var _ streamSender = &session{}
// newSession makes a new session
func newSession(
conn connection,
sessionRunner sessionRunner,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
scfg *handshake.ServerConfig,
@ -162,6 +161,7 @@ func newSession(
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
sessionRunner: sessionRunner,
srcConnID: connectionID,
destConnID: connectionID,
perspective: protocol.PerspectiveServer,
@ -221,6 +221,7 @@ func newSession(
// declare this as a variable, so that we can it mock it in the tests
var newClientSession = func(
conn connection,
sessionRunner sessionRunner,
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
@ -234,6 +235,7 @@ var newClientSession = func(
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
sessionRunner: sessionRunner,
srcConnID: connectionID,
destConnID: connectionID,
perspective: protocol.PerspectiveClient,
@ -288,6 +290,7 @@ var newClientSession = func(
func newTLSServerSession(
conn connection,
runner sessionRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber,
@ -302,6 +305,7 @@ func newTLSServerSession(
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
sessionRunner: runner,
config: config,
srcConnID: srcConnID,
destConnID: destConnID,
@ -345,6 +349,7 @@ func newTLSServerSession(
// declare this as a variable, such that we can it mock it in the tests
var newTLSClientSession = func(
conn connection,
runner sessionRunner,
hostname string,
v protocol.VersionNumber,
destConnID protocol.ConnectionID,
@ -358,6 +363,7 @@ var newTLSClientSession = func(
handshakeEvent := make(chan struct{}, 1)
s := &session{
conn: conn,
sessionRunner: runner,
config: config,
srcConnID: srcConnID,
destConnID: destConnID,
@ -413,7 +419,6 @@ func (s *session) preSetup() {
}
func (s *session) postSetup() error {
s.handshakeChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
@ -527,13 +532,11 @@ runLoop:
}
}
// only send the error the handshakeChan when the handshake is not completed yet
// otherwise this chan will already be closed
if !s.handshakeComplete {
s.handshakeChan <- closeErr.err
if err := s.handleCloseError(closeErr); err != nil {
s.logger.Infof("Handling close error failed: %s", err)
}
s.handleCloseError(closeErr)
s.logger.Infof("Connection %s closed.", s.srcConnID)
s.sessionRunner.removeConnectionID(s.srcConnID)
return closeErr.err
}
@ -580,6 +583,7 @@ func (s *session) handleHandshakeEvent(completed bool) {
}
s.handshakeComplete = true
s.handshakeEvent = nil // prevent this case from ever being selected again
s.sessionRunner.onHandshakeComplete(s)
// In gQUIC, the server completes the handshake first (after sending the SHLO).
// In TLS 1.3, the client completes the handshake first (after sending the CFIN).
@ -593,7 +597,6 @@ func (s *session) handleHandshakeEvent(completed bool) {
s.queueControlFrame(&wire.PingFrame{})
s.sentPacketHandler.SetHandshakeComplete()
}
close(s.handshakeChan)
}
func (s *session) handlePacketImpl(p *receivedPacket) error {
@ -1239,10 +1242,6 @@ func (s *session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *session) handshakeStatus() <-chan error {
return s.handshakeChan
}
func (s *session) getCryptoStream() cryptoStreamI {
return s.cryptoStream
}

View file

@ -68,6 +68,7 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() {
var (
sess *session
sessionRunner *MockSessionRunner
scfg *handshake.ServerConfig
mconn *mockConnection
cryptoSetup *mockCryptoSetup
@ -97,6 +98,7 @@ var _ = Describe("Session", func() {
return cryptoSetup, nil
}
sessionRunner = NewMockSessionRunner(mockCtrl)
mconn = newMockConnection()
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX()
@ -106,6 +108,7 @@ var _ = Describe("Session", func() {
var pSess Session
pSess, err = newSession(
mconn,
sessionRunner,
protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
scfg,
@ -159,6 +162,7 @@ var _ = Describe("Session", func() {
}
pSess, err := newSession(
mconn,
sessionRunner,
protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
scfg,
@ -471,17 +475,15 @@ var _ = Describe("Session", func() {
It("handles CONNECTION_CLOSE frames", func() {
testErr := qerr.Error(qerr.ProofInvalid, "foobar")
streamManager.EXPECT().CloseWithError(testErr)
done := make(chan struct{})
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).To(MatchError(testErr))
close(done)
}()
err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified)
Expect(err).NotTo(HaveOccurred())
Eventually(sess.Context().Done()).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
})
@ -510,6 +512,7 @@ var _ = Describe("Session", func() {
It("shuts down without error", func() {
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
@ -522,6 +525,7 @@ var _ = Describe("Session", func() {
It("only closes once", func() {
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse())
@ -532,6 +536,7 @@ var _ = Describe("Session", func() {
It("closes streams with proper error", func() {
testErr := errors.New("test error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(testErr)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed())
@ -539,6 +544,7 @@ var _ = Describe("Session", func() {
It("closes the session in order to replace it with another QUIC version", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(errCloseSessionForNewVersion)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
@ -546,6 +552,7 @@ var _ = Describe("Session", func() {
It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1))
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
@ -554,6 +561,7 @@ var _ = Describe("Session", func() {
It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1))
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
@ -562,6 +570,7 @@ var _ = Describe("Session", func() {
It("cancels the context when the run loop exists", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
returned := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -619,20 +628,21 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
})
It("closes when handling a packet fails", func(done Done) {
It("closes when handling a packet fails", func() {
testErr := errors.New("unpack error")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any())
hdr.PacketNumber = 5
var runErr error
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runErr = sess.run()
err := sess.run()
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handlePacket(&receivedPacket{header: hdr})
Eventually(func() error { return runErr }).Should(MatchError(testErr))
Expect(sess.Context().Done()).To(BeClosed())
close(done)
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Eventually(done).Should(BeClosed())
})
It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() {
@ -886,6 +896,7 @@ var _ = Describe("Session", func() {
Eventually(mconn.written).Should(HaveLen(2))
Consistently(mconn.written).Should(HaveLen(2))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
})
@ -909,6 +920,7 @@ var _ = Describe("Session", func() {
Eventually(mconn.written).Should(HaveLen(1))
Consistently(mconn.written).Should(HaveLen(1))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
})
@ -936,6 +948,7 @@ var _ = Describe("Session", func() {
Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1))
Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
})
@ -958,6 +971,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(3))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
})
@ -978,6 +992,7 @@ var _ = Describe("Session", func() {
sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1})
Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
})
@ -1019,6 +1034,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(1))
// make sure that the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
@ -1049,6 +1065,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(1))
// make sure that the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
@ -1210,19 +1227,18 @@ var _ = Describe("Session", func() {
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess.run()
close(done)
}()
Consistently(mconn.written).ShouldNot(Receive())
sess.scheduleSending()
Eventually(mconn.written).Should(Receive())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("sets the timer to the ack timer", func() {
@ -1243,32 +1259,31 @@ var _ = Describe("Session", func() {
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond))
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
sess.receivedPacketHandler = rph
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess.run()
close(done)
}()
Eventually(mconn.written).Should(Receive())
// make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
})
It("closes when crypto stream errors", func() {
testErr := errors.New("crypto setup error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
cryptoSetup.handleErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).To(MatchError(testErr))
close(done)
}()
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() {
@ -1303,7 +1318,9 @@ var _ = Describe("Session", func() {
sendUndecryptablePackets()
sess.scheduleSending()
Consistently(mconn.written).Should(HaveLen(0))
Expect(sess.Close(nil)).To(Succeed())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
})
@ -1314,6 +1331,8 @@ var _ = Describe("Session", func() {
}()
sendUndecryptablePackets()
Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
})
@ -1327,11 +1346,14 @@ var _ = Describe("Session", func() {
Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets))
// check that old packets are kept, and the new packets are dropped
Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("sends a Public Reset after a timeout", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero())
go func() {
defer GinkgoRecover()
@ -1359,7 +1381,9 @@ var _ = Describe("Session", func() {
// in reality, this happens when the trial decryption succeeded during the Public Reset timeout
Consistently(mconn.written).ShouldNot(HaveLen(1))
Expect(sess.Context().Done()).ToNot(Receive())
Expect(sess.Close(nil)).To(Succeed())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
})
@ -1371,6 +1395,8 @@ var _ = Describe("Session", func() {
}()
sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
@ -1387,38 +1413,35 @@ var _ = Describe("Session", func() {
})
It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
sess.run()
}()
handshakeChan <- struct{}{}
Consistently(sess.handshakeStatus()).ShouldNot(Receive())
// don't EXPECT any calls to sessionRunner.onHandshakeComplete()
// make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("closes the handshakeChan when the handshake completes", func() {
done := make(chan struct{})
It("calls the onHandshakeComplete callback when the handshake completes", func() {
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
sess.run()
}()
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
close(handshakeChan)
Eventually(sess.handshakeStatus()).Should(BeClosed())
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
// make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("passes errors to the handshakeChan", func() {
It("passes errors to the session runner", func() {
testErr := errors.New("handshake error")
done := make(chan struct{})
go func() {
@ -1428,19 +1451,17 @@ var _ = Describe("Session", func() {
close(done)
}()
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(testErr)
Expect(sess.handshakeStatus()).To(Receive(Equal(testErr)))
Eventually(done).Should(BeClosed())
})
It("process transport parameters received from the peer", func() {
paramsChan := make(chan handshake.TransportParameters)
sess.paramsChan = paramsChan
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess.run()
close(done)
}()
params := handshake.TransportParameters{
MaxStreams: 123,
@ -1457,8 +1478,9 @@ var _ = Describe("Session", func() {
Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42)))
// make the go routine return
streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
})
Context("keep-alives", func() {
@ -1486,6 +1508,7 @@ var _ = Describe("Session", func() {
// -12 because of the crypto tag. This should be 7 (the frame id for a ping frame).
Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07}))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
@ -1503,6 +1526,7 @@ var _ = Describe("Session", func() {
}()
Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
@ -1520,6 +1544,7 @@ var _ = Describe("Session", func() {
}()
Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil)
Eventually(done).Should(BeClosed())
@ -1531,23 +1556,33 @@ var _ = Describe("Session", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
})
It("times out due to no network activity", func(done Done) {
It("times out due to no network activity", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.handshakeComplete = true
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
It("times out due to non-completed handshake", func(done Done) {
It("times out due to non-completed handshake", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second)
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time.")))
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
It("does not use the idle timeout before the handshake complete", func() {
@ -1556,28 +1591,31 @@ var _ = Describe("Session", func() {
sess.lastNetworkActivityTime = time.Now().Add(-time.Minute)
// the handshake timeout is irrelevant here, since it depends on the time the session was created,
// and not on the last network activity
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_ = sess.run()
close(done)
sess.run()
}()
Consistently(done).ShouldNot(BeClosed())
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("closes the session due to the idle timeout after handshake", func() {
sessionRunner.EXPECT().onHandshakeComplete(sess)
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.config.IdleTimeout = 0
close(handshakeChan)
errChan := make(chan error)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
errChan <- sess.run() // Would normally not return
err := sess.run()
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
close(done)
}()
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
Expect(sess.Context().Done()).To(BeClosed())
})
})
@ -1679,6 +1717,7 @@ var _ = Describe("Session", func() {
var _ = Describe("Client Session", func() {
var (
sess *session
sessionRunner *MockSessionRunner
mconn *mockConnection
handshakeChan chan<- struct{}
@ -1707,8 +1746,10 @@ var _ = Describe("Client Session", func() {
}
mconn = newMockConnection()
sessionRunner = NewMockSessionRunner(mockCtrl)
sessP, err := newClientSession(
mconn,
sessionRunner,
"hostname",
protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
@ -1727,19 +1768,18 @@ var _ = Describe("Client Session", func() {
})
It("sends a forward-secure packet when the handshake completes", func() {
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
sess.packer.hasSentPacket = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
sess.run()
}()
close(handshakeChan)
Eventually(mconn.written).Should(Receive())
//make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
Context("receiving packets", func() {
@ -1755,20 +1795,19 @@ var _ = Describe("Client Session", func() {
unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
sess.unpacker = unpacker
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
sess.run()
}()
hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar")
err := sess.handlePacketImpl(&receivedPacket{header: hdr})
Expect(err).ToNot(HaveOccurred())
Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
})
})