From 993d71fd567b479a2ad1bffca8a56a47fbed14ff Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 9 Jul 2020 11:21:49 +0700 Subject: [PATCH] move the SupportedVersions slice out of the wire.Header --- internal/mocks/connection_tracer.go | 8 +-- internal/wire/header.go | 27 +-------- internal/wire/header_test.go | 41 -------------- internal/wire/version_negotiation.go | 26 +++++++++ internal/wire/version_negotiation_test.go | 69 +++++++++++++++++++++-- logging/interface.go | 2 +- qlog/qlog.go | 10 ++-- qlog/qlog_test.go | 10 ++-- server_test.go | 5 +- session.go | 12 ++-- session_test.go | 6 +- 11 files changed, 119 insertions(+), 97 deletions(-) diff --git a/internal/mocks/connection_tracer.go b/internal/mocks/connection_tracer.go index 5b0c7e30..a5fd0cf4 100644 --- a/internal/mocks/connection_tracer.go +++ b/internal/mocks/connection_tracer.go @@ -184,15 +184,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 int } // ReceivedVersionNegotiationPacket mocks base method -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header) { +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0) + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) } // SentPacket mocks base method diff --git a/internal/wire/header.go b/internal/wire/header.go index 0ec54bad..d6e70d36 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -56,8 +56,7 @@ type Header struct { Length protocol.ByteCount - Token []byte - SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet + Token []byte parsedLen protocol.ByteCount // how many bytes were read while parsing this header } @@ -155,8 +154,8 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { if err != nil { return err } - if h.Version == 0 { - return h.parseVersionNegotiationPacket(b) + if h.Version == 0 { // version negotiation packet + return nil } // If we don't understand the version, we have no idea how to interpret the rest of the bytes if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { @@ -209,26 +208,6 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { return nil } -func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error { - if b.Len() == 0 { - //nolint:stylecheck - return errors.New("Version Negotiation packet has empty version list") - } - if b.Len()%4 != 0 { - //nolint:stylecheck - return errors.New("Version Negotiation packet has a version list with an invalid length") - } - h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.SupportedVersions[i] = protocol.VersionNumber(v) - } - return nil -} - // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *Header) ParsedLen() protocol.ByteCount { return h.parsedLen diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 12f0f521..5a90947a 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -111,47 +111,6 @@ var _ = Describe("Header Parsing", func() { }) }) - Context("Version Negotiation Packets", func() { - It("parses", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - versions := []protocol.VersionNumber{0x22334455, 0x33445566} - vnp, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(err).ToNot(HaveOccurred()) - Expect(IsVersionNegotiationPacket(vnp)).To(BeTrue()) - hdr, _, rest, err := ParsePacket(vnp, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(BeZero()) - for _, v := range versions { - Expect(hdr.SupportedVersions).To(ContainElement(v)) - } - Expect(rest).To(BeEmpty()) - }) - - It("errors if it contains versions of the wrong length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) - _, _, _, err = ParsePacket(data[:len(data)-2], 0) - Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) - }) - - It("errors if the version list is empty", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) - // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number - data = data[:len(data)-8] - _, _, _, err = ParsePacket(data, 0) - Expect(err).To(MatchError("Version Negotiation packet has empty version list")) - }) - }) - Context("Long Headers", func() { It("parses a Long Header", func() { destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 4a6b323f..bcae87d1 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -3,11 +3,37 @@ package wire import ( "bytes" "crypto/rand" + "errors" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) +// ParseVersionNegotiationPacket parses a Version Negotiation packet. +func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { + hdr, err := parseHeader(b, 0) + if err != nil { + return nil, nil, err + } + if b.Len() == 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has empty version list") + } + if b.Len()%4 != 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + } + versions := make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, nil, err + } + versions[i] = protocol.VersionNumber(v) + } + return hdr, versions, nil +} + // ComposeVersionNegotiation composes a Version Negotiation func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { greasedVersions := protocol.GetGreasedVersions(versions) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 1d93c389..d3a0b7e6 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -1,29 +1,86 @@ package wire import ( + "bytes" + "encoding/binary" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Version Negotiation Packets", func() { - It("writes", func() { + It("parses a Version Negotiation packet", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + versions := []protocol.VersionNumber{0x22334455, 0x33445566} + data := []byte{0x80, 0, 0, 0, 0} + data = append(data, uint8(len(destConnID))) + data = append(data, destConnID...) + data = append(data, uint8(len(srcConnID))) + data = append(data, srcConnID...) + for _, v := range versions { + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) + } + Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(destConnID)) + Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.IsLongHeader).To(BeTrue()) + Expect(hdr.Version).To(BeZero()) + Expect(supportedVersions).To(Equal(versions)) + }) + + It("errors if it contains versions of the wrong length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []protocol.VersionNumber{0x22334455, 0x33445566} + data, err := ComposeVersionNegotiation(connID, connID, versions) + Expect(err).ToNot(HaveOccurred()) + _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) + }) + + It("errors if the version list is empty", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + versions := []protocol.VersionNumber{0x22334455} + data, err := ComposeVersionNegotiation(connID, connID, versions) + Expect(err).ToNot(HaveOccurred()) + // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number + data = data[:len(data)-8] + _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data)) + Expect(err).To(MatchError("Version Negotiation packet has empty version list")) + }) + + It("adds a reserved version", func() { srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) Expect(data[0] & 0x80).ToNot(BeZero()) - hdr, _, rest, err := ParsePacket(data, 4) + hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Version).To(BeZero()) // the supported versions should include one reserved version number - Expect(hdr.SupportedVersions).To(HaveLen(len(versions) + 1)) - for _, version := range versions { - Expect(hdr.SupportedVersions).To(ContainElement(version)) + Expect(supportedVersions).To(HaveLen(len(versions) + 1)) + for _, v := range versions { + Expect(supportedVersions).To(ContainElement(v)) } - Expect(rest).To(BeEmpty()) + var reservedVersion protocol.VersionNumber + versionLoop: + for _, ver := range supportedVersions { + for _, v := range versions { + if v == ver { + continue versionLoop + } + } + reservedVersion = ver + } + Expect(reservedVersion).ToNot(BeZero()) + Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number }) }) diff --git a/logging/interface.go b/logging/interface.go index 03bb10a0..cdde3a34 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -81,7 +81,7 @@ type ConnectionTracer interface { SentTransportParameters(*TransportParameters) ReceivedTransportParameters(*TransportParameters) SentPacket(hdr *ExtendedHeader, packetSize ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(*Header) + ReceivedVersionNegotiationPacket(*Header, []VersionNumber) ReceivedRetry(*Header) ReceivedPacket(hdr *ExtendedHeader, packetSize ByteCount, frames []Frame) ReceivedStatelessReset(token *[16]byte) diff --git a/qlog/qlog.go b/qlog/qlog.go index 0608379e..e0fb0d48 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -253,15 +253,15 @@ func (t *connectionTracer) ReceivedRetry(hdr *wire.Header) { t.mutex.Unlock() } -func (t *connectionTracer) ReceivedVersionNegotiationPacket(hdr *wire.Header) { - versions := make([]versionNumber, len(hdr.SupportedVersions)) - for i, v := range hdr.SupportedVersions { - versions[i] = versionNumber(v) +func (t *connectionTracer) ReceivedVersionNegotiationPacket(hdr *wire.Header, versions []logging.VersionNumber) { + ver := make([]versionNumber, len(versions)) + for i, v := range versions { + ver[i] = versionNumber(v) } t.mutex.Lock() t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ Header: *transformHeader(hdr), - SupportedVersions: versions, + SupportedVersions: ver, }) t.mutex.Unlock() } diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index abbaaa97..9893162e 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -379,12 +379,12 @@ var _ = Describe("Tracing", func() { It("records a received Version Negotiation packet", func() { tracer.ReceivedVersionNegotiationPacket( &logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - SupportedVersions: []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, }, + []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, ) entry := exportAndParseSingle() Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) diff --git a/server_test.go b/server_test.go index e870b6c1..569d1497 100644 --- a/server_test.go +++ b/server_test.go @@ -385,10 +385,11 @@ var _ = Describe("Server", func() { Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) - hdr := parseHeader(write.data) + hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(write.data)) + Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(srcConnID)) Expect(hdr.SrcConnectionID).To(Equal(destConnID)) - Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42))) + Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) }) It("replies with a Retry packet, if a Token is required", func() { diff --git a/session.go b/session.go index e35b49cd..d6fee2c7 100644 --- a/session.go +++ b/session.go @@ -916,7 +916,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { return } - hdr, _, _, err := wire.ParsePacket(p.data, 0) + hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError) @@ -925,7 +925,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { return } - for _, v := range hdr.SupportedVersions { + for _, v := range supportedVersions { if v == s.version { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), logging.PacketDropUnexpectedVersion) @@ -936,14 +936,14 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { } } - s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) + s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) if s.tracer != nil { - s.tracer.ReceivedVersionNegotiationPacket(hdr) + s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) } - newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, hdr.SupportedVersions) + newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) if !ok { //nolint:stylecheck - s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, hdr.SupportedVersions)) + s.destroyImpl(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s.", s.config.Versions, supportedVersions)) s.logger.Infof("No compatible QUIC version found.") return } diff --git a/session_test.go b/session_test.go index 89612fc4..915999b1 100644 --- a/session_test.go +++ b/session_test.go @@ -2154,9 +2154,9 @@ var _ = Describe("Client Session", func() { errChan <- sess.run() }() sessionRunner.EXPECT().Remove(srcConnID) - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()).Do(func(hdr *wire.Header) { + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()).Do(func(hdr *wire.Header, versions []logging.VersionNumber) { Expect(hdr.Version).To(BeZero()) - Expect(hdr.SupportedVersions).To(And( + Expect(versions).To(And( ContainElement(protocol.VersionNumber(4321)), ContainElement(protocol.VersionNumber(1337)), )) @@ -2181,7 +2181,7 @@ var _ = Describe("Client Session", func() { }() sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1) gomock.InOrder( - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()), + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()), tracer.EXPECT().Close(), ) cryptoSetup.EXPECT().Close()