mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
handle Version Negotiation packets in the session
This commit is contained in:
parent
6b42c7a045
commit
06ad477b9b
6 changed files with 226 additions and 298 deletions
89
client.go
89
client.go
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qlog"
|
||||
)
|
||||
|
||||
|
@ -27,20 +26,15 @@ type client struct {
|
|||
|
||||
packetHandlers packetHandlerManager
|
||||
|
||||
versionNegotiated utils.AtomicBool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
|
@ -268,8 +262,9 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
c.initialVersion,
|
||||
c.version,
|
||||
c.use0RTT,
|
||||
c.hasNegotiatedVersion,
|
||||
c.qlogger,
|
||||
c.logger,
|
||||
c.version,
|
||||
|
@ -280,7 +275,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||
errorChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != errCloseForRecreating && c.createdPacketConn {
|
||||
if !errors.Is(err, errCloseForRecreating{}) && c.createdPacketConn {
|
||||
c.packetHandlers.Destroy()
|
||||
}
|
||||
errorChan <- err
|
||||
|
@ -298,7 +293,11 @@ func (c *client) dial(ctx context.Context) error {
|
|||
c.session.shutdown()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
if err == errCloseForRecreating {
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
|
@ -312,75 +311,9 @@ func (c *client) dial(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
go c.handleVersionNegotiationPacket(p)
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated.Get() {
|
||||
c.versionNegotiated.Set(true)
|
||||
}
|
||||
|
||||
c.session.handlePacket(p)
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
if err != nil {
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
||||
}
|
||||
c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
|
||||
}
|
||||
c.logger.Debugf("Received a delayed Version Negotiation packet.")
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
||||
}
|
||||
// The Version Negotiation packet contains the version that we offered.
|
||||
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.ReceivedVersionNegotiationPacket(hdr)
|
||||
}
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
//nolint:stylecheck
|
||||
c.session.destroy(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s", c.config.Versions, hdr.SupportedVersions))
|
||||
c.logger.Debugf("No compatible QUIC version found.")
|
||||
return
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.initialPacketNumber = c.session.closeForRecreating()
|
||||
}
|
||||
|
||||
func (c *client) shutdown() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
|
241
client_test.go
241
client_test.go
|
@ -47,6 +47,7 @@ var _ = Describe("Client", func() {
|
|||
initialPacketNumber protocol.PacketNumber,
|
||||
initialVersion protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
hasNegotiatedVersion bool,
|
||||
qlogger qlog.Tracer,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
|
@ -65,16 +66,6 @@ var _ = Describe("Client", func() {
|
|||
return b.Bytes()
|
||||
}
|
||||
|
||||
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
|
||||
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
|
||||
return &receivedPacket{
|
||||
rcvTime: time.Now(),
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||
|
@ -169,6 +160,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -201,6 +193,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -233,6 +226,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -271,6 +265,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -313,6 +308,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -360,6 +356,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -403,6 +400,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -454,6 +452,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -574,6 +573,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
|
@ -596,183 +596,58 @@ var _ = Describe("Client", func() {
|
|||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
Context("version negotiation", func() {
|
||||
var origSupportedVersions []protocol.VersionNumber
|
||||
It("creates a new session after version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
BeforeEach(func() {
|
||||
origSupportedVersions = protocol.SupportedVersions
|
||||
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
|
||||
})
|
||||
initialVersion := cl.version
|
||||
|
||||
AfterEach(func() {
|
||||
protocol.SupportedVersions = origSupportedVersions
|
||||
})
|
||||
|
||||
It("returns an error that occurs during version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Return(testErr)
|
||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
return sess
|
||||
var counter int
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
version protocol.VersionNumber,
|
||||
_ bool,
|
||||
hasNegotiatedVersion bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) quicSession {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
if counter == 0 {
|
||||
Expect(pn).To(BeZero())
|
||||
Expect(version).To(Equal(initialVersion))
|
||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||
sess.EXPECT().run().Return(&errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
})
|
||||
} else {
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||
Expect(version).ToNot(Equal(initialVersion))
|
||||
Expect(version).To(Equal(protocol.VersionNumber(789)))
|
||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||
sess.EXPECT().run()
|
||||
}
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||
_, err := Dial(
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
tlsConf,
|
||||
config,
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
counter++
|
||||
return sess
|
||||
}
|
||||
|
||||
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = config
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: cl.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
|
||||
})
|
||||
|
||||
// Illustrates that adversary that injects a version negotiation packet
|
||||
// with no supported versions can break a connection.
|
||||
It("errors if no matching version is found", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
||||
close(done)
|
||||
})
|
||||
cl.session = sess
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr)
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
||||
close(done)
|
||||
})
|
||||
cl.session = sess
|
||||
v := protocol.VersionNumber(1234)
|
||||
Expect(v).ToNot(Equal(cl.version))
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("changes to the version preferred by the quic.Config", func() {
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
cl.packetHandlers = phm
|
||||
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
destroyed := make(chan struct{})
|
||||
sess.EXPECT().closeForRecreating().Do(func() {
|
||||
close(destroyed)
|
||||
})
|
||||
cl.session = sess
|
||||
versions := []protocol.VersionNumber{1234, 4321}
|
||||
cl.config = &Config{Versions: versions}
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
||||
})
|
||||
|
||||
It("drops unparseable version negotiation packets", func() {
|
||||
cl.config = config
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
||||
p.data = p.data[:len(p.data)-1]
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets if any other packet was received before", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = config
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: cl.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets that contain the offered version", func() {
|
||||
cl.config = config
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
gomock.InOrder(
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()),
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()),
|
||||
)
|
||||
_, err := DialAddr("localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(counter).To(Equal(2))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -225,20 +225,6 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
|
||||
}
|
||||
|
||||
// closeForRecreating mocks base method
|
||||
func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "closeForRecreating")
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// closeForRecreating indicates an expected call of closeForRecreating
|
||||
func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating))
|
||||
}
|
||||
|
||||
// destroy mocks base method
|
||||
func (m *MockQuicSession) destroy(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -49,7 +49,6 @@ type quicSession interface {
|
|||
run() error
|
||||
destroy(error)
|
||||
shutdown()
|
||||
closeForRecreating() protocol.PacketNumber
|
||||
}
|
||||
|
||||
// A Listener of QUIC
|
||||
|
|
81
session.go
81
session.go
|
@ -104,7 +104,19 @@ type closeError struct {
|
|||
immediate bool
|
||||
}
|
||||
|
||||
var errCloseForRecreating = errors.New("closing session in order to recreate it")
|
||||
type errCloseForRecreating struct {
|
||||
nextPacketNumber protocol.PacketNumber
|
||||
nextVersion protocol.VersionNumber
|
||||
}
|
||||
|
||||
func (errCloseForRecreating) Error() string {
|
||||
return "closing session in order to recreate it"
|
||||
}
|
||||
|
||||
func (errCloseForRecreating) Is(target error) bool {
|
||||
_, ok := target.(errCloseForRecreating)
|
||||
return ok
|
||||
}
|
||||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
|
@ -169,6 +181,7 @@ type session struct {
|
|||
handshakeConfirmed bool
|
||||
|
||||
receivedRetry bool
|
||||
versionNegotiated bool
|
||||
receivedFirstPacket bool
|
||||
|
||||
idleTimeout time.Duration
|
||||
|
@ -336,6 +349,7 @@ var newClientSession = func(
|
|||
initialPacketNumber protocol.PacketNumber,
|
||||
initialVersion protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
hasNegotiatedVersion bool,
|
||||
qlogger qlog.Tracer,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
|
@ -352,6 +366,7 @@ var newClientSession = func(
|
|||
logger: logger,
|
||||
qlogger: qlogger,
|
||||
initialVersion: initialVersion,
|
||||
versionNegotiated: hasNegotiatedVersion,
|
||||
version: v,
|
||||
}
|
||||
s.connIDManager = newConnIDManager(
|
||||
|
@ -595,7 +610,7 @@ runLoop:
|
|||
}
|
||||
|
||||
s.handleCloseError(closeErr)
|
||||
if closeErr.err != errCloseForRecreating && s.qlogger != nil {
|
||||
if !errors.Is(closeErr.err, errCloseForRecreating{}) && s.qlogger != nil {
|
||||
if err := s.qlogger.Export(); err != nil {
|
||||
s.logger.Errorf("exporting qlog failed: %s", err)
|
||||
}
|
||||
|
@ -692,6 +707,11 @@ func (s *session) handleHandshakeComplete() {
|
|||
}
|
||||
|
||||
func (s *session) handlePacketImpl(rp *receivedPacket) bool {
|
||||
if wire.IsVersionNegotiationPacket(rp.data) {
|
||||
s.handleVersionNegotiationPacket(rp)
|
||||
return false
|
||||
}
|
||||
|
||||
var counter uint8
|
||||
var lastConnID protocol.ConnectionID
|
||||
var processed bool
|
||||
|
@ -888,6 +908,55 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *session) handleVersionNegotiationPacket(p *receivedPacket) {
|
||||
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
|
||||
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
if err != nil {
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == s.version {
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
||||
}
|
||||
// The Version Negotiation packet contains the version that we offered.
|
||||
// This might be a packet sent by an attacker, or it was corrupted.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.ReceivedVersionNegotiationPacket(hdr)
|
||||
}
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
//nolint:stylecheck
|
||||
s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, hdr.SupportedVersions))
|
||||
s.logger.Infof("No compatible QUIC version found.")
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Infof("Switching to QUIC version %s.", newVersion)
|
||||
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
|
||||
s.destroyImpl(&errCloseForRecreating{
|
||||
nextPacketNumber: nextPN,
|
||||
nextVersion: newVersion,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) handleUnpackedPacket(
|
||||
packet *unpackedPacket,
|
||||
rcvTime time.Time,
|
||||
|
@ -1190,14 +1259,6 @@ func (s *session) destroyImpl(e error) {
|
|||
})
|
||||
}
|
||||
|
||||
// closeForRecreating closes the session in order to recreate it immediately afterwards
|
||||
// It returns the first packet number that should be used in the new session.
|
||||
func (s *session) closeForRecreating() protocol.PacketNumber {
|
||||
s.destroy(errCloseForRecreating)
|
||||
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
|
||||
return nextPN
|
||||
}
|
||||
|
||||
func (s *session) closeRemote(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.logger.Errorf("Peer closed session with error: %s", e)
|
||||
|
|
|
@ -487,18 +487,6 @@ var _ = Describe("Session", func() {
|
|||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session in order to recreate it", func() {
|
||||
runSession()
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
|
||||
cryptoSetup.EXPECT().Close()
|
||||
// don't EXPECT any calls to mconn.Write()
|
||||
// don't EXPECT any call to qlogger.Export()
|
||||
sess.closeForRecreating()
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
expectedRunErr = errCloseForRecreating
|
||||
})
|
||||
|
||||
It("destroys the session", func() {
|
||||
runSession()
|
||||
testErr := errors.New("close")
|
||||
|
@ -603,6 +591,16 @@ var _ = Describe("Session", func() {
|
|||
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("drops Version Negotiation packets", func() {
|
||||
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, sess.config.Versions)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), qlog.PacketDropUnexpectedPacket)
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
data: b,
|
||||
buffer: getPacketBuffer(),
|
||||
})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("drops packets for which header decryption fails", func() {
|
||||
p := getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
|
@ -2035,6 +2033,7 @@ var _ = Describe("Client Session", func() {
|
|||
42, // initial packet number
|
||||
protocol.VersionTLS,
|
||||
false,
|
||||
false,
|
||||
qlogger,
|
||||
utils.DefaultLogger,
|
||||
protocol.VersionTLS,
|
||||
|
@ -2133,6 +2132,81 @@ var _ = Describe("Client Session", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("handling Version Negotiation", func() {
|
||||
getVNP := func(versions ...protocol.VersionNumber) *receivedPacket {
|
||||
b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &receivedPacket{
|
||||
data: b,
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
It("closes and returns the right error", func() {
|
||||
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||
sess.sentPacketHandler = sph
|
||||
sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4)
|
||||
sess.config.Versions = []protocol.VersionNumber{1234, 4321}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
errChan <- sess.run()
|
||||
}()
|
||||
sessionRunner.EXPECT().Remove(srcConnID)
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()).Do(func(hdr *wire.Header) {
|
||||
Expect(hdr.Version).To(BeZero())
|
||||
Expect(hdr.SupportedVersions).To(And(
|
||||
ContainElement(protocol.VersionNumber(4321)),
|
||||
ContainElement(protocol.VersionNumber(1337)),
|
||||
))
|
||||
})
|
||||
cryptoSetup.EXPECT().Close()
|
||||
Expect(sess.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse())
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{}))
|
||||
recreateErr := err.(*errCloseForRecreating)
|
||||
Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321)))
|
||||
Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128)))
|
||||
})
|
||||
|
||||
It("it closes when no matching version is found", func() {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
errChan <- sess.run()
|
||||
}()
|
||||
sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
|
||||
gomock.InOrder(
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()),
|
||||
qlogger.EXPECT().Export(),
|
||||
)
|
||||
cryptoSetup.EXPECT().Close()
|
||||
Expect(sess.handlePacketImpl(getVNP(12345678))).To(BeFalse())
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).ToNot(BeAssignableToTypeOf(&errCloseForRecreating{}))
|
||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found"))
|
||||
})
|
||||
|
||||
It("ignores Version Negotiation packets that offer the current version", func() {
|
||||
p := getVNP(sess.version)
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion)
|
||||
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores unparseable Version Negotiation packets", func() {
|
||||
p := getVNP(sess.version)
|
||||
p.data = p.data[:len(p.data)-2]
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError)
|
||||
Expect(sess.handlePacketImpl(p)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("handling Retry", func() {
|
||||
origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue