diff --git a/client.go b/client.go index 3eff5aca..84ad19ac 100644 --- a/client.go +++ b/client.go @@ -287,7 +287,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { } func (c *client) handlePacket(p *receivedPacket) { - if p.hdr.IsVersionNegotiation() { + if wire.IsVersionNegotiationPacket(p.data) { go c.handleVersionNegotiationPacket(p.hdr) return } diff --git a/client_test.go b/client_test.go index d0df05fa..b3581753 100644 --- a/client_test.go +++ b/client_test.go @@ -60,10 +60,11 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) + Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue()) return &receivedPacket{ rcvTime: time.Now(), hdr: hdr, + data: data, } } diff --git a/internal/wire/header.go b/internal/wire/header.go index 9461cccb..740d3789 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -35,6 +35,14 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil } +// IsVersionNegotiationPacket says if this is a version negotiation packet +func IsVersionNegotiationPacket(b []byte) bool { + if len(b) < 5 { + return false + } + return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 +} + var errUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header @@ -129,7 +137,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { return err } h.Version = protocol.VersionNumber(v) - if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 { + if h.Version != 0 && h.typeByte&0x40 == 0 { return errors.New("not a QUIC packet") } connIDLenByte, err := b.ReadByte() @@ -214,11 +222,6 @@ func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error { return nil } -// IsVersionNegotiation says if this a version negotiation packet -func (h *Header) IsVersionNegotiation() bool { - return h.IsLongHeader && h.Version == 0 -} - // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *Header) ParsedLen() protocol.ByteCount { return h.parsedLen diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index b1c8702f..363c5015 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -94,6 +94,24 @@ var _ = Describe("Header Parsing", func() { }) }) + Context("Identifying Version Negotiation Packets", func() { + It("identifies version negotiation packets", func() { + Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) + Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) + }) + + It("returns false on EOF", func() { + vnp := []byte{0x80, 0, 0, 0, 0} + for i := range vnp { + Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) + } + }) + }) + Context("Version Negotiation Packets", func() { It("parses", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} @@ -101,12 +119,12 @@ var _ = Describe("Header Parsing", func() { versions := []protocol.VersionNumber{0x22334455, 0x33445566} vnp, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) + Expect(IsVersionNegotiationPacket(vnp)).To(BeTrue()) hdr, _, rest, err := ParsePacket(vnp, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.Version).To(BeZero()) for _, v := range versions { Expect(hdr.SupportedVersions).To(ContainElement(v)) @@ -150,12 +168,12 @@ var _ = Describe("Header Parsing", func() { hdrLen := len(data) data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number data = append(data, []byte("foobar")...) + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) hdr, pdata, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(pdata).To(Equal(data)) Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) @@ -399,10 +417,11 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x40}, connID...) data = append(data, 0x42) // packet number + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + hdr, pdata, rest, err := ParsePacket(data, 8) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) diff --git a/server_test.go b/server_test.go index cfa36cf1..b5525fff 100644 --- a/server_test.go +++ b/server_test.go @@ -239,8 +239,8 @@ var _ = Describe("Server", func() { var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) hdr := parseHeader(write.data) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(srcConnID)) Expect(hdr.SrcConnectionID).To(Equal(destConnID)) Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42)))