mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
use callbacks for signaling the session status
Instead of exposing a session.handshakeStatus() <-chan error, it's easier to pass a callback to the session which is called when the handshake is done. The removeConnectionID callback is in preparation for IETF QUIC, where a connection can have multiple connection IDs over its lifetime.
This commit is contained in:
parent
c7119b2adf
commit
733e2e952b
10 changed files with 295 additions and 174 deletions
47
client.go
47
client.go
|
@ -37,6 +37,8 @@ type client struct {
|
|||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
session packetHandler
|
||||
|
||||
logger utils.Logger
|
||||
|
@ -105,14 +107,15 @@ func Dial(
|
|||
}
|
||||
}
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: version,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: version,
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
@ -243,21 +246,19 @@ func (c *client) dialTLS() error {
|
|||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case err := <-c.session.handshakeStatus():
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case <-c.handshakeChan:
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -438,8 +439,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
func (c *client) createNewGQUICSession() (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
||||
}
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
|
@ -458,8 +464,13 @@ func (c *client) createNewTLSSession(
|
|||
) (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ packetHandler) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
||||
}
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
|
|
|
@ -28,7 +28,7 @@ var _ = Describe("Client", func() {
|
|||
addr net.Addr
|
||||
connID protocol.ConnectionID
|
||||
|
||||
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error)
|
||||
originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (packetHandler, error)
|
||||
)
|
||||
|
||||
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
||||
|
@ -48,7 +48,7 @@ var _ = Describe("Client", func() {
|
|||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||
originalClientSessConstructor = newClientSession
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
msess, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
|
||||
msess, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
|
||||
sess = msess.(*mockSession)
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = newMockPacketConn()
|
||||
|
@ -97,6 +97,7 @@ var _ = Describe("Client", func() {
|
|||
remoteAddrChan := make(chan string)
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -126,6 +127,7 @@ var _ = Describe("Client", func() {
|
|||
hostnameChan := make(chan string)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
h string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -160,6 +162,7 @@ var _ = Describe("Client", func() {
|
|||
It("returns after the handshake is complete", func() {
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -169,6 +172,7 @@ var _ = Describe("Client", func() {
|
|||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (packetHandler, error) {
|
||||
runner.onHandshakeComplete(sess)
|
||||
return sess, nil
|
||||
}
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
|
||||
|
@ -180,14 +184,16 @@ var _ = Describe("Client", func() {
|
|||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
close(sess.handshakeChan)
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
// make the session run loop return
|
||||
close(sess.stopRunLoop)
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func() {
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -208,7 +214,8 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handshakeChan <- testErr
|
||||
sess.closeReason = testErr
|
||||
close(sess.stopRunLoop)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -269,6 +276,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("error creating session")
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -295,6 +303,7 @@ var _ = Describe("Client", func() {
|
|||
var conf *Config
|
||||
newTLSClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -344,6 +353,7 @@ var _ = Describe("Client", func() {
|
|||
It("returns an error that occurs during version negotiation", func() {
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -390,9 +400,10 @@ var _ = Describe("Client", func() {
|
|||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}}
|
||||
sessionChan := make(chan *mockSession)
|
||||
handshakeChan := make(chan error)
|
||||
stopRunLoop := make(chan struct{})
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
|
@ -406,9 +417,8 @@ var _ = Describe("Client", func() {
|
|||
negotiatedVersions = negotiatedVersionsP
|
||||
|
||||
sess := &mockSession{
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
handshakeChan: handshakeChan,
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: stopRunLoop,
|
||||
}
|
||||
sessionChan <- sess
|
||||
return sess, nil
|
||||
|
@ -441,7 +451,6 @@ var _ = Describe("Client", func() {
|
|||
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
||||
Expect(initialVersion).To(Equal(actualInitialVersion))
|
||||
|
||||
close(handshakeChan)
|
||||
Eventually(established).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -449,6 +458,7 @@ var _ = Describe("Client", func() {
|
|||
sessionCounter := uint32(0)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
|
@ -613,6 +623,7 @@ var _ = Describe("Client", func() {
|
|||
var conf *Config
|
||||
newClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
@ -651,6 +662,7 @@ var _ = Describe("Client", func() {
|
|||
sessionChan := make(chan *mockSession)
|
||||
newTLSClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
|
|
55
mock_session_runner_test.go
Normal file
55
mock_session_runner_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go (interfaces: SessionRunner)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockSessionRunner is a mock of SessionRunner interface
|
||||
type MockSessionRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSessionRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockSessionRunnerMockRecorder is the mock recorder for MockSessionRunner
|
||||
type MockSessionRunnerMockRecorder struct {
|
||||
mock *MockSessionRunner
|
||||
}
|
||||
|
||||
// NewMockSessionRunner creates a new mock instance
|
||||
func NewMockSessionRunner(ctrl *gomock.Controller) *MockSessionRunner {
|
||||
mock := &MockSessionRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockSessionRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// onHandshakeComplete mocks base method
|
||||
func (m *MockSessionRunner) onHandshakeComplete(arg0 packetHandler) {
|
||||
m.ctrl.Call(m, "onHandshakeComplete", arg0)
|
||||
}
|
||||
|
||||
// onHandshakeComplete indicates an expected call of onHandshakeComplete
|
||||
func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0)
|
||||
}
|
||||
|
||||
// removeConnectionID mocks base method
|
||||
func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.Call(m, "removeConnectionID", arg0)
|
||||
}
|
||||
|
||||
// removeConnectionID indicates an expected call of removeConnectionID
|
||||
func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0)
|
||||
}
|
|
@ -8,9 +8,9 @@ package quic
|
|||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
|
||||
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker"
|
||||
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
|
||||
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
|
||||
//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
|
||||
//go:generate sh -c "goimports -w mock*_test.go"
|
||||
|
|
48
server.go
48
server.go
|
@ -21,13 +21,27 @@ import (
|
|||
type packetHandler interface {
|
||||
Session
|
||||
getCryptoStream() cryptoStreamI
|
||||
handshakeStatus() <-chan error
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
run() error
|
||||
closeRemote(error)
|
||||
}
|
||||
|
||||
type sessionRunner interface {
|
||||
onHandshakeComplete(packetHandler)
|
||||
removeConnectionID(protocol.ConnectionID)
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
onHandshakeCompleteImpl func(packetHandler)
|
||||
removeConnectionIDImpl func(protocol.ConnectionID)
|
||||
}
|
||||
|
||||
func (r *runner) onHandshakeComplete(p packetHandler) { r.onHandshakeCompleteImpl(p) }
|
||||
func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
|
||||
|
||||
var _ sessionRunner = &runner{}
|
||||
|
||||
// A Listener of QUIC
|
||||
type server struct {
|
||||
tlsConf *tls.Config
|
||||
|
@ -45,12 +59,14 @@ type server struct {
|
|||
sessions map[string] /* string(ConnectionID)*/ packetHandler
|
||||
closed bool
|
||||
|
||||
serverError error
|
||||
serverError error
|
||||
|
||||
sessionQueue chan Session
|
||||
errorChan chan struct{}
|
||||
|
||||
sessionRunner sessionRunner
|
||||
// set as members, so they can be set in the tests
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error)
|
||||
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error)
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
|
||||
logger utils.Logger
|
||||
|
@ -112,6 +128,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
supportsTLS: supportsTLS,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
}
|
||||
s.sessionRunner = &runner{
|
||||
onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess },
|
||||
removeConnectionIDImpl: s.removeConnection,
|
||||
}
|
||||
if supportsTLS {
|
||||
if err := s.setupTLS(); err != nil {
|
||||
return nil, err
|
||||
|
@ -127,7 +147,7 @@ func (s *server) setupTLS() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger)
|
||||
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -148,7 +168,7 @@ func (s *server) setupTLS() error {
|
|||
}
|
||||
s.sessions[string(connID)] = sess
|
||||
s.sessionsMutex.Unlock()
|
||||
s.runHandshakeAndSession(sess, connID)
|
||||
go sess.run()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -415,6 +435,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
|||
var err error
|
||||
session, err = s.newSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
s.sessionRunner,
|
||||
version,
|
||||
hdr.DestConnectionID,
|
||||
s.scfg,
|
||||
|
@ -429,7 +450,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
|||
s.sessions[string(hdr.DestConnectionID)] = session
|
||||
s.sessionsMutex.Unlock()
|
||||
|
||||
s.runHandshakeAndSession(session, hdr.DestConnectionID)
|
||||
go session.run()
|
||||
}
|
||||
|
||||
session.handlePacket(&receivedPacket{
|
||||
|
@ -441,21 +462,6 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
|
||||
go func() {
|
||||
_ = session.run()
|
||||
// session.run() returns as soon as the session is closed
|
||||
s.removeConnection(connID)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := <-session.handshakeStatus(); err != nil {
|
||||
return
|
||||
}
|
||||
s.sessionQueue <- session
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[string(id)] = nil
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -22,13 +21,13 @@ import (
|
|||
)
|
||||
|
||||
type mockSession struct {
|
||||
runner sessionRunner
|
||||
connectionID protocol.ConnectionID
|
||||
handledPackets []*receivedPacket
|
||||
closed bool
|
||||
closeReason error
|
||||
closedRemote bool
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
handshakeChan chan error
|
||||
}
|
||||
|
||||
func (s *mockSession) handlePacket(p *receivedPacket) {
|
||||
|
@ -67,13 +66,13 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not impl
|
|||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
|
||||
func newMockSession(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
_ *handshake.ServerConfig,
|
||||
|
@ -82,9 +81,9 @@ func newMockSession(
|
|||
_ utils.Logger,
|
||||
) (packetHandler, error) {
|
||||
s := mockSession{
|
||||
connectionID: connectionID,
|
||||
handshakeChan: make(chan error),
|
||||
stopRunLoop: make(chan struct{}),
|
||||
runner: runner,
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
@ -181,7 +180,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("accepts new TLS sessions", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -198,9 +197,9 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("only accepts one new TLS sessions for one connection ID", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -224,38 +223,45 @@ var _ = Describe("Server", func() {
|
|||
}).Should(Equal(sess1))
|
||||
})
|
||||
|
||||
It("accepts a session once the connection it is forward secure", func(done Done) {
|
||||
It("accepts a session once the connection it is forward secure", func() {
|
||||
var acceptedSess Session
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
acceptedSess, err = serv.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
Consistently(func() Session { return acceptedSess }).Should(BeNil())
|
||||
close(sess.handshakeChan)
|
||||
serv.sessionQueue <- sess
|
||||
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't accept session that error during the handshake", func(done Done) {
|
||||
var accepted bool
|
||||
It("doesn't accept sessions that error during the handshake", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.Accept()
|
||||
accepted = true
|
||||
close(done)
|
||||
}()
|
||||
err := serv.handlePacket(nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[string(connID)].(*mockSession)
|
||||
sess.handshakeChan <- errors.New("handshake failed")
|
||||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||
close(done)
|
||||
sess.closeReason = errors.New("handshake failed")
|
||||
close(sess.stopRunLoop)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
serv.removeConnection(connID)
|
||||
close(serv.errorChan)
|
||||
serv.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
|
@ -268,16 +274,10 @@ var _ = Describe("Server", func() {
|
|||
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("closes and deletes sessions", func() {
|
||||
It("deletes sessions", func() {
|
||||
serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).ToNot(BeNil())
|
||||
// make session.run() return
|
||||
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
serv.removeConnection(connID)
|
||||
// The server should now have closed the session, leaving a nil value in the sessions map
|
||||
Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1))
|
||||
Expect(serv.sessions[string(connID)]).To(BeNil())
|
||||
|
@ -285,14 +285,9 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("deletes nil session entries after a wait time", func() {
|
||||
serv.deleteClosedSessionsAfter = 25 * time.Millisecond
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serv.handlePacket(nil, append(firstPacket, nullAEAD.Seal(nil, nil, 0, firstPacket)...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions).To(HaveKey(string(connID)))
|
||||
serv.sessions[string(connID)] = &mockSession{}
|
||||
// make session.run() return
|
||||
serv.sessions[string(connID)].(*mockSession).stopRunLoop <- struct{}{}
|
||||
serv.removeConnection(connID)
|
||||
Eventually(func() bool {
|
||||
serv.sessionsMutex.Lock()
|
||||
_, ok := serv.sessions[string(connID)]
|
||||
|
@ -303,7 +298,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("closes sessions and the connection when Close is called", func() {
|
||||
go serv.serve()
|
||||
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
|
||||
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
|
||||
serv.sessions[string(connID)] = session
|
||||
err := serv.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -353,7 +348,7 @@ var _ = Describe("Server", func() {
|
|||
}, 0.5)
|
||||
|
||||
It("closes all sessions when encountering a connection error", func() {
|
||||
session, _ := newMockSession(nil, 0, connID, nil, nil, nil, nil)
|
||||
session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil)
|
||||
serv.sessions[string(connID)] = session
|
||||
Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse())
|
||||
testErr := errors.New("connection error")
|
||||
|
|
|
@ -42,7 +42,8 @@ type serverTLS struct {
|
|||
params *handshake.TransportParameters
|
||||
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
|
||||
|
||||
sessionChan chan<- tlsSession
|
||||
sessionRunner sessionRunner
|
||||
sessionChan chan<- tlsSession
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -50,6 +51,7 @@ type serverTLS struct {
|
|||
func newServerTLS(
|
||||
conn net.PacketConn,
|
||||
config *Config,
|
||||
runner sessionRunner,
|
||||
cookieHandler *handshake.CookieHandler,
|
||||
tlsConf *tls.Config,
|
||||
logger utils.Logger,
|
||||
|
@ -72,6 +74,7 @@ func newServerTLS(
|
|||
config: config,
|
||||
supportedVersions: config.Versions,
|
||||
mintConf: mconf,
|
||||
sessionRunner: runner,
|
||||
sessionChan: sessionChan,
|
||||
params: &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
|
@ -214,6 +217,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
|||
params := <-paramsChan
|
||||
sess, err := newTLSServerSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
s.sessionRunner,
|
||||
hdr.SrcConnectionID,
|
||||
hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here
|
||||
protocol.PacketNumber(1), // TODO: use a random packet number here
|
||||
|
|
|
@ -37,7 +37,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
var err error
|
||||
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig(), utils.DefaultLogger)
|
||||
server, sessionChan, err = newServerTLS(conn, config, nil, nil, testdata.GetTLSConfig(), utils.DefaultLogger)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
|
||||
mintReply = bc
|
||||
|
|
31
session.go
31
session.go
|
@ -73,6 +73,8 @@ type closeError struct {
|
|||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
sessionRunner sessionRunner
|
||||
|
||||
destConnID protocol.ConnectionID
|
||||
srcConnID protocol.ConnectionID
|
||||
|
||||
|
@ -116,11 +118,7 @@ type session struct {
|
|||
paramsChan <-chan handshake.TransportParameters
|
||||
// the handshakeEvent channel is passed to the CryptoSetup.
|
||||
// It receives when it makes sense to try decrypting undecryptable packets.
|
||||
handshakeEvent <-chan struct{}
|
||||
// handshakeChan is returned by handshakeStatus.
|
||||
// It receives any error that might occur during the handshake.
|
||||
// It is closed when the handshake is complete.
|
||||
handshakeChan chan error
|
||||
handshakeEvent <-chan struct{}
|
||||
handshakeComplete bool
|
||||
|
||||
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
||||
|
@ -151,6 +149,7 @@ var _ streamSender = &session{}
|
|||
// newSession makes a new session
|
||||
func newSession(
|
||||
conn connection,
|
||||
sessionRunner sessionRunner,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
scfg *handshake.ServerConfig,
|
||||
|
@ -162,6 +161,7 @@ func newSession(
|
|||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
sessionRunner: sessionRunner,
|
||||
srcConnID: connectionID,
|
||||
destConnID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
|
@ -221,6 +221,7 @@ func newSession(
|
|||
// declare this as a variable, so that we can it mock it in the tests
|
||||
var newClientSession = func(
|
||||
conn connection,
|
||||
sessionRunner sessionRunner,
|
||||
hostname string,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
|
@ -234,6 +235,7 @@ var newClientSession = func(
|
|||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
sessionRunner: sessionRunner,
|
||||
srcConnID: connectionID,
|
||||
destConnID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
|
@ -288,6 +290,7 @@ var newClientSession = func(
|
|||
|
||||
func newTLSServerSession(
|
||||
conn connection,
|
||||
runner sessionRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
|
@ -302,6 +305,7 @@ func newTLSServerSession(
|
|||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
sessionRunner: runner,
|
||||
config: config,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
|
@ -345,6 +349,7 @@ func newTLSServerSession(
|
|||
// declare this as a variable, such that we can it mock it in the tests
|
||||
var newTLSClientSession = func(
|
||||
conn connection,
|
||||
runner sessionRunner,
|
||||
hostname string,
|
||||
v protocol.VersionNumber,
|
||||
destConnID protocol.ConnectionID,
|
||||
|
@ -358,6 +363,7 @@ var newTLSClientSession = func(
|
|||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
sessionRunner: runner,
|
||||
config: config,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
|
@ -413,7 +419,6 @@ func (s *session) preSetup() {
|
|||
}
|
||||
|
||||
func (s *session) postSetup() error {
|
||||
s.handshakeChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
|
@ -527,13 +532,11 @@ runLoop:
|
|||
}
|
||||
}
|
||||
|
||||
// only send the error the handshakeChan when the handshake is not completed yet
|
||||
// otherwise this chan will already be closed
|
||||
if !s.handshakeComplete {
|
||||
s.handshakeChan <- closeErr.err
|
||||
if err := s.handleCloseError(closeErr); err != nil {
|
||||
s.logger.Infof("Handling close error failed: %s", err)
|
||||
}
|
||||
s.handleCloseError(closeErr)
|
||||
s.logger.Infof("Connection %s closed.", s.srcConnID)
|
||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||
return closeErr.err
|
||||
}
|
||||
|
||||
|
@ -580,6 +583,7 @@ func (s *session) handleHandshakeEvent(completed bool) {
|
|||
}
|
||||
s.handshakeComplete = true
|
||||
s.handshakeEvent = nil // prevent this case from ever being selected again
|
||||
s.sessionRunner.onHandshakeComplete(s)
|
||||
|
||||
// In gQUIC, the server completes the handshake first (after sending the SHLO).
|
||||
// In TLS 1.3, the client completes the handshake first (after sending the CFIN).
|
||||
|
@ -593,7 +597,6 @@ func (s *session) handleHandshakeEvent(completed bool) {
|
|||
s.queueControlFrame(&wire.PingFrame{})
|
||||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
}
|
||||
close(s.handshakeChan)
|
||||
}
|
||||
|
||||
func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||
|
@ -1239,10 +1242,6 @@ func (s *session) RemoteAddr() net.Addr {
|
|||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *session) handshakeStatus() <-chan error {
|
||||
return s.handshakeChan
|
||||
}
|
||||
|
||||
func (s *session) getCryptoStream() cryptoStreamI {
|
||||
return s.cryptoStream
|
||||
}
|
||||
|
|
177
session_test.go
177
session_test.go
|
@ -68,6 +68,7 @@ func areSessionsRunning() bool {
|
|||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
sessionRunner *MockSessionRunner
|
||||
scfg *handshake.ServerConfig
|
||||
mconn *mockConnection
|
||||
cryptoSetup *mockCryptoSetup
|
||||
|
@ -97,6 +98,7 @@ var _ = Describe("Session", func() {
|
|||
return cryptoSetup, nil
|
||||
}
|
||||
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
mconn = newMockConnection()
|
||||
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
|
@ -106,6 +108,7 @@ var _ = Describe("Session", func() {
|
|||
var pSess Session
|
||||
pSess, err = newSession(
|
||||
mconn,
|
||||
sessionRunner,
|
||||
protocol.Version39,
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
scfg,
|
||||
|
@ -159,6 +162,7 @@ var _ = Describe("Session", func() {
|
|||
}
|
||||
pSess, err := newSession(
|
||||
mconn,
|
||||
sessionRunner,
|
||||
protocol.Version39,
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
scfg,
|
||||
|
@ -471,17 +475,15 @@ var _ = Describe("Session", func() {
|
|||
It("handles CONNECTION_CLOSE frames", func() {
|
||||
testErr := qerr.Error(qerr.ProofInvalid, "foobar")
|
||||
streamManager.EXPECT().CloseWithError(testErr)
|
||||
done := make(chan struct{})
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -510,6 +512,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("shuts down without error", func() {
|
||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
|
@ -522,6 +525,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("only closes once", func() {
|
||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, ""))
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
sess.Close(nil)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
|
@ -532,6 +536,7 @@ var _ = Describe("Session", func() {
|
|||
It("closes streams with proper error", func() {
|
||||
testErr := errors.New("test error")
|
||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(testErr)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
|
@ -539,6 +544,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
|
||||
|
@ -546,6 +552,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("sends a Public Reset if the client is initiating the head-of-line blocking experiment", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(handshake.ErrHOLExperiment)
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
|
||||
|
@ -554,6 +561,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("sends a Public Reset if the client is initiating the no STOP_WAITING experiment", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(handshake.ErrHOLExperiment)
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect((<-mconn.written)[0] & 0x02).ToNot(BeZero()) // Public Reset
|
||||
|
@ -562,6 +570,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("cancels the context when the run loop exists", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
returned := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -619,20 +628,21 @@ var _ = Describe("Session", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes when handling a packet fails", func(done Done) {
|
||||
It("closes when handling a packet fails", func() {
|
||||
testErr := errors.New("unpack error")
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
hdr.PacketNumber = 5
|
||||
var runErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runErr = sess.run()
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handlePacket(&receivedPacket{header: hdr})
|
||||
Eventually(func() error { return runErr }).Should(MatchError(testErr))
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() {
|
||||
|
@ -886,6 +896,7 @@ var _ = Describe("Session", func() {
|
|||
Eventually(mconn.written).Should(HaveLen(2))
|
||||
Consistently(mconn.written).Should(HaveLen(2))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -909,6 +920,7 @@ var _ = Describe("Session", func() {
|
|||
Eventually(mconn.written).Should(HaveLen(1))
|
||||
Consistently(mconn.written).Should(HaveLen(1))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -936,6 +948,7 @@ var _ = Describe("Session", func() {
|
|||
Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1))
|
||||
Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -958,6 +971,7 @@ var _ = Describe("Session", func() {
|
|||
sess.scheduleSending()
|
||||
Eventually(mconn.written).Should(HaveLen(3))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -978,6 +992,7 @@ var _ = Describe("Session", func() {
|
|||
sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1})
|
||||
Consistently(mconn.written).ShouldNot(Receive())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -1019,6 +1034,7 @@ var _ = Describe("Session", func() {
|
|||
sess.scheduleSending()
|
||||
Eventually(mconn.written).Should(HaveLen(1))
|
||||
// make sure that the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1049,6 +1065,7 @@ var _ = Describe("Session", func() {
|
|||
sess.scheduleSending()
|
||||
Eventually(mconn.written).Should(HaveLen(1))
|
||||
// make sure that the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1210,19 +1227,18 @@ var _ = Describe("Session", func() {
|
|||
sph.EXPECT().SentPacket(gomock.Any())
|
||||
sess.sentPacketHandler = sph
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
Consistently(mconn.written).ShouldNot(Receive())
|
||||
sess.scheduleSending()
|
||||
Eventually(mconn.written).Should(Receive())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sets the timer to the ack timer", func() {
|
||||
|
@ -1243,32 +1259,31 @@ var _ = Describe("Session", func() {
|
|||
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond))
|
||||
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
|
||||
sess.receivedPacketHandler = rph
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
Eventually(mconn.written).Should(Receive())
|
||||
// make sure the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
It("closes when crypto stream errors", func() {
|
||||
testErr := errors.New("crypto setup error")
|
||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
cryptoSetup.handleErr = testErr
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() {
|
||||
|
@ -1303,7 +1318,9 @@ var _ = Describe("Session", func() {
|
|||
sendUndecryptablePackets()
|
||||
sess.scheduleSending()
|
||||
Consistently(mconn.written).Should(HaveLen(0))
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -1314,6 +1331,8 @@ var _ = Describe("Session", func() {
|
|||
}()
|
||||
sendUndecryptablePackets()
|
||||
Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
@ -1327,11 +1346,14 @@ var _ = Describe("Session", func() {
|
|||
Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets))
|
||||
// check that old packets are kept, and the new packets are dropped
|
||||
Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a Public Reset after a timeout", func() {
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -1359,7 +1381,9 @@ var _ = Describe("Session", func() {
|
|||
// in reality, this happens when the trial decryption succeeded during the Public Reset timeout
|
||||
Consistently(mconn.written).ShouldNot(HaveLen(1))
|
||||
Expect(sess.Context().Done()).ToNot(Receive())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -1371,6 +1395,8 @@ var _ = Describe("Session", func() {
|
|||
}()
|
||||
sendUndecryptablePackets()
|
||||
Consistently(sess.undecryptablePackets).Should(BeEmpty())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
@ -1387,38 +1413,35 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
sess.run()
|
||||
}()
|
||||
handshakeChan <- struct{}{}
|
||||
Consistently(sess.handshakeStatus()).ShouldNot(Receive())
|
||||
// don't EXPECT any calls to sessionRunner.onHandshakeComplete()
|
||||
// make sure the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the handshakeChan when the handshake completes", func() {
|
||||
done := make(chan struct{})
|
||||
It("calls the onHandshakeComplete callback when the handshake completes", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
sess.run()
|
||||
}()
|
||||
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
|
||||
close(handshakeChan)
|
||||
Eventually(sess.handshakeStatus()).Should(BeClosed())
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
// make sure the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("passes errors to the handshakeChan", func() {
|
||||
It("passes errors to the session runner", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -1428,19 +1451,17 @@ var _ = Describe("Session", func() {
|
|||
close(done)
|
||||
}()
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(testErr)
|
||||
Expect(sess.handshakeStatus()).To(Receive(Equal(testErr)))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("process transport parameters received from the peer", func() {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
sess.paramsChan = paramsChan
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
params := handshake.TransportParameters{
|
||||
MaxStreams: 123,
|
||||
|
@ -1457,8 +1478,9 @@ var _ = Describe("Session", func() {
|
|||
Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42)))
|
||||
// make the go routine return
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("keep-alives", func() {
|
||||
|
@ -1486,6 +1508,7 @@ var _ = Describe("Session", func() {
|
|||
// -12 because of the crypto tag. This should be 7 (the frame id for a ping frame).
|
||||
Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07}))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1503,6 +1526,7 @@ var _ = Describe("Session", func() {
|
|||
}()
|
||||
Consistently(mconn.written).ShouldNot(Receive())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1520,6 +1544,7 @@ var _ = Describe("Session", func() {
|
|||
}()
|
||||
Consistently(mconn.written).ShouldNot(Receive())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1531,23 +1556,33 @@ var _ = Describe("Session", func() {
|
|||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
})
|
||||
|
||||
It("times out due to no network activity", func(done Done) {
|
||||
It("times out due to no network activity", func() {
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.handshakeComplete = true
|
||||
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("times out due to non-completed handshake", func(done Done) {
|
||||
It("times out due to non-completed handshake", func() {
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second)
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time.")))
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("does not use the idle timeout before the handshake complete", func() {
|
||||
|
@ -1556,28 +1591,31 @@ var _ = Describe("Session", func() {
|
|||
sess.lastNetworkActivityTime = time.Now().Add(-time.Minute)
|
||||
// the handshake timeout is irrelevant here, since it depends on the time the session was created,
|
||||
// and not on the last network activity
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_ = sess.run()
|
||||
close(done)
|
||||
sess.run()
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.Close(nil)
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session due to the idle timeout after handshake", func() {
|
||||
sessionRunner.EXPECT().onHandshakeComplete(sess)
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
sess.config.IdleTimeout = 0
|
||||
close(handshakeChan)
|
||||
errChan := make(chan error)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
errChan <- sess.run() // Would normally not return
|
||||
err := sess.run()
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
close(done)
|
||||
}()
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity.")))
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -1679,6 +1717,7 @@ var _ = Describe("Session", func() {
|
|||
var _ = Describe("Client Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
sessionRunner *MockSessionRunner
|
||||
mconn *mockConnection
|
||||
handshakeChan chan<- struct{}
|
||||
|
||||
|
@ -1707,8 +1746,10 @@ var _ = Describe("Client Session", func() {
|
|||
}
|
||||
|
||||
mconn = newMockConnection()
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
sessP, err := newClientSession(
|
||||
mconn,
|
||||
sessionRunner,
|
||||
"hostname",
|
||||
protocol.Version39,
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
|
@ -1727,19 +1768,18 @@ var _ = Describe("Client Session", func() {
|
|||
})
|
||||
|
||||
It("sends a forward-secure packet when the handshake completes", func() {
|
||||
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
|
||||
sess.packer.hasSentPacket = true
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
sess.run()
|
||||
}()
|
||||
close(handshakeChan)
|
||||
Eventually(mconn.written).Should(Receive())
|
||||
//make sure the go routine returns
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("receiving packets", func() {
|
||||
|
@ -1755,20 +1795,19 @@ var _ = Describe("Client Session", func() {
|
|||
unpacker := NewMockUnpacker(mockCtrl)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
|
||||
sess.unpacker = unpacker
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
sess.run()
|
||||
}()
|
||||
hdr.PacketNumber = 5
|
||||
hdr.DiversificationNonce = []byte("foobar")
|
||||
err := sess.handlePacketImpl(&receivedPacket{header: hdr})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce))
|
||||
// make the go routine return
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue