mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
handle Version Negotiation packets in the session
This commit is contained in:
parent
6b42c7a045
commit
06ad477b9b
6 changed files with 226 additions and 298 deletions
89
client.go
89
client.go
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
"github.com/lucas-clemente/quic-go/qlog"
|
"github.com/lucas-clemente/quic-go/qlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,20 +26,15 @@ type client struct {
|
||||||
|
|
||||||
packetHandlers packetHandlerManager
|
packetHandlers packetHandlerManager
|
||||||
|
|
||||||
versionNegotiated utils.AtomicBool // has the server accepted our version
|
|
||||||
receivedVersionNegotiationPacket bool
|
|
||||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
srcConnID protocol.ConnectionID
|
srcConnID protocol.ConnectionID
|
||||||
destConnID protocol.ConnectionID
|
destConnID protocol.ConnectionID
|
||||||
|
|
||||||
initialPacketNumber protocol.PacketNumber
|
initialPacketNumber protocol.PacketNumber
|
||||||
|
hasNegotiatedVersion bool
|
||||||
initialVersion protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
version protocol.VersionNumber
|
|
||||||
|
|
||||||
handshakeChan chan struct{}
|
handshakeChan chan struct{}
|
||||||
|
|
||||||
|
@ -268,8 +262,9 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
c.config,
|
c.config,
|
||||||
c.tlsConf,
|
c.tlsConf,
|
||||||
c.initialPacketNumber,
|
c.initialPacketNumber,
|
||||||
c.initialVersion,
|
c.version,
|
||||||
c.use0RTT,
|
c.use0RTT,
|
||||||
|
c.hasNegotiatedVersion,
|
||||||
c.qlogger,
|
c.qlogger,
|
||||||
c.logger,
|
c.logger,
|
||||||
c.version,
|
c.version,
|
||||||
|
@ -280,7 +275,7 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := c.session.run() // returns as soon as the session is closed
|
err := c.session.run() // returns as soon as the session is closed
|
||||||
if err != errCloseForRecreating && c.createdPacketConn {
|
if !errors.Is(err, errCloseForRecreating{}) && c.createdPacketConn {
|
||||||
c.packetHandlers.Destroy()
|
c.packetHandlers.Destroy()
|
||||||
}
|
}
|
||||||
errorChan <- err
|
errorChan <- err
|
||||||
|
@ -298,7 +293,11 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
c.session.shutdown()
|
c.session.shutdown()
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case err := <-errorChan:
|
case err := <-errorChan:
|
||||||
if err == errCloseForRecreating {
|
var recreateErr *errCloseForRecreating
|
||||||
|
if errors.As(err, &recreateErr) {
|
||||||
|
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||||
|
c.version = recreateErr.nextVersion
|
||||||
|
c.hasNegotiatedVersion = true
|
||||||
return c.dial(ctx)
|
return c.dial(ctx)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -312,75 +311,9 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacket(p *receivedPacket) {
|
func (c *client) handlePacket(p *receivedPacket) {
|
||||||
if wire.IsVersionNegotiationPacket(p.data) {
|
|
||||||
go c.handleVersionNegotiationPacket(p)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is the first packet we are receiving
|
|
||||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
|
||||||
if !c.versionNegotiated.Get() {
|
|
||||||
c.versionNegotiated.Set(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.session.handlePacket(p)
|
c.session.handlePacket(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
|
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
|
|
||||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
|
||||||
if err != nil {
|
|
||||||
if c.qlogger != nil {
|
|
||||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
|
||||||
}
|
|
||||||
c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore delayed / duplicated version negotiation packets
|
|
||||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
|
|
||||||
if c.qlogger != nil {
|
|
||||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
|
|
||||||
}
|
|
||||||
c.logger.Debugf("Received a delayed Version Negotiation packet.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range hdr.SupportedVersions {
|
|
||||||
if v == c.version {
|
|
||||||
if c.qlogger != nil {
|
|
||||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
|
||||||
}
|
|
||||||
// The Version Negotiation packet contains the version that we offered.
|
|
||||||
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
|
||||||
if c.qlogger != nil {
|
|
||||||
c.qlogger.ReceivedVersionNegotiationPacket(hdr)
|
|
||||||
}
|
|
||||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
|
||||||
if !ok {
|
|
||||||
//nolint:stylecheck
|
|
||||||
c.session.destroy(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s", c.config.Versions, hdr.SupportedVersions))
|
|
||||||
c.logger.Debugf("No compatible QUIC version found.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.receivedVersionNegotiationPacket = true
|
|
||||||
c.negotiatedVersions = hdr.SupportedVersions
|
|
||||||
|
|
||||||
// switch to negotiated version
|
|
||||||
c.initialVersion = c.version
|
|
||||||
c.version = newVersion
|
|
||||||
|
|
||||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
|
||||||
c.initialPacketNumber = c.session.closeForRecreating()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) shutdown() {
|
func (c *client) shutdown() {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
241
client_test.go
241
client_test.go
|
@ -47,6 +47,7 @@ var _ = Describe("Client", func() {
|
||||||
initialPacketNumber protocol.PacketNumber,
|
initialPacketNumber protocol.PacketNumber,
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
hasNegotiatedVersion bool,
|
||||||
qlogger qlog.Tracer,
|
qlogger qlog.Tracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
|
@ -65,16 +66,6 @@ var _ = Describe("Client", func() {
|
||||||
return b.Bytes()
|
return b.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
|
|
||||||
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
|
|
||||||
return &receivedPacket{
|
|
||||||
rcvTime: time.Now(),
|
|
||||||
data: data,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||||
|
@ -169,6 +160,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -201,6 +193,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -233,6 +226,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -271,6 +265,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -313,6 +308,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -360,6 +356,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -403,6 +400,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -454,6 +452,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
|
@ -574,6 +573,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.PacketNumber,
|
_ protocol.PacketNumber,
|
||||||
_ protocol.VersionNumber, /* initial version */
|
_ protocol.VersionNumber, /* initial version */
|
||||||
_ bool,
|
_ bool,
|
||||||
|
_ bool,
|
||||||
_ qlog.Tracer,
|
_ qlog.Tracer,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
versionP protocol.VersionNumber,
|
versionP protocol.VersionNumber,
|
||||||
|
@ -596,183 +596,58 @@ var _ = Describe("Client", func() {
|
||||||
Expect(conf.Versions).To(Equal(config.Versions))
|
Expect(conf.Versions).To(Equal(config.Versions))
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("version negotiation", func() {
|
It("creates a new session after version negotiation", func() {
|
||||||
var origSupportedVersions []protocol.VersionNumber
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||||
|
manager.EXPECT().Destroy()
|
||||||
|
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||||
|
|
||||||
BeforeEach(func() {
|
initialVersion := cl.version
|
||||||
origSupportedVersions = protocol.SupportedVersions
|
|
||||||
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
var counter int
|
||||||
protocol.SupportedVersions = origSupportedVersions
|
newClientSession = func(
|
||||||
})
|
_ connection,
|
||||||
|
_ sessionRunner,
|
||||||
It("returns an error that occurs during version negotiation", func() {
|
_ protocol.ConnectionID,
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
_ protocol.ConnectionID,
|
||||||
manager.EXPECT().Add(connID, gomock.Any())
|
configP *Config,
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
_ *tls.Config,
|
||||||
|
pn protocol.PacketNumber,
|
||||||
testErr := errors.New("early handshake error")
|
version protocol.VersionNumber,
|
||||||
newClientSession = func(
|
_ bool,
|
||||||
conn connection,
|
hasNegotiatedVersion bool,
|
||||||
_ sessionRunner,
|
_ qlog.Tracer,
|
||||||
_ protocol.ConnectionID,
|
_ utils.Logger,
|
||||||
_ protocol.ConnectionID,
|
versionP protocol.VersionNumber,
|
||||||
_ *Config,
|
) quicSession {
|
||||||
_ *tls.Config,
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
_ protocol.PacketNumber,
|
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||||
_ protocol.VersionNumber,
|
if counter == 0 {
|
||||||
_ bool,
|
Expect(pn).To(BeZero())
|
||||||
_ qlog.Tracer,
|
Expect(version).To(Equal(initialVersion))
|
||||||
_ utils.Logger,
|
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||||
_ protocol.VersionNumber,
|
sess.EXPECT().run().Return(&errCloseForRecreating{
|
||||||
) quicSession {
|
nextPacketNumber: 109,
|
||||||
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
nextVersion: 789,
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
})
|
||||||
sess.EXPECT().run().Return(testErr)
|
} else {
|
||||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||||
return sess
|
Expect(version).ToNot(Equal(initialVersion))
|
||||||
|
Expect(version).To(Equal(protocol.VersionNumber(789)))
|
||||||
|
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||||
|
sess.EXPECT().run()
|
||||||
}
|
}
|
||||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
counter++
|
||||||
_, err := Dial(
|
return sess
|
||||||
packetConn,
|
}
|
||||||
addr,
|
|
||||||
"localhost:1337",
|
|
||||||
tlsConf,
|
|
||||||
config,
|
|
||||||
)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
|
gomock.InOrder(
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()),
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()),
|
||||||
cl.session = sess
|
)
|
||||||
cl.config = config
|
_, err := DialAddr("localhost:7890", tlsConf, config)
|
||||||
buf := &bytes.Buffer{}
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect((&wire.ExtendedHeader{
|
Expect(counter).To(Equal(2))
|
||||||
Header: wire.Header{
|
|
||||||
DestConnectionID: connID,
|
|
||||||
SrcConnectionID: connID,
|
|
||||||
Version: cl.version,
|
|
||||||
},
|
|
||||||
PacketNumberLen: protocol.PacketNumberLen3,
|
|
||||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
|
||||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
|
||||||
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Illustrates that adversary that injects a version negotiation packet
|
|
||||||
// with no supported versions can break a connection.
|
|
||||||
It("errors if no matching version is found", func() {
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
|
||||||
done := make(chan struct{})
|
|
||||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
cl.session = sess
|
|
||||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
|
||||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})
|
|
||||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr)
|
|
||||||
cl.handlePacket(p)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
|
||||||
done := make(chan struct{})
|
|
||||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
cl.session = sess
|
|
||||||
v := protocol.VersionNumber(1234)
|
|
||||||
Expect(v).ToNot(Equal(cl.version))
|
|
||||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
|
||||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
|
||||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("changes to the version preferred by the quic.Config", func() {
|
|
||||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
cl.packetHandlers = phm
|
|
||||||
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
|
||||||
destroyed := make(chan struct{})
|
|
||||||
sess.EXPECT().closeForRecreating().Do(func() {
|
|
||||||
close(destroyed)
|
|
||||||
})
|
|
||||||
cl.session = sess
|
|
||||||
versions := []protocol.VersionNumber{1234, 4321}
|
|
||||||
cl.config = &Config{Versions: versions}
|
|
||||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
|
||||||
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
|
|
||||||
Eventually(destroyed).Should(BeClosed())
|
|
||||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops unparseable version negotiation packets", func() {
|
|
||||||
cl.config = config
|
|
||||||
ver := cl.version
|
|
||||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
|
||||||
p.data = p.data[:len(p.data)-1]
|
|
||||||
done := make(chan struct{})
|
|
||||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
cl.handlePacket(p)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(cl.version).To(Equal(ver))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops version negotiation packets if any other packet was received before", func() {
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
|
||||||
cl.session = sess
|
|
||||||
cl.config = config
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
Expect((&wire.ExtendedHeader{
|
|
||||||
Header: wire.Header{
|
|
||||||
DestConnectionID: connID,
|
|
||||||
SrcConnectionID: connID,
|
|
||||||
Version: cl.version,
|
|
||||||
},
|
|
||||||
PacketNumberLen: protocol.PacketNumberLen3,
|
|
||||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
|
||||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
|
||||||
|
|
||||||
ver := cl.version
|
|
||||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})
|
|
||||||
done := make(chan struct{})
|
|
||||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
cl.handlePacket(p)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(cl.version).To(Equal(ver))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops version negotiation packets that contain the offered version", func() {
|
|
||||||
cl.config = config
|
|
||||||
ver := cl.version
|
|
||||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
|
||||||
done := make(chan struct{})
|
|
||||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
cl.handlePacket(p)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(cl.version).To(Equal(ver))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -225,20 +225,6 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeForRecreating mocks base method
|
|
||||||
func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "closeForRecreating")
|
|
||||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeForRecreating indicates an expected call of closeForRecreating
|
|
||||||
func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating))
|
|
||||||
}
|
|
||||||
|
|
||||||
// destroy mocks base method
|
// destroy mocks base method
|
||||||
func (m *MockQuicSession) destroy(arg0 error) {
|
func (m *MockQuicSession) destroy(arg0 error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -49,7 +49,6 @@ type quicSession interface {
|
||||||
run() error
|
run() error
|
||||||
destroy(error)
|
destroy(error)
|
||||||
shutdown()
|
shutdown()
|
||||||
closeForRecreating() protocol.PacketNumber
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Listener of QUIC
|
// A Listener of QUIC
|
||||||
|
|
81
session.go
81
session.go
|
@ -104,7 +104,19 @@ type closeError struct {
|
||||||
immediate bool
|
immediate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var errCloseForRecreating = errors.New("closing session in order to recreate it")
|
type errCloseForRecreating struct {
|
||||||
|
nextPacketNumber protocol.PacketNumber
|
||||||
|
nextVersion protocol.VersionNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (errCloseForRecreating) Error() string {
|
||||||
|
return "closing session in order to recreate it"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (errCloseForRecreating) Is(target error) bool {
|
||||||
|
_, ok := target.(errCloseForRecreating)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// A Session is a QUIC session
|
// A Session is a QUIC session
|
||||||
type session struct {
|
type session struct {
|
||||||
|
@ -169,6 +181,7 @@ type session struct {
|
||||||
handshakeConfirmed bool
|
handshakeConfirmed bool
|
||||||
|
|
||||||
receivedRetry bool
|
receivedRetry bool
|
||||||
|
versionNegotiated bool
|
||||||
receivedFirstPacket bool
|
receivedFirstPacket bool
|
||||||
|
|
||||||
idleTimeout time.Duration
|
idleTimeout time.Duration
|
||||||
|
@ -336,6 +349,7 @@ var newClientSession = func(
|
||||||
initialPacketNumber protocol.PacketNumber,
|
initialPacketNumber protocol.PacketNumber,
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
enable0RTT bool,
|
enable0RTT bool,
|
||||||
|
hasNegotiatedVersion bool,
|
||||||
qlogger qlog.Tracer,
|
qlogger qlog.Tracer,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
|
@ -352,6 +366,7 @@ var newClientSession = func(
|
||||||
logger: logger,
|
logger: logger,
|
||||||
qlogger: qlogger,
|
qlogger: qlogger,
|
||||||
initialVersion: initialVersion,
|
initialVersion: initialVersion,
|
||||||
|
versionNegotiated: hasNegotiatedVersion,
|
||||||
version: v,
|
version: v,
|
||||||
}
|
}
|
||||||
s.connIDManager = newConnIDManager(
|
s.connIDManager = newConnIDManager(
|
||||||
|
@ -595,7 +610,7 @@ runLoop:
|
||||||
}
|
}
|
||||||
|
|
||||||
s.handleCloseError(closeErr)
|
s.handleCloseError(closeErr)
|
||||||
if closeErr.err != errCloseForRecreating && s.qlogger != nil {
|
if !errors.Is(closeErr.err, errCloseForRecreating{}) && s.qlogger != nil {
|
||||||
if err := s.qlogger.Export(); err != nil {
|
if err := s.qlogger.Export(); err != nil {
|
||||||
s.logger.Errorf("exporting qlog failed: %s", err)
|
s.logger.Errorf("exporting qlog failed: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -692,6 +707,11 @@ func (s *session) handleHandshakeComplete() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handlePacketImpl(rp *receivedPacket) bool {
|
func (s *session) handlePacketImpl(rp *receivedPacket) bool {
|
||||||
|
if wire.IsVersionNegotiationPacket(rp.data) {
|
||||||
|
s.handleVersionNegotiationPacket(rp)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var counter uint8
|
var counter uint8
|
||||||
var lastConnID protocol.ConnectionID
|
var lastConnID protocol.ConnectionID
|
||||||
var processed bool
|
var processed bool
|
||||||
|
@ -888,6 +908,55 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) handleVersionNegotiationPacket(p *receivedPacket) {
|
||||||
|
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
|
||||||
|
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
|
||||||
|
if s.qlogger != nil {
|
||||||
|
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||||
|
if err != nil {
|
||||||
|
if s.qlogger != nil {
|
||||||
|
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
||||||
|
}
|
||||||
|
s.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range hdr.SupportedVersions {
|
||||||
|
if v == s.version {
|
||||||
|
if s.qlogger != nil {
|
||||||
|
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
||||||
|
}
|
||||||
|
// The Version Negotiation packet contains the version that we offered.
|
||||||
|
// This might be a packet sent by an attacker, or it was corrupted.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||||
|
if s.qlogger != nil {
|
||||||
|
s.qlogger.ReceivedVersionNegotiationPacket(hdr)
|
||||||
|
}
|
||||||
|
newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, hdr.SupportedVersions)
|
||||||
|
if !ok {
|
||||||
|
//nolint:stylecheck
|
||||||
|
s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, hdr.SupportedVersions))
|
||||||
|
s.logger.Infof("No compatible QUIC version found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Infof("Switching to QUIC version %s.", newVersion)
|
||||||
|
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
|
||||||
|
s.destroyImpl(&errCloseForRecreating{
|
||||||
|
nextPacketNumber: nextPN,
|
||||||
|
nextVersion: newVersion,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *session) handleUnpackedPacket(
|
func (s *session) handleUnpackedPacket(
|
||||||
packet *unpackedPacket,
|
packet *unpackedPacket,
|
||||||
rcvTime time.Time,
|
rcvTime time.Time,
|
||||||
|
@ -1190,14 +1259,6 @@ func (s *session) destroyImpl(e error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeForRecreating closes the session in order to recreate it immediately afterwards
|
|
||||||
// It returns the first packet number that should be used in the new session.
|
|
||||||
func (s *session) closeForRecreating() protocol.PacketNumber {
|
|
||||||
s.destroy(errCloseForRecreating)
|
|
||||||
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
|
|
||||||
return nextPN
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) closeRemote(e error) {
|
func (s *session) closeRemote(e error) {
|
||||||
s.closeOnce.Do(func() {
|
s.closeOnce.Do(func() {
|
||||||
s.logger.Errorf("Peer closed session with error: %s", e)
|
s.logger.Errorf("Peer closed session with error: %s", e)
|
||||||
|
|
|
@ -487,18 +487,6 @@ var _ = Describe("Session", func() {
|
||||||
Expect(sess.Context().Done()).To(BeClosed())
|
Expect(sess.Context().Done()).To(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes the session in order to recreate it", func() {
|
|
||||||
runSession()
|
|
||||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
|
||||||
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
||||||
cryptoSetup.EXPECT().Close()
|
|
||||||
// don't EXPECT any calls to mconn.Write()
|
|
||||||
// don't EXPECT any call to qlogger.Export()
|
|
||||||
sess.closeForRecreating()
|
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
|
||||||
expectedRunErr = errCloseForRecreating
|
|
||||||
})
|
|
||||||
|
|
||||||
It("destroys the session", func() {
|
It("destroys the session", func() {
|
||||||
runSession()
|
runSession()
|
||||||
testErr := errors.New("close")
|
testErr := errors.New("close")
|
||||||
|
@ -603,6 +591,16 @@ var _ = Describe("Session", func() {
|
||||||
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("drops Version Negotiation packets", func() {
|
||||||
|
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, sess.config.Versions)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), qlog.PacketDropUnexpectedPacket)
|
||||||
|
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||||
|
data: b,
|
||||||
|
buffer: getPacketBuffer(),
|
||||||
|
})).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
It("drops packets for which header decryption fails", func() {
|
It("drops packets for which header decryption fails", func() {
|
||||||
p := getPacket(&wire.ExtendedHeader{
|
p := getPacket(&wire.ExtendedHeader{
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
|
@ -2035,6 +2033,7 @@ var _ = Describe("Client Session", func() {
|
||||||
42, // initial packet number
|
42, // initial packet number
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
false,
|
false,
|
||||||
|
false,
|
||||||
qlogger,
|
qlogger,
|
||||||
utils.DefaultLogger,
|
utils.DefaultLogger,
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
|
@ -2133,6 +2132,81 @@ var _ = Describe("Client Session", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("handling Version Negotiation", func() {
|
||||||
|
getVNP := func(versions ...protocol.VersionNumber) *receivedPacket {
|
||||||
|
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return &receivedPacket{
|
||||||
|
data: b,
|
||||||
|
buffer: getPacketBuffer(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
It("closes and returns the right error", func() {
|
||||||
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||||
|
sess.sentPacketHandler = sph
|
||||||
|
sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4)
|
||||||
|
sess.config.Versions = []protocol.VersionNumber{1234, 4321}
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||||
|
errChan <- sess.run()
|
||||||
|
}()
|
||||||
|
sessionRunner.EXPECT().Remove(srcConnID)
|
||||||
|
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()).Do(func(hdr *wire.Header) {
|
||||||
|
Expect(hdr.Version).To(BeZero())
|
||||||
|
Expect(hdr.SupportedVersions).To(And(
|
||||||
|
ContainElement(protocol.VersionNumber(4321)),
|
||||||
|
ContainElement(protocol.VersionNumber(1337)),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
cryptoSetup.EXPECT().Close()
|
||||||
|
Expect(sess.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse())
|
||||||
|
var err error
|
||||||
|
Eventually(errChan).Should(Receive(&err))
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{}))
|
||||||
|
recreateErr := err.(*errCloseForRecreating)
|
||||||
|
Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321)))
|
||||||
|
Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("it closes when no matching version is found", func() {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||||
|
errChan <- sess.run()
|
||||||
|
}()
|
||||||
|
sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
|
||||||
|
gomock.InOrder(
|
||||||
|
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()),
|
||||||
|
qlogger.EXPECT().Export(),
|
||||||
|
)
|
||||||
|
cryptoSetup.EXPECT().Close()
|
||||||
|
Expect(sess.handlePacketImpl(getVNP(12345678))).To(BeFalse())
|
||||||
|
var err error
|
||||||
|
Eventually(errChan).Should(Receive(&err))
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err).ToNot(BeAssignableToTypeOf(&errCloseForRecreating{}))
|
||||||
|
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores Version Negotiation packets that offer the current version", func() {
|
||||||
|
p := getVNP(sess.version)
|
||||||
|
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
||||||
|
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores unparseable Version Negotiation packets", func() {
|
||||||
|
p := getVNP(sess.version)
|
||||||
|
p.data = p.data[:len(p.data)-2]
|
||||||
|
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
||||||
|
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("handling Retry", func() {
|
Context("handling Retry", func() {
|
||||||
origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue