use callbacks for signaling the session status

Instead of exposing a session.handshakeStatus() <-chan error, it's
easier to pass a callback to the session which is called when the
handshake is done.
The removeConnectionID callback is in preparation for IETF QUIC, where a
connection can have multiple connection IDs over its lifetime.
This commit is contained in:
Marten Seemann 2018-05-11 08:40:25 +09:00
parent c7119b2adf
commit 733e2e952b
10 changed files with 295 additions and 174 deletions

View file

@ -37,6 +37,8 @@ type client struct {
initialVersion protocol.VersionNumber initialVersion protocol.VersionNumber
version protocol.VersionNumber version protocol.VersionNumber
handshakeChan chan struct{}
session packetHandler session packetHandler
logger utils.Logger logger utils.Logger
@ -112,6 +114,7 @@ func Dial(
tlsConf: tlsConf, tlsConf: tlsConf,
config: clientConfig, config: clientConfig,
version: version, version: version,
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"), logger: utils.DefaultLogger.WithPrefix("client"),
} }
@ -243,21 +246,19 @@ func (c *client) dialTLS() error {
// - any other error that might occur // - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error { func (c *client) establishSecureConnection() error {
var runErr error errorChan := make(chan error, 1)
errorChan := make(chan struct{})
go func() { go func() {
runErr = c.session.run() // returns as soon as the session is closed err := c.session.run() // returns as soon as the session is closed
close(errorChan) errorChan <- err
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
}() }()
select { select {
case <-errorChan: case err := <-errorChan:
return runErr
case err := <-c.session.handshakeStatus():
return err return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
} }
} }
@ -438,8 +439,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
func (c *client) createNewGQUICSession() (err error) { func (c *client) createNewGQUICSession() (err error) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
}
c.session, err = newClientSession( c.session, err = newClientSession(
c.conn, c.conn,
runner,
c.hostname, c.hostname,
c.version, c.version,
c.destConnID, c.destConnID,
@ -458,8 +464,13 @@ func (c *client) createNewTLSSession(
) (err error) { ) (err error) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
}
c.session, err = newTLSClientSession( c.session, err = newTLSClientSession(
c.conn, c.conn,
runner,
c.hostname, c.hostname,
c.version, c.version,
c.destConnID, c.destConnID,

View file

@ -28,7 +28,7 @@ var _ = Describe("Client", func() {
addr net.Addr addr net.Addr
connID protocol.ConnectionID connID protocol.ConnectionID
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error) originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (packetHandler, error)
) )
// generate a packet sent by the server that accepts the QUIC version suggested by the client // generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -48,7 +48,7 @@ var _ = Describe("Client", func() {
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
msess, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) msess, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
sess = msess.(*mockSession) sess = msess.(*mockSession)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = newMockPacketConn() packetConn = newMockPacketConn()
@ -97,6 +97,7 @@ var _ = Describe("Client", func() {
remoteAddrChan := make(chan string) remoteAddrChan := make(chan string)
newClientSession = func( newClientSession = func(
conn connection, conn connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -126,6 +127,7 @@ var _ = Describe("Client", func() {
hostnameChan := make(chan string) hostnameChan := make(chan string)
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ sessionRunner,
h string, h string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -160,6 +162,7 @@ var _ = Describe("Client", func() {
It("returns after the handshake is complete", func() { It("returns after the handshake is complete", func() {
newClientSession = func( newClientSession = func(
_ connection, _ connection,
runner sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -169,6 +172,7 @@ var _ = Describe("Client", func() {
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
) (packetHandler, error) { ) (packetHandler, error) {
runner.onHandshakeComplete(sess)
return sess, nil return sess, nil
} }
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
@ -180,14 +184,16 @@ var _ = Describe("Client", func() {
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
close(dialed) close(dialed)
}() }()
close(sess.handshakeChan)
Eventually(dialed).Should(BeClosed()) Eventually(dialed).Should(BeClosed())
// make the session run loop return
close(sess.stopRunLoop)
}) })
It("returns an error that occurs while waiting for the connection to become secure", func() { It("returns an error that occurs while waiting for the connection to become secure", func() {
testErr := errors.New("early handshake error") testErr := errors.New("early handshake error")
newClientSession = func( newClientSession = func(
conn connection, conn connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -208,7 +214,8 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
sess.handshakeChan <- testErr sess.closeReason = testErr
close(sess.stopRunLoop)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -269,6 +276,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("error creating session") testErr := errors.New("error creating session")
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -295,6 +303,7 @@ var _ = Describe("Client", func() {
var conf *Config var conf *Config
newTLSClientSession = func( newTLSClientSession = func(
connP connection, connP connection,
_ sessionRunner,
hostnameP string, hostnameP string,
versionP protocol.VersionNumber, versionP protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -344,6 +353,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs during version negotiation", func() { It("returns an error that occurs during version negotiation", func() {
newClientSession = func( newClientSession = func(
conn connection, conn connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -390,9 +400,10 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}} cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}}
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
handshakeChan := make(chan error) stopRunLoop := make(chan struct{})
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
@ -407,8 +418,7 @@ var _ = Describe("Client", func() {
sess := &mockSession{ sess := &mockSession{
connectionID: connectionID, connectionID: connectionID,
stopRunLoop: make(chan struct{}), stopRunLoop: stopRunLoop,
handshakeChan: handshakeChan,
} }
sessionChan <- sess sessionChan <- sess
return sess, nil return sess, nil
@ -441,7 +451,6 @@ var _ = Describe("Client", func() {
Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(negotiatedVersions).To(ContainElement(newVersion))
Expect(initialVersion).To(Equal(actualInitialVersion)) Expect(initialVersion).To(Equal(actualInitialVersion))
close(handshakeChan)
Eventually(established).Should(BeClosed()) Eventually(established).Should(BeClosed())
}) })
@ -449,6 +458,7 @@ var _ = Describe("Client", func() {
sessionCounter := uint32(0) sessionCounter := uint32(0)
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ sessionRunner,
_ string, _ string,
_ protocol.VersionNumber, _ protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
@ -613,6 +623,7 @@ var _ = Describe("Client", func() {
var conf *Config var conf *Config
newClientSession = func( newClientSession = func(
connP connection, connP connection,
_ sessionRunner,
hostnameP string, hostnameP string,
versionP protocol.VersionNumber, versionP protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,
@ -651,6 +662,7 @@ var _ = Describe("Client", func() {
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
newTLSClientSession = func( newTLSClientSession = func(
connP connection, connP connection,
_ sessionRunner,
hostnameP string, hostnameP string,
versionP protocol.VersionNumber, versionP protocol.VersionNumber,
_ protocol.ConnectionID, _ protocol.ConnectionID,

View file

@ -0,0 +1,55 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionRunner)
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockSessionRunner is a mock of SessionRunner interface
type MockSessionRunner struct {
ctrl *gomock.Controller
recorder *MockSessionRunnerMockRecorder
}
// MockSessionRunnerMockRecorder is the mock recorder for MockSessionRunner
type MockSessionRunnerMockRecorder struct {
mock *MockSessionRunner
}
// NewMockSessionRunner creates a new mock instance
func NewMockSessionRunner(ctrl *gomock.Controller) *MockSessionRunner {
mock := &MockSessionRunner{ctrl: ctrl}
mock.recorder = &MockSessionRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
return m.recorder
}
// onHandshakeComplete mocks base method
func (m *MockSessionRunner) onHandshakeComplete(arg0 packetHandler) {
m.ctrl.Call(m, "onHandshakeComplete", arg0)
}
// onHandshakeComplete indicates an expected call of onHandshakeComplete
func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0)
}
// removeConnectionID mocks base method
func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) {
m.ctrl.Call(m, "removeConnectionID", arg0)
}
// removeConnectionID indicates an expected call of removeConnectionID
func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0)
}

View file

@ -8,9 +8,9 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
//go:generate sh -c "goimports -w mock*_test.go" //go:generate sh -c "goimports -w mock*_test.go"

View file

@ -21,13 +21,27 @@ import (
type packetHandler interface { type packetHandler interface {
Session Session
getCryptoStream() cryptoStreamI getCryptoStream() cryptoStreamI
handshakeStatus() <-chan error
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
run() error run() error
closeRemote(error) closeRemote(error)
} }
type sessionRunner interface {
onHandshakeComplete(packetHandler)
removeConnectionID(protocol.ConnectionID)
}
type runner struct {
onHandshakeCompleteImpl func(packetHandler)
removeConnectionIDImpl func(protocol.ConnectionID)
}
func (r *runner) onHandshakeComplete(p packetHandler) { r.onHandshakeCompleteImpl(p) }
func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
var _ sessionRunner = &runner{}
// A Listener of QUIC // A Listener of QUIC
type server struct { type server struct {
tlsConf *tls.Config tlsConf *tls.Config
@ -46,11 +60,13 @@ type server struct {
closed bool closed bool
serverError error serverError error
sessionQueue chan Session sessionQueue chan Session
errorChan chan struct{} errorChan chan struct{}
sessionRunner sessionRunner
// set as members, so they can be set in the tests // set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error) newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
deleteClosedSessionsAfter time.Duration deleteClosedSessionsAfter time.Duration
logger utils.Logger logger utils.Logger
@ -112,6 +128,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
supportsTLS: supportsTLS, supportsTLS: supportsTLS,
logger: utils.DefaultLogger.WithPrefix("server"), logger: utils.DefaultLogger.WithPrefix("server"),
} }
s.sessionRunner = &runner{
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
removeConnectionIDImpl: s.removeConnection,
}
if supportsTLS { if supportsTLS {
if err := s.setupTLS(); err != nil { if err := s.setupTLS(); err != nil {
return nil, err return nil, err
@ -127,7 +147,7 @@ func (s *server) setupTLS() error {
if err != nil { if err != nil {
return err return err
} }
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger) serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger)
if err != nil { if err != nil {
return err return err
} }
@ -148,7 +168,7 @@ func (s *server) setupTLS() error {
} }
s.sessions[string(connID)] = sess s.sessions[string(connID)] = sess
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID) go sess.run()
} }
} }
}() }()
@ -415,6 +435,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
var err error var err error
session, err = s.newSession( session, err = s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr}, &conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
version, version,
hdr.DestConnectionID, hdr.DestConnectionID,
s.scfg, s.scfg,
@ -429,7 +450,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
s.sessions[string(hdr.DestConnectionID)] = session s.sessions[string(hdr.DestConnectionID)] = session
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
s.runHandshakeAndSession(session, hdr.DestConnectionID) go session.run()
} }
session.handlePacket(&receivedPacket{ session.handlePacket(&receivedPacket{
@ -441,21 +462,6 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
return nil return nil
} }
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
go func() {
_ = session.run()
// session.run() returns as soon as the session is closed
s.removeConnection(connID)
}()
go func() {
if err := <-session.handshakeStatus(); err != nil {
return
}
s.sessionQueue <- session
}()
}
func (s *server) removeConnection(id protocol.ConnectionID) { func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock() s.sessionsMutex.Lock()
s.sessions[string(id)] = nil s.sessions[string(id)] = nil

View file

@ -9,7 +9,6 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
@ -22,13 +21,13 @@ import (
) )
type mockSession struct { type mockSession struct {
runner sessionRunner
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
handledPackets []*receivedPacket handledPackets []*receivedPacket
closed bool closed bool
closeReason error closeReason error
closedRemote bool closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
} }
func (s *mockSession) handlePacket(p *receivedPacket) { func (s *mockSession) handlePacket(p *receivedPacket) {
@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl
func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") } func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{} var _ Session = &mockSession{}
func newMockSession( func newMockSession(
_ connection, _ connection,
runner sessionRunner,
_ protocol.VersionNumber, _ protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
_ *handshake.ServerConfig, _ *handshake.ServerConfig,
@ -82,8 +81,8 @@ func newMockSession(
_ utils.Logger, _ utils.Logger,
) (packetHandler, error) { ) (packetHandler, error) {
s := mockSession{ s := mockSession{
runner: runner,
connectionID: connectionID, connectionID: connectionID,
handshakeChan: make(chan error),
stopRunLoop: make(chan struct{}), stopRunLoop: make(chan struct{}),
} }
return &s, nil return &s, nil
@ -181,7 +180,7 @@ var _ = Describe("Server", func() {
It("accepts new TLS sessions", func() { It("accepts new TLS sessions", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS() err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -198,9 +197,9 @@ var _ = Describe("Server", func() {
It("only accepts one new TLS sessions for one connection ID", func() { It("only accepts one new TLS sessions for one connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = serv.setupTLS() err = serv.setupTLS()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -224,38 +223,45 @@ var _ = Describe("Server", func() {
}).Should(Equal(sess1)) }).Should(Equal(sess1))
}) })
It("accepts a session once the connection it is forward secure", func(done Done) { It("accepts a session once the connection it is forward secure", func() {
var acceptedSess Session var acceptedSess Session
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error var err error
acceptedSess, err = serv.Accept() acceptedSess, err = serv.Accept()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
err := serv.handlePacket(nil, firstPacket) err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession) sess := serv.sessions[string(connID)].(*mockSession)
Consistently(func() Session { return acceptedSess }).Should(BeNil()) Consistently(func() Session { return acceptedSess }).Should(BeNil())
close(sess.handshakeChan) serv.sessionQueue <- sess
Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done) Eventually(done).Should(BeClosed())
}, 0.5) })
It("doesn't accept session that error during the handshake", func(done Done) { It("doesn't accept sessions that error during the handshake", func() {
var accepted bool done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.Accept() serv.Accept()
accepted = true close(done)
}() }()
err := serv.handlePacket(nil, firstPacket) err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession) sess := serv.sessions[string(connID)].(*mockSession)
sess.handshakeChan <- errors.New("handshake failed") sess.closeReason = errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse()) close(sess.stopRunLoop)
close(done) Consistently(done).ShouldNot(BeClosed())
// make the go routine return
serv.removeConnection(connID)
close(serv.errorChan)
serv.Close()
Eventually(done).Should(BeClosed())
}) })
It("assigns packets to existing sessions", func() { It("assigns packets to existing sessions", func() {
@ -268,16 +274,10 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
}) })
It("closes and deletes sessions", func() { It("deletes sessions", func() {
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever) serv.sessions[string(connID)] = &mockSession{}
Expect(err).ToNot(HaveOccurred()) serv.removeConnection(connID)
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)]).ToNot(BeNil())
// make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
// The server should now have closed the session, leaving a nil value in the sessions map // The server should now have closed the session, leaving a nil value in the sessions map
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1)) Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
Expect(serv.sessions[string(connID)]).To(BeNil()) Expect(serv.sessions[string(connID)]).To(BeNil())
@ -285,14 +285,9 @@ var _ = Describe("Server", func() {
It("deletes nil session entries after a wait time", func() { It("deletes nil session entries after a wait time", func() {
serv.deleteClosedSessionsAfter = 25 * time.Millisecond serv.deleteClosedSessionsAfter = 25 * time.Millisecond
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever) serv.sessions[string(connID)] = &mockSession{}
Expect(err).ToNot(HaveOccurred())
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions).To(HaveKey(string(connID)))
// make session.run() return // make session.run() return
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{} serv.removeConnection(connID)
Eventually(func() bool { Eventually(func() bool {
serv.sessionsMutex.Lock() serv.sessionsMutex.Lock()
_, ok := serv.sessions[string(connID)] _, ok := serv.sessions[string(connID)]
@ -303,7 +298,7 @@ var _ = Describe("Server", func() {
It("closes sessions and the connection when Close is called", func() { It("closes sessions and the connection when Close is called", func() {
go serv.serve() go serv.serve()
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session serv.sessions[string(connID)] = session
err := serv.Close() err := serv.Close()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -353,7 +348,7 @@ var _ = Describe("Server", func() {
}, 0.5) }, 0.5)
It("closes all sessions when encountering a connection error", func() { It("closes all sessions when encountering a connection error", func() {
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil) session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
serv.sessions[string(connID)] = session serv.sessions[string(connID)] = session
Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse()) Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error") testErr := errors.New("connection error")

View file

@ -42,6 +42,7 @@ type serverTLS struct {
params *handshake.TransportParameters params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
sessionRunner sessionRunner
sessionChan chan<- tlsSession sessionChan chan<- tlsSession
logger utils.Logger logger utils.Logger
@ -50,6 +51,7 @@ type serverTLS struct {
func newServerTLS( func newServerTLS(
conn net.PacketConn, conn net.PacketConn,
config *Config, config *Config,
runner sessionRunner,
cookieHandler *handshake.CookieHandler, cookieHandler *handshake.CookieHandler,
tlsConf *tls.Config, tlsConf *tls.Config,
logger utils.Logger, logger utils.Logger,
@ -72,6 +74,7 @@ func newServerTLS(
config: config, config: config,
supportedVersions: config.Versions, supportedVersions: config.Versions,
mintConf: mconf, mintConf: mconf,
sessionRunner: runner,
sessionChan: sessionChan, sessionChan: sessionChan,
params: &handshake.TransportParameters{ params: &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
@ -214,6 +217,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
params := <-paramsChan params := <-paramsChan
sess, err := newTLSServerSession( sess, err := newTLSServerSession(
&conn{pconn: s.conn, currentAddr: remoteAddr}, &conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
hdr.SrcConnectionID, hdr.SrcConnectionID,
hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here
protocol.PacketNumber(1), // TODO: use a random packet number here protocol.PacketNumber(1), // TODO: use a random packet number here

View file

@ -37,7 +37,7 @@ var _ = Describe("Stateless TLS handling", func() {
Versions: []protocol.VersionNumber{protocol.VersionTLS}, Versions: []protocol.VersionNumber{protocol.VersionTLS},
} }
var err error var err error
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig(), utils.DefaultLogger) server, sessionChan, err = newServerTLS(conn, config, nil, nil, testdata.GetTLSConfig(), utils.DefaultLogger)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
mintReply = bc mintReply = bc

View file

@ -73,6 +73,8 @@ type closeError struct {
// A Session is a QUIC session // A Session is a QUIC session
type session struct { type session struct {
sessionRunner sessionRunner
destConnID protocol.ConnectionID destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID srcConnID protocol.ConnectionID
@ -117,10 +119,6 @@ type session struct {
// the handshakeEvent channel is passed to the CryptoSetup. // the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets. // It receives when it makes sense to try decrypting undecryptable packets.
handshakeEvent <-chan struct{} handshakeEvent <-chan struct{}
// handshakeChan is returned by handshakeStatus.
// It receives any error that might occur during the handshake.
// It is closed when the handshake is complete.
handshakeChan chan error
handshakeComplete bool handshakeComplete bool
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
@ -151,6 +149,7 @@ var _ streamSender = &session{}
// newSession makes a new session // newSession makes a new session
func newSession( func newSession(
conn connection, conn connection,
sessionRunner sessionRunner,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
scfg *handshake.ServerConfig, scfg *handshake.ServerConfig,
@ -162,6 +161,7 @@ func newSession(
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: sessionRunner,
srcConnID: connectionID, srcConnID: connectionID,
destConnID: connectionID, destConnID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
@ -221,6 +221,7 @@ func newSession(
// declare this as a variable, so that we can it mock it in the tests // declare this as a variable, so that we can it mock it in the tests
var newClientSession = func( var newClientSession = func(
conn connection, conn connection,
sessionRunner sessionRunner,
hostname string, hostname string,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
@ -234,6 +235,7 @@ var newClientSession = func(
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: sessionRunner,
srcConnID: connectionID, srcConnID: connectionID,
destConnID: connectionID, destConnID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
@ -288,6 +290,7 @@ var newClientSession = func(
func newTLSServerSession( func newTLSServerSession(
conn connection, conn connection,
runner sessionRunner,
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
@ -302,6 +305,7 @@ func newTLSServerSession(
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner,
config: config, config: config,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
@ -345,6 +349,7 @@ func newTLSServerSession(
// declare this as a variable, such that we can it mock it in the tests // declare this as a variable, such that we can it mock it in the tests
var newTLSClientSession = func( var newTLSClientSession = func(
conn connection, conn connection,
runner sessionRunner,
hostname string, hostname string,
v protocol.VersionNumber, v protocol.VersionNumber,
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
@ -358,6 +363,7 @@ var newTLSClientSession = func(
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner,
config: config, config: config,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
@ -413,7 +419,6 @@ func (s *session) preSetup() {
} }
func (s *session) postSetup() error { func (s *session) postSetup() error {
s.handshakeChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
@ -527,13 +532,11 @@ runLoop:
} }
} }
// only send the error the handshakeChan when the handshake is not completed yet if err := s.handleCloseError(closeErr); err != nil {
// otherwise this chan will already be closed s.logger.Infof("Handling close error failed: %s", err)
if !s.handshakeComplete {
s.handshakeChan <- closeErr.err
} }
s.handleCloseError(closeErr)
s.logger.Infof("Connection %s closed.", s.srcConnID) s.logger.Infof("Connection %s closed.", s.srcConnID)
s.sessionRunner.removeConnectionID(s.srcConnID)
return closeErr.err return closeErr.err
} }
@ -580,6 +583,7 @@ func (s *session) handleHandshakeEvent(completed bool) {
} }
s.handshakeComplete = true s.handshakeComplete = true
s.handshakeEvent = nil // prevent this case from ever being selected again s.handshakeEvent = nil // prevent this case from ever being selected again
s.sessionRunner.onHandshakeComplete(s)
// In gQUIC, the server completes the handshake first (after sending the SHLO). // In gQUIC, the server completes the handshake first (after sending the SHLO).
// In TLS 1.3, the client completes the handshake first (after sending the CFIN). // In TLS 1.3, the client completes the handshake first (after sending the CFIN).
@ -593,7 +597,6 @@ func (s *session) handleHandshakeEvent(completed bool) {
s.queueControlFrame(&wire.PingFrame{}) s.queueControlFrame(&wire.PingFrame{})
s.sentPacketHandler.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete()
} }
close(s.handshakeChan)
} }
func (s *session) handlePacketImpl(p *receivedPacket) error { func (s *session) handlePacketImpl(p *receivedPacket) error {
@ -1239,10 +1242,6 @@ func (s *session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }
func (s *session) handshakeStatus() <-chan error {
return s.handshakeChan
}
func (s *session) getCryptoStream() cryptoStreamI { func (s *session) getCryptoStream() cryptoStreamI {
return s.cryptoStream return s.cryptoStream
} }

View file

@ -68,6 +68,7 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
sess *session sess *session
sessionRunner *MockSessionRunner
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
mconn *mockConnection mconn *mockConnection
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
@ -97,6 +98,7 @@ var _ = Describe("Session", func() {
return cryptoSetup, nil return cryptoSetup, nil
} }
sessionRunner = NewMockSessionRunner(mockCtrl)
mconn = newMockConnection() mconn = newMockConnection()
certChain := crypto.NewCertChain(testdata.GetTLSConfig()) certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
@ -106,6 +108,7 @@ var _ = Describe("Session", func() {
var pSess Session var pSess Session
pSess, err = newSession( pSess, err = newSession(
mconn, mconn,
sessionRunner,
protocol.Version39, protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
scfg, scfg,
@ -159,6 +162,7 @@ var _ = Describe("Session", func() {
} }
pSess, err := newSession( pSess, err := newSession(
mconn, mconn,
sessionRunner,
protocol.Version39, protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
scfg, scfg,
@ -471,17 +475,15 @@ var _ = Describe("Session", func() {
It("handles CONNECTION_CLOSE frames", func() { It("handles CONNECTION_CLOSE frames", func() {
testErr := qerr.Error(qerr.ProofInvalid, "foobar") testErr := qerr.Error(qerr.ProofInvalid, "foobar")
streamManager.EXPECT().CloseWithError(testErr) streamManager.EXPECT().CloseWithError(testErr)
done := make(chan struct{}) sessionRunner.EXPECT().removeConnectionID(gomock.Any())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() err := sess.run()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done)
}() }()
err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
Eventually(done).Should(BeClosed())
}) })
}) })
@ -510,6 +512,7 @@ var _ = Describe("Session", func() {
It("shuts down without error", func() { It("shuts down without error", func() {
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
@ -522,6 +525,7 @@ var _ = Describe("Session", func() {
It("only closes once", func() { It("only closes once", func() {
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
sess.Close(nil) sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
@ -532,6 +536,7 @@ var _ = Describe("Session", func() {
It("closes streams with proper error", func() { It("closes streams with proper error", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(testErr) sess.Close(testErr)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
@ -539,6 +544,7 @@ var _ = Describe("Session", func() {
It("closes the session in order to replace it with another QUIC version", func() { It("closes the session in order to replace it with another QUIC version", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(errCloseSessionForNewVersion) sess.Close(errCloseSessionForNewVersion)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
@ -546,6 +552,7 @@ var _ = Describe("Session", func() {
It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() { It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(handshake.ErrHOLExperiment) sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
@ -554,6 +561,7 @@ var _ = Describe("Session", func() {
It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() { It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(handshake.ErrHOLExperiment) sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
@ -562,6 +570,7 @@ var _ = Describe("Session", func() {
It("cancels the context when the run loop exists", func() { It("cancels the context when the run loop exists", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
returned := make(chan struct{}) returned := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -619,20 +628,21 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("closes when handling a packet fails", func(done Done) { It("closes when handling a packet fails", func() {
testErr := errors.New("unpack error") testErr := errors.New("unpack error")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
hdr.PacketNumber = 5 hdr.PacketNumber = 5
var runErr error done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
runErr = sess.run() err := sess.run()
Expect(err).To(MatchError(testErr))
close(done)
}() }()
sess.handlePacket(&receivedPacket{header: hdr}) sess.handlePacket(&receivedPacket{header: hdr})
Eventually(func() error { return runErr }).Should(MatchError(testErr)) sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Context().Done()).To(BeClosed()) Eventually(done).Should(BeClosed())
close(done)
}) })
It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() {
@ -886,6 +896,7 @@ var _ = Describe("Session", func() {
Eventually(mconn.written).Should(HaveLen(2)) Eventually(mconn.written).Should(HaveLen(2))
Consistently(mconn.written).Should(HaveLen(2)) Consistently(mconn.written).Should(HaveLen(2))
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -909,6 +920,7 @@ var _ = Describe("Session", func() {
Eventually(mconn.written).Should(HaveLen(1)) Eventually(mconn.written).Should(HaveLen(1))
Consistently(mconn.written).Should(HaveLen(1)) Consistently(mconn.written).Should(HaveLen(1))
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -936,6 +948,7 @@ var _ = Describe("Session", func() {
Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1))
Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2))
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -958,6 +971,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending() sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(3)) Eventually(mconn.written).Should(HaveLen(3))
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -978,6 +992,7 @@ var _ = Describe("Session", func() {
sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1})
Consistently(mconn.written).ShouldNot(Receive()) Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -1019,6 +1034,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending() sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(1)) Eventually(mconn.written).Should(HaveLen(1))
// make sure that the go routine returns // make sure that the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1049,6 +1065,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending() sess.scheduleSending()
Eventually(mconn.written).Should(HaveLen(1)) Eventually(mconn.written).Should(HaveLen(1))
// make sure that the go routine returns // make sure that the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1210,19 +1227,18 @@ var _ = Describe("Session", func() {
sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess.run() sess.run()
close(done)
}() }()
Consistently(mconn.written).ShouldNot(Receive()) Consistently(mconn.written).ShouldNot(Receive())
sess.scheduleSending() sess.scheduleSending()
Eventually(mconn.written).Should(Receive()) Eventually(mconn.written).Should(Receive())
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("sets the timer to the ack timer", func() { It("sets the timer to the ack timer", func() {
@ -1243,32 +1259,31 @@ var _ = Describe("Session", func() {
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond))
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
sess.receivedPacketHandler = rph sess.receivedPacketHandler = rph
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess.run() sess.run()
close(done)
}() }()
Eventually(mconn.written).Should(Receive()) Eventually(mconn.written).Should(Receive())
// make sure the go routine returns // make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
}) })
It("closes when crypto stream errors", func() { It("closes when crypto stream errors", func() {
testErr := errors.New("crypto setup error") testErr := errors.New("crypto setup error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
cryptoSetup.handleErr = testErr cryptoSetup.handleErr = testErr
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() err := sess.run()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done)
}() }()
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() {
@ -1303,7 +1318,9 @@ var _ = Describe("Session", func() {
sendUndecryptablePackets() sendUndecryptablePackets()
sess.scheduleSending() sess.scheduleSending()
Consistently(mconn.written).Should(HaveLen(0)) Consistently(mconn.written).Should(HaveLen(0))
Expect(sess.Close(nil)).To(Succeed()) // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
@ -1314,6 +1331,8 @@ var _ = Describe("Session", func() {
}() }()
sendUndecryptablePackets() sendUndecryptablePackets()
Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
@ -1327,11 +1346,14 @@ var _ = Describe("Session", func() {
Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets)) Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets))
// check that old packets are kept, and the new packets are dropped // check that old packets are kept, and the new packets are dropped
Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("sends a Public Reset after a timeout", func() { It("sends a Public Reset after a timeout", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero()) Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1359,7 +1381,9 @@ var _ = Describe("Session", func() {
// in reality, this happens when the trial decryption succeeded during the Public Reset timeout // in reality, this happens when the trial decryption succeeded during the Public Reset timeout
Consistently(mconn.written).ShouldNot(HaveLen(1)) Consistently(mconn.written).ShouldNot(HaveLen(1))
Expect(sess.Context().Done()).ToNot(Receive()) Expect(sess.Context().Done()).ToNot(Receive())
Expect(sess.Close(nil)).To(Succeed()) // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
@ -1371,6 +1395,8 @@ var _ = Describe("Session", func() {
}() }()
sendUndecryptablePackets() sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty()) Consistently(sess.undecryptablePackets).Should(BeEmpty())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
@ -1387,38 +1413,35 @@ var _ = Describe("Session", func() {
}) })
It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() { It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() {
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
handshakeChan <- struct{}{} handshakeChan <- struct{}{}
Consistently(sess.handshakeStatus()).ShouldNot(Receive()) // don't EXPECT any calls to sessionRunner.onHandshakeComplete()
// make sure the go routine returns // make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("closes the handshakeChan when the handshake completes", func() { It("calls the onHandshakeComplete callback when the handshake completes", func() {
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
close(handshakeChan) close(handshakeChan)
Eventually(sess.handshakeStatus()).Should(BeClosed()) Consistently(sess.Context().Done()).ShouldNot(BeClosed())
// make sure the go routine returns // make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("passes errors to the handshakeChan", func() { It("passes errors to the session runner", func() {
testErr := errors.New("handshake error") testErr := errors.New("handshake error")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -1428,19 +1451,17 @@ var _ = Describe("Session", func() {
close(done) close(done)
}() }()
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(testErr) sess.Close(testErr)
Expect(sess.handshakeStatus()).To(Receive(Equal(testErr)))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("process transport parameters received from the peer", func() { It("process transport parameters received from the peer", func() {
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
sess.paramsChan = paramsChan sess.paramsChan = paramsChan
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess.run() sess.run()
close(done)
}() }()
params := handshake.TransportParameters{ params := handshake.TransportParameters{
MaxStreams: 123, MaxStreams: 123,
@ -1457,8 +1478,9 @@ var _ = Describe("Session", func() {
Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42))) Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42)))
// make the go routine return // make the go routine return
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Eventually(done).Should(BeClosed()) sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
}) })
Context("keep-alives", func() { Context("keep-alives", func() {
@ -1486,6 +1508,7 @@ var _ = Describe("Session", func() {
// -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame).
Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07}))
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1503,6 +1526,7 @@ var _ = Describe("Session", func() {
}() }()
Consistently(mconn.written).ShouldNot(Receive()) Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1520,6 +1544,7 @@ var _ = Describe("Session", func() {
}() }()
Consistently(mconn.written).ShouldNot(Receive()) Consistently(mconn.written).ShouldNot(Receive())
// make the go routine return // make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sess.Close(nil) sess.Close(nil)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1531,23 +1556,33 @@ var _ = Describe("Session", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
}) })
It("times out due to no network activity", func(done Done) { It("times out due to no network activity", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.handshakeComplete = true sess.handshakeComplete = true
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
err := sess.run() // Would normally not return done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
Expect(sess.Context().Done()).To(BeClosed())
close(done) close(done)
}()
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
}) })
It("times out due to non-completed handshake", func(done Done) { It("times out due to non-completed handshake", func() {
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second)
err := sess.run() // Would normally not return done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.run()
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time.")))
Expect(sess.Context().Done()).To(BeClosed())
close(done) close(done)
}()
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time.")))
}) })
It("does not use the idle timeout before the handshake complete", func() { It("does not use the idle timeout before the handshake complete", func() {
@ -1556,28 +1591,31 @@ var _ = Describe("Session", func() {
sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) sess.lastNetworkActivityTime = time.Now().Add(-time.Minute)
// the handshake timeout is irrelevant here, since it depends on the time the session was created, // the handshake timeout is irrelevant here, since it depends on the time the session was created,
// and not on the last network activity // and not on the last network activity
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_ = sess.run() sess.run()
close(done)
}() }()
Consistently(done).ShouldNot(BeClosed()) Consistently(sess.Context().Done()).ShouldNot(BeClosed())
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.Close(nil)
Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("closes the session due to the idle timeout after handshake", func() { It("closes the session due to the idle timeout after handshake", func() {
sessionRunner.EXPECT().onHandshakeComplete(sess)
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.config.IdleTimeout = 0 sess.config.IdleTimeout = 0
close(handshakeChan) close(handshakeChan)
errChan := make(chan error) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
errChan <- sess.run() // Would normally not return err := sess.run()
}()
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
Expect(sess.Context().Done()).To(BeClosed())
}) })
}) })
@ -1679,6 +1717,7 @@ var _ = Describe("Session", func() {
var _ = Describe("Client Session", func() { var _ = Describe("Client Session", func() {
var ( var (
sess *session sess *session
sessionRunner *MockSessionRunner
mconn *mockConnection mconn *mockConnection
handshakeChan chan<- struct{} handshakeChan chan<- struct{}
@ -1707,8 +1746,10 @@ var _ = Describe("Client Session", func() {
} }
mconn = newMockConnection() mconn = newMockConnection()
sessionRunner = NewMockSessionRunner(mockCtrl)
sessP, err := newClientSession( sessP, err := newClientSession(
mconn, mconn,
sessionRunner,
"hostname", "hostname",
protocol.Version39, protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
@ -1727,19 +1768,18 @@ var _ = Describe("Client Session", func() {
}) })
It("sends a forward-secure packet when the handshake completes", func() { It("sends a forward-secure packet when the handshake completes", func() {
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
sess.packer.hasSentPacket = true sess.packer.hasSentPacket = true
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
close(handshakeChan) close(handshakeChan)
Eventually(mconn.written).Should(Receive()) Eventually(mconn.written).Should(Receive())
//make sure the go routine returns //make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
Context("receiving packets", func() { Context("receiving packets", func() {
@ -1755,20 +1795,19 @@ var _ = Describe("Client Session", func() {
unpacker := NewMockUnpacker(mockCtrl) unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
sess.unpacker = unpacker sess.unpacker = unpacker
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
hdr.PacketNumber = 5 hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar") hdr.DiversificationNonce = []byte("foobar")
err := sess.handlePacketImpl(&receivedPacket{header: hdr}) err := sess.handlePacketImpl(&receivedPacket{header: hdr})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce)) Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce))
// make the go routine return
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
}) })
}) })