diff --git a/client.go b/client.go index 5d542c57..1ceb071a 100644 --- a/client.go +++ b/client.go @@ -11,7 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qlog" ) @@ -27,20 +26,15 @@ type client struct { packetHandlers packetHandlerManager - versionNegotiated utils.AtomicBool // has the server accepted our version - receivedVersionNegotiationPacket bool - negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet - tlsConf *tls.Config config *Config srcConnID protocol.ConnectionID destConnID protocol.ConnectionID - initialPacketNumber protocol.PacketNumber - - initialVersion protocol.VersionNumber - version protocol.VersionNumber + initialPacketNumber protocol.PacketNumber + hasNegotiatedVersion bool + version protocol.VersionNumber handshakeChan chan struct{} @@ -268,8 +262,9 @@ func (c *client) dial(ctx context.Context) error { c.config, c.tlsConf, c.initialPacketNumber, - c.initialVersion, + c.version, c.use0RTT, + c.hasNegotiatedVersion, c.qlogger, c.logger, c.version, @@ -280,7 +275,7 @@ func (c *client) dial(ctx context.Context) error { errorChan := make(chan error, 1) go func() { err := c.session.run() // returns as soon as the session is closed - if err != errCloseForRecreating && c.createdPacketConn { + if !errors.Is(err, errCloseForRecreating{}) && c.createdPacketConn { c.packetHandlers.Destroy() } errorChan <- err @@ -298,7 +293,11 @@ func (c *client) dial(ctx context.Context) error { c.session.shutdown() return ctx.Err() case err := <-errorChan: - if err == errCloseForRecreating { + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true return c.dial(ctx) } return err @@ -312,75 +311,9 @@ func (c *client) dial(ctx context.Context) error { } func (c *client) handlePacket(p *receivedPacket) { - if wire.IsVersionNegotiationPacket(p.data) { - go c.handleVersionNegotiationPacket(p) - return - } - - // this is the first packet we are receiving - // since it is not a Version Negotiation Packet, this means the server supports the suggested version - if !c.versionNegotiated.Get() { - c.versionNegotiated.Set(true) - } - c.session.handlePacket(p) } -func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { - c.mutex.Lock() - defer c.mutex.Unlock() - - hdr, _, _, err := wire.ParsePacket(p.data, 0) - if err != nil { - if c.qlogger != nil { - c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError) - } - c.logger.Debugf("Error parsing Version Negotiation packet: %s", err) - return - } - - // ignore delayed / duplicated version negotiation packets - if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { - if c.qlogger != nil { - c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket) - } - c.logger.Debugf("Received a delayed Version Negotiation packet.") - return - } - - for _, v := range hdr.SupportedVersions { - if v == c.version { - if c.qlogger != nil { - c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion) - } - // The Version Negotiation packet contains the version that we offered. - // This might be a packet sent by an attacker (or by a terribly broken server implementation). - return - } - } - - c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) - if c.qlogger != nil { - c.qlogger.ReceivedVersionNegotiationPacket(hdr) - } - newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) - if !ok { - //nolint:stylecheck - c.session.destroy(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s", c.config.Versions, hdr.SupportedVersions)) - c.logger.Debugf("No compatible QUIC version found.") - return - } - c.receivedVersionNegotiationPacket = true - c.negotiatedVersions = hdr.SupportedVersions - - // switch to negotiated version - c.initialVersion = c.version - c.version = newVersion - - c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) - c.initialPacketNumber = c.session.closeForRecreating() -} - func (c *client) shutdown() { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/client_test.go b/client_test.go index 5ce40dc6..6742ba06 100644 --- a/client_test.go +++ b/client_test.go @@ -47,6 +47,7 @@ var _ = Describe("Client", func() { initialPacketNumber protocol.PacketNumber, initialVersion protocol.VersionNumber, enable0RTT bool, + hasNegotiatedVersion bool, qlogger qlog.Tracer, logger utils.Logger, v protocol.VersionNumber, @@ -65,16 +66,6 @@ var _ = Describe("Client", func() { return b.Bytes() } - composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket { - data, err := wire.ComposeVersionNegotiation(connID, nil, versions) - Expect(err).ToNot(HaveOccurred()) - Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue()) - return &receivedPacket{ - rcvTime: time.Now(), - data: data, - } - } - BeforeEach(func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} @@ -169,6 +160,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -201,6 +193,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -233,6 +226,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -271,6 +265,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, enable0RTT bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -313,6 +308,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, enable0RTT bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -360,6 +356,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -403,6 +400,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -454,6 +452,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, _ protocol.VersionNumber, @@ -574,6 +573,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ protocol.VersionNumber, /* initial version */ _ bool, + _ bool, _ qlog.Tracer, _ utils.Logger, versionP protocol.VersionNumber, @@ -596,183 +596,58 @@ var _ = Describe("Client", func() { Expect(conf.Versions).To(Equal(config.Versions)) }) - Context("version negotiation", func() { - var origSupportedVersions []protocol.VersionNumber + It("creates a new session after version negotiation", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(connID, gomock.Any()).Times(2) + manager.EXPECT().Destroy() + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - BeforeEach(func() { - origSupportedVersions = protocol.SupportedVersions - protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...) - }) + initialVersion := cl.version - AfterEach(func() { - protocol.SupportedVersions = origSupportedVersions - }) - - It("returns an error that occurs during version negotiation", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) - - testErr := errors.New("early handshake error") - newClientSession = func( - conn connection, - _ sessionRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ protocol.VersionNumber, - _ bool, - _ qlog.Tracer, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicSession { - Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Return(testErr) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + var counter int + newClientSession = func( + _ connection, + _ sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + configP *Config, + _ *tls.Config, + pn protocol.PacketNumber, + version protocol.VersionNumber, + _ bool, + hasNegotiatedVersion bool, + _ qlog.Tracer, + _ utils.Logger, + versionP protocol.VersionNumber, + ) quicSession { + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().HandshakeComplete().Return(context.Background()) + if counter == 0 { + Expect(pn).To(BeZero()) + Expect(version).To(Equal(initialVersion)) + Expect(hasNegotiatedVersion).To(BeFalse()) + sess.EXPECT().run().Return(&errCloseForRecreating{ + nextPacketNumber: 109, + nextVersion: 789, + }) + } else { + Expect(pn).To(Equal(protocol.PacketNumber(109))) + Expect(version).ToNot(Equal(initialVersion)) + Expect(version).To(Equal(protocol.VersionNumber(789))) + Expect(hasNegotiatedVersion).To(BeTrue()) + sess.EXPECT().run() } - qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) - _, err := Dial( - packetConn, - addr, - "localhost:1337", - tlsConf, - config, - ) - Expect(err).To(MatchError(testErr)) - }) + counter++ + return sess + } - It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - cl.session = sess - cl.config = config - buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ - Header: wire.Header{ - DestConnectionID: connID, - SrcConnectionID: connID, - Version: cl.version, - }, - PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, protocol.VersionTLS)).To(Succeed()) - cl.handlePacket(&receivedPacket{data: buf.Bytes()}) - Eventually(cl.versionNegotiated.Get).Should(BeTrue()) - }) - - // Illustrates that adversary that injects a version negotiation packet - // with no supported versions can break a connection. - It("errors if no matching version is found", func() { - sess := NewMockQuicSession(mockCtrl) - done := make(chan struct{}) - sess.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found.")) - close(done) - }) - cl.session = sess - cl.config = &Config{Versions: protocol.SupportedVersions} - p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337}) - hdr, _, _, err := wire.ParsePacket(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr) - cl.handlePacket(p) - Eventually(done).Should(BeClosed()) - }) - - It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { - sess := NewMockQuicSession(mockCtrl) - done := make(chan struct{}) - sess.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found.")) - close(done) - }) - cl.session = sess - v := protocol.VersionNumber(1234) - Expect(v).ToNot(Equal(cl.version)) - cl.config = &Config{Versions: protocol.SupportedVersions} - qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()) - cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v})) - Eventually(done).Should(BeClosed()) - }) - - It("changes to the version preferred by the quic.Config", func() { - phm := NewMockPacketHandlerManager(mockCtrl) - cl.packetHandlers = phm - - sess := NewMockQuicSession(mockCtrl) - destroyed := make(chan struct{}) - sess.EXPECT().closeForRecreating().Do(func() { - close(destroyed) - }) - cl.session = sess - versions := []protocol.VersionNumber{1234, 4321} - cl.config = &Config{Versions: versions} - qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()) - cl.handlePacket(composeVersionNegotiationPacket(connID, versions)) - Eventually(destroyed).Should(BeClosed()) - Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) - }) - - It("drops unparseable version negotiation packets", func() { - cl.config = config - ver := cl.version - p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}) - p.data = p.data[:len(p.data)-1] - done := make(chan struct{}) - qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { - close(done) - }) - cl.handlePacket(p) - Eventually(done).Should(BeClosed()) - Expect(cl.version).To(Equal(ver)) - }) - - It("drops version negotiation packets if any other packet was received before", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - cl.session = sess - cl.config = config - buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ - Header: wire.Header{ - DestConnectionID: connID, - SrcConnectionID: connID, - Version: cl.version, - }, - PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, protocol.VersionTLS)).To(Succeed()) - cl.handlePacket(&receivedPacket{data: buf.Bytes()}) - - ver := cl.version - p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}) - done := make(chan struct{}) - qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { - close(done) - }) - cl.handlePacket(p) - Eventually(done).Should(BeClosed()) - Expect(cl.version).To(Equal(ver)) - }) - - It("drops version negotiation packets that contain the offered version", func() { - cl.config = config - ver := cl.version - p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}) - done := make(chan struct{}) - qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { - close(done) - }) - cl.handlePacket(p) - Eventually(done).Should(BeClosed()) - Expect(cl.version).To(Equal(ver)) - }) + gomock.InOrder( + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()), + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()), + ) + _, err := DialAddr("localhost:7890", tlsConf, config) + Expect(err).ToNot(HaveOccurred()) + Expect(counter).To(Equal(2)) }) }) diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 23e4f9c4..e92cbc75 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -225,20 +225,6 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr)) } -// closeForRecreating mocks base method -func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "closeForRecreating") - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// closeForRecreating indicates an expected call of closeForRecreating -func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating)) -} - // destroy mocks base method func (m *MockQuicSession) destroy(arg0 error) { m.ctrl.T.Helper() diff --git a/server.go b/server.go index cb035e86..35f6fdb7 100644 --- a/server.go +++ b/server.go @@ -49,7 +49,6 @@ type quicSession interface { run() error destroy(error) shutdown() - closeForRecreating() protocol.PacketNumber } // A Listener of QUIC diff --git a/session.go b/session.go index 1caa1da8..e9ff537e 100644 --- a/session.go +++ b/session.go @@ -104,7 +104,19 @@ type closeError struct { immediate bool } -var errCloseForRecreating = errors.New("closing session in order to recreate it") +type errCloseForRecreating struct { + nextPacketNumber protocol.PacketNumber + nextVersion protocol.VersionNumber +} + +func (errCloseForRecreating) Error() string { + return "closing session in order to recreate it" +} + +func (errCloseForRecreating) Is(target error) bool { + _, ok := target.(errCloseForRecreating) + return ok +} // A Session is a QUIC session type session struct { @@ -169,6 +181,7 @@ type session struct { handshakeConfirmed bool receivedRetry bool + versionNegotiated bool receivedFirstPacket bool idleTimeout time.Duration @@ -336,6 +349,7 @@ var newClientSession = func( initialPacketNumber protocol.PacketNumber, initialVersion protocol.VersionNumber, enable0RTT bool, + hasNegotiatedVersion bool, qlogger qlog.Tracer, logger utils.Logger, v protocol.VersionNumber, @@ -352,6 +366,7 @@ var newClientSession = func( logger: logger, qlogger: qlogger, initialVersion: initialVersion, + versionNegotiated: hasNegotiatedVersion, version: v, } s.connIDManager = newConnIDManager( @@ -595,7 +610,7 @@ runLoop: } s.handleCloseError(closeErr) - if closeErr.err != errCloseForRecreating && s.qlogger != nil { + if !errors.Is(closeErr.err, errCloseForRecreating{}) && s.qlogger != nil { if err := s.qlogger.Export(); err != nil { s.logger.Errorf("exporting qlog failed: %s", err) } @@ -692,6 +707,11 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(rp *receivedPacket) bool { + if wire.IsVersionNegotiationPacket(rp.data) { + s.handleVersionNegotiationPacket(rp) + return false + } + var counter uint8 var lastConnID protocol.ConnectionID var processed bool @@ -888,6 +908,55 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t return true } +func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { + if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets + s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets + if s.qlogger != nil { + s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket) + } + return + } + + hdr, _, _, err := wire.ParsePacket(p.data, 0) + if err != nil { + if s.qlogger != nil { + s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) + return + } + + for _, v := range hdr.SupportedVersions { + if v == s.version { + if s.qlogger != nil { + s.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion) + } + // The Version Negotiation packet contains the version that we offered. + // This might be a packet sent by an attacker, or it was corrupted. + return + } + } + + s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) + if s.qlogger != nil { + s.qlogger.ReceivedVersionNegotiationPacket(hdr) + } + newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, hdr.SupportedVersions) + if !ok { + //nolint:stylecheck + s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, hdr.SupportedVersions)) + s.logger.Infof("No compatible QUIC version found.") + return + } + + s.logger.Infof("Switching to QUIC version %s.", newVersion) + nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) + s.destroyImpl(&errCloseForRecreating{ + nextPacketNumber: nextPN, + nextVersion: newVersion, + }) +} + func (s *session) handleUnpackedPacket( packet *unpackedPacket, rcvTime time.Time, @@ -1190,14 +1259,6 @@ func (s *session) destroyImpl(e error) { }) } -// closeForRecreating closes the session in order to recreate it immediately afterwards -// It returns the first packet number that should be used in the new session. -func (s *session) closeForRecreating() protocol.PacketNumber { - s.destroy(errCloseForRecreating) - nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) - return nextPN -} - func (s *session) closeRemote(e error) { s.closeOnce.Do(func() { s.logger.Errorf("Peer closed session with error: %s", e) diff --git a/session_test.go b/session_test.go index b452d09b..49a111e5 100644 --- a/session_test.go +++ b/session_test.go @@ -487,18 +487,6 @@ var _ = Describe("Session", func() { Expect(sess.Context().Done()).To(BeClosed()) }) - It("closes the session in order to recreate it", func() { - runSession() - streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cryptoSetup.EXPECT().Close() - // don't EXPECT any calls to mconn.Write() - // don't EXPECT any call to qlogger.Export() - sess.closeForRecreating() - Eventually(areSessionsRunning).Should(BeFalse()) - expectedRunErr = errCloseForRecreating - }) - It("destroys the session", func() { runSession() testErr := errors.New("close") @@ -603,6 +591,16 @@ var _ = Describe("Session", func() { Expect(sess.handlePacketImpl(p)).To(BeFalse()) }) + It("drops Version Negotiation packets", func() { + b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, sess.config.Versions) + Expect(err).ToNot(HaveOccurred()) + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), qlog.PacketDropUnexpectedPacket) + Expect(sess.handlePacketImpl(&receivedPacket{ + data: b, + buffer: getPacketBuffer(), + })).To(BeFalse()) + }) + It("drops packets for which header decryption fails", func() { p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ @@ -2035,6 +2033,7 @@ var _ = Describe("Client Session", func() { 42, // initial packet number protocol.VersionTLS, false, + false, qlogger, utils.DefaultLogger, protocol.VersionTLS, @@ -2133,6 +2132,81 @@ var _ = Describe("Client Session", func() { }) }) + Context("handling Version Negotiation", func() { + getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { + b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) + Expect(err).ToNot(HaveOccurred()) + return &receivedPacket{ + data: b, + buffer: getPacketBuffer(), + } + } + + It("closes and returns the right error", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) + sess.config.Versions = []protocol.VersionNumber{1234, 4321} + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + errChan <- sess.run() + }() + sessionRunner.EXPECT().Remove(srcConnID) + qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()).Do(func(hdr *wire.Header) { + Expect(hdr.Version).To(BeZero()) + Expect(hdr.SupportedVersions).To(And( + ContainElement(protocol.VersionNumber(4321)), + ContainElement(protocol.VersionNumber(1337)), + )) + }) + cryptoSetup.EXPECT().Close() + Expect(sess.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{})) + recreateErr := err.(*errCloseForRecreating) + Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321))) + Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) + }) + + It("it closes when no matching version is found", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + errChan <- sess.run() + }() + sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1) + gomock.InOrder( + qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()), + qlogger.EXPECT().Export(), + ) + cryptoSetup.EXPECT().Close() + Expect(sess.handlePacketImpl(getVNP(12345678))).To(BeFalse()) + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).ToNot(BeAssignableToTypeOf(&errCloseForRecreating{})) + Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found")) + }) + + It("ignores Version Negotiation packets that offer the current version", func() { + p := getVNP(sess.version) + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion) + Expect(sess.handlePacketImpl(p)).To(BeFalse()) + }) + + It("ignores unparseable Version Negotiation packets", func() { + p := getVNP(sess.version) + p.data = p.data[:len(p.data)-2] + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError) + Expect(sess.handlePacketImpl(p)).To(BeFalse()) + }) + }) + Context("handling Retry", func() { origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}