handle Version Negotiation packets in the session

This commit is contained in:
Marten Seemann 2020-06-30 17:13:50 +07:00
parent 6b42c7a045
commit 06ad477b9b
6 changed files with 226 additions and 298 deletions

View file

@ -11,7 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qlog"
)
@ -27,20 +26,15 @@ type client struct {
packetHandlers packetHandlerManager
versionNegotiated utils.AtomicBool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.VersionNumber
handshakeChan chan struct{}
@ -268,8 +262,9 @@ func (c *client) dial(ctx context.Context) error {
c.config,
c.tlsConf,
c.initialPacketNumber,
c.initialVersion,
c.version,
c.use0RTT,
c.hasNegotiatedVersion,
c.qlogger,
c.logger,
c.version,
@ -280,7 +275,7 @@ func (c *client) dial(ctx context.Context) error {
errorChan := make(chan error, 1)
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseForRecreating && c.createdPacketConn {
if !errors.Is(err, errCloseForRecreating{}) && c.createdPacketConn {
c.packetHandlers.Destroy()
}
errorChan <- err
@ -298,7 +293,11 @@ func (c *client) dial(ctx context.Context) error {
c.session.shutdown()
return ctx.Err()
case err := <-errorChan:
if err == errCloseForRecreating {
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
}
return err
@ -312,75 +311,9 @@ func (c *client) dial(ctx context.Context) error {
}
func (c *client) handlePacket(p *receivedPacket) {
if wire.IsVersionNegotiationPacket(p.data) {
go c.handleVersionNegotiationPacket(p)
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated.Get() {
c.versionNegotiated.Set(true)
}
c.session.handlePacket(p)
}
func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
c.mutex.Lock()
defer c.mutex.Unlock()
hdr, _, _, err := wire.ParsePacket(p.data, 0)
if err != nil {
if c.qlogger != nil {
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
}
c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
return
}
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
if c.qlogger != nil {
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
}
c.logger.Debugf("Received a delayed Version Negotiation packet.")
return
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
if c.qlogger != nil {
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
}
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
return
}
}
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
if c.qlogger != nil {
c.qlogger.ReceivedVersionNegotiationPacket(hdr)
}
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
//nolint:stylecheck
c.session.destroy(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s", c.config.Versions, hdr.SupportedVersions))
c.logger.Debugf("No compatible QUIC version found.")
return
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) shutdown() {
c.mutex.Lock()
defer c.mutex.Unlock()

View file

@ -47,6 +47,7 @@ var _ = Describe("Client", func() {
initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
qlogger qlog.Tracer,
logger utils.Logger,
v protocol.VersionNumber,
@ -65,16 +66,6 @@ var _ = Describe("Client", func() {
return b.Bytes()
}
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
Expect(err).ToNot(HaveOccurred())
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
return &receivedPacket{
rcvTime: time.Now(),
data: data,
}
}
BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
@ -169,6 +160,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -201,6 +193,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -233,6 +226,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -271,6 +265,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
enable0RTT bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -313,6 +308,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
enable0RTT bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -360,6 +356,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -403,6 +400,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -454,6 +452,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@ -574,6 +573,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber, /* initial version */
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
versionP protocol.VersionNumber,
@ -596,183 +596,58 @@ var _ = Describe("Client", func() {
Expect(conf.Versions).To(Equal(config.Versions))
})
Context("version negotiation", func() {
var origSupportedVersions []protocol.VersionNumber
It("creates a new session after version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
BeforeEach(func() {
origSupportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
})
initialVersion := cl.version
AfterEach(func() {
protocol.SupportedVersions = origSupportedVersions
})
It("returns an error that occurs during version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
testErr := errors.New("early handshake error")
newClientSession = func(
conn connection,
_ sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
var counter int
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
version protocol.VersionNumber,
_ bool,
hasNegotiatedVersion bool,
_ qlog.Tracer,
_ utils.Logger,
versionP protocol.VersionNumber,
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().HandshakeComplete().Return(context.Background())
if counter == 0 {
Expect(pn).To(BeZero())
Expect(version).To(Equal(initialVersion))
Expect(hasNegotiatedVersion).To(BeFalse())
sess.EXPECT().run().Return(&errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
})
} else {
Expect(pn).To(Equal(protocol.PacketNumber(109)))
Expect(version).ToNot(Equal(initialVersion))
Expect(version).To(Equal(protocol.VersionNumber(789)))
Expect(hasNegotiatedVersion).To(BeTrue())
sess.EXPECT().run()
}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
Expect(err).To(MatchError(testErr))
})
counter++
return sess
}
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: connID,
SrcConnectionID: connID,
Version: cl.version,
},
PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, protocol.VersionTLS)).To(Succeed())
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
})
// Illustrates that adversary that injects a version negotiation packet
// with no supported versions can break a connection.
It("errors if no matching version is found", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
close(done)
})
cl.session = sess
cl.config = &Config{Versions: protocol.SupportedVersions}
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})
hdr, _, _, err := wire.ParsePacket(p.data, 0)
Expect(err).ToNot(HaveOccurred())
qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr)
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
close(done)
})
cl.session = sess
v := protocol.VersionNumber(1234)
Expect(v).ToNot(Equal(cl.version))
cl.config = &Config{Versions: protocol.SupportedVersions}
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
Eventually(done).Should(BeClosed())
})
It("changes to the version preferred by the quic.Config", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
cl.packetHandlers = phm
sess := NewMockQuicSession(mockCtrl)
destroyed := make(chan struct{})
sess.EXPECT().closeForRecreating().Do(func() {
close(destroyed)
})
cl.session = sess
versions := []protocol.VersionNumber{1234, 4321}
cl.config = &Config{Versions: versions}
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
Eventually(destroyed).Should(BeClosed())
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
})
It("drops unparseable version negotiation packets", func() {
cl.config = config
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
p.data = p.data[:len(p.data)-1]
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
It("drops version negotiation packets if any other packet was received before", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: connID,
SrcConnectionID: connID,
Version: cl.version,
},
PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, protocol.VersionTLS)).To(Succeed())
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
It("drops version negotiation packets that contain the offered version", func() {
cl.config = config
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
gomock.InOrder(
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()),
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()),
)
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(counter).To(Equal(2))
})
})

View file

@ -225,20 +225,6 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
}
// closeForRecreating mocks base method
func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "closeForRecreating")
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// closeForRecreating indicates an expected call of closeForRecreating
func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating))
}
// destroy mocks base method
func (m *MockQuicSession) destroy(arg0 error) {
m.ctrl.T.Helper()

View file

@ -49,7 +49,6 @@ type quicSession interface {
run() error
destroy(error)
shutdown()
closeForRecreating() protocol.PacketNumber
}
// A Listener of QUIC

View file

@ -104,7 +104,19 @@ type closeError struct {
immediate bool
}
var errCloseForRecreating = errors.New("closing session in order to recreate it")
type errCloseForRecreating struct {
nextPacketNumber protocol.PacketNumber
nextVersion protocol.VersionNumber
}
func (errCloseForRecreating) Error() string {
return "closing session in order to recreate it"
}
func (errCloseForRecreating) Is(target error) bool {
_, ok := target.(errCloseForRecreating)
return ok
}
// A Session is a QUIC session
type session struct {
@ -169,6 +181,7 @@ type session struct {
handshakeConfirmed bool
receivedRetry bool
versionNegotiated bool
receivedFirstPacket bool
idleTimeout time.Duration
@ -336,6 +349,7 @@ var newClientSession = func(
initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
qlogger qlog.Tracer,
logger utils.Logger,
v protocol.VersionNumber,
@ -352,6 +366,7 @@ var newClientSession = func(
logger: logger,
qlogger: qlogger,
initialVersion: initialVersion,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
@ -595,7 +610,7 @@ runLoop:
}
s.handleCloseError(closeErr)
if closeErr.err != errCloseForRecreating && s.qlogger != nil {
if !errors.Is(closeErr.err, errCloseForRecreating{}) && s.qlogger != nil {
if err := s.qlogger.Export(); err != nil {
s.logger.Errorf("exporting qlog failed: %s", err)
}
@ -692,6 +707,11 @@ func (s *session) handleHandshakeComplete() {
}
func (s *session) handlePacketImpl(rp *receivedPacket) bool {
if wire.IsVersionNegotiationPacket(rp.data) {
s.handleVersionNegotiationPacket(rp)
return false
}
var counter uint8
var lastConnID protocol.ConnectionID
var processed bool
@ -888,6 +908,55 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
return true
}
func (s *session) handleVersionNegotiationPacket(p *receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.qlogger != nil {
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
}
return
}
hdr, _, _, err := wire.ParsePacket(p.data, 0)
if err != nil {
if s.qlogger != nil {
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
return
}
for _, v := range hdr.SupportedVersions {
if v == s.version {
if s.qlogger != nil {
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
}
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker, or it was corrupted.
return
}
}
s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
if s.qlogger != nil {
s.qlogger.ReceivedVersionNegotiationPacket(hdr)
}
newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, hdr.SupportedVersions)
if !ok {
//nolint:stylecheck
s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, hdr.SupportedVersions))
s.logger.Infof("No compatible QUIC version found.")
return
}
s.logger.Infof("Switching to QUIC version %s.", newVersion)
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
s.destroyImpl(&errCloseForRecreating{
nextPacketNumber: nextPN,
nextVersion: newVersion,
})
}
func (s *session) handleUnpackedPacket(
packet *unpackedPacket,
rcvTime time.Time,
@ -1190,14 +1259,6 @@ func (s *session) destroyImpl(e error) {
})
}
// closeForRecreating closes the session in order to recreate it immediately afterwards
// It returns the first packet number that should be used in the new session.
func (s *session) closeForRecreating() protocol.PacketNumber {
s.destroy(errCloseForRecreating)
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
return nextPN
}
func (s *session) closeRemote(e error) {
s.closeOnce.Do(func() {
s.logger.Errorf("Peer closed session with error: %s", e)

View file

@ -487,18 +487,6 @@ var _ = Describe("Session", func() {
Expect(sess.Context().Done()).To(BeClosed())
})
It("closes the session in order to recreate it", func() {
runSession()
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
// don't EXPECT any calls to mconn.Write()
// don't EXPECT any call to qlogger.Export()
sess.closeForRecreating()
Eventually(areSessionsRunning).Should(BeFalse())
expectedRunErr = errCloseForRecreating
})
It("destroys the session", func() {
runSession()
testErr := errors.New("close")
@ -603,6 +591,16 @@ var _ = Describe("Session", func() {
Expect(sess.handlePacketImpl(p)).To(BeFalse())
})
It("drops Version Negotiation packets", func() {
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, sess.config.Versions)
Expect(err).ToNot(HaveOccurred())
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), qlog.PacketDropUnexpectedPacket)
Expect(sess.handlePacketImpl(&receivedPacket{
data: b,
buffer: getPacketBuffer(),
})).To(BeFalse())
})
It("drops packets for which header decryption fails", func() {
p := getPacket(&wire.ExtendedHeader{
Header: wire.Header{
@ -2035,6 +2033,7 @@ var _ = Describe("Client Session", func() {
42, // initial packet number
protocol.VersionTLS,
false,
false,
qlogger,
utils.DefaultLogger,
protocol.VersionTLS,
@ -2133,6 +2132,81 @@ var _ = Describe("Client Session", func() {
})
})
Context("handling Version Negotiation", func() {
getVNP := func(versions ...protocol.VersionNumber) *receivedPacket {
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions)
Expect(err).ToNot(HaveOccurred())
return &receivedPacket{
data: b,
buffer: getPacketBuffer(),
}
}
It("closes and returns the right error", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4)
sess.config.Versions = []protocol.VersionNumber{1234, 4321}
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
errChan <- sess.run()
}()
sessionRunner.EXPECT().Remove(srcConnID)
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()).Do(func(hdr *wire.Header) {
Expect(hdr.Version).To(BeZero())
Expect(hdr.SupportedVersions).To(And(
ContainElement(protocol.VersionNumber(4321)),
ContainElement(protocol.VersionNumber(1337)),
))
})
cryptoSetup.EXPECT().Close()
Expect(sess.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse())
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{}))
recreateErr := err.(*errCloseForRecreating)
Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321)))
Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128)))
})
It("it closes when no matching version is found", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
errChan <- sess.run()
}()
sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
gomock.InOrder(
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()),
qlogger.EXPECT().Export(),
)
cryptoSetup.EXPECT().Close()
Expect(sess.handlePacketImpl(getVNP(12345678))).To(BeFalse())
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
Expect(err).ToNot(BeAssignableToTypeOf(&errCloseForRecreating{}))
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found"))
})
It("ignores Version Negotiation packets that offer the current version", func() {
p := getVNP(sess.version)
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
Expect(sess.handlePacketImpl(p)).To(BeFalse())
})
It("ignores unparseable Version Negotiation packets", func() {
p := getVNP(sess.version)
p.data = p.data[:len(p.data)-2]
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
Expect(sess.handlePacketImpl(p)).To(BeFalse())
})
})
Context("handling Retry", func() {
origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}