diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index 7c561ea1..fde15d40 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -42,6 +42,11 @@ func Fuzz(data []byte) int { panic("inconsistent 0-RTT packet detection") } + if !wire.IsLongHeaderPacket(data[0]) { + wire.ParseShortHeader(data, connIDLen) + return 1 + } + var extHdr *wire.ExtendedHeader // Parse the extended header, if this is not a Retry packet. if hdr.Type == protocol.PacketTypeRetry { @@ -58,9 +63,6 @@ func Fuzz(data []byte) int { if hdr.IsLongHeader && hdr.Length > 16383 { return 1 } - if !hdr.IsLongHeader { - return 1 - } b := &bytes.Buffer{} if err := extHdr.Write(b, version); err != nil { // We are able to parse packets with connection IDs longer than 20 bytes, diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 27cb5730..455576d7 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -35,6 +35,9 @@ var _ = Describe("0-RTT", func() { RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { for len(data) > 0 { + if !wire.IsLongHeaderPacket(data[0]) { + break + } hdr, _, rest, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) if hdr.Type == protocol.PacketType0RTT { @@ -347,14 +350,19 @@ var _ = Describe("0-RTT", func() { proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - hdr, _, _, err := wire.ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) + if wire.IsLongHeaderPacket(data[0]) { + hdr, _, _, err := wire.ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + } } return rtt / 2 }, DropPacket: func(_ quicproxy.Direction, data []byte) bool { + if !wire.IsLongHeaderPacket(data[0]) { + return false + } hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) if hdr.Type == protocol.PacketType0RTT { diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 310546f2..3c1d0aef 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -46,7 +46,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool if h.IsLongHeader { reservedBitsValid, err = h.parseLongHeader(b, v) } else { - reservedBitsValid, err = h.parseShortHeader(b, v) + panic("parsed a short header packet") } if err != nil { return false, err @@ -65,21 +65,6 @@ func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumb return true, nil } -func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { - h.KeyPhase = protocol.KeyPhaseZero - if h.typeByte&0x4 > 0 { - h.KeyPhase = protocol.KeyPhaseOne - } - - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0x18 != 0 { - return false, nil - } - return true, nil -} - func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 switch h.PacketNumberLen { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 87d854b7..d25ec3e5 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -485,147 +485,6 @@ var _ = Describe("Header Parsing", func() { }) }) - Context("Short Headers", func() { - It("reads a Short Header with a 8 byte connection ID", func() { - connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) - data := append([]byte{0x40}, connID.Bytes()...) - data = append(data, 0x42) // packet number - Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) - - hdr, pdata, rest, err := ParsePacket(data, 8) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeZero()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) - Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) - Expect(pdata).To(Equal(data)) - Expect(rest).To(BeEmpty()) - }) - - It("errors if 0x40 is not set", func() { - connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) - data := append([]byte{0x0}, connID.Bytes()...) - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(MatchError("not a QUIC packet")) - }) - - It("errors if the 4th or 5th bit are set", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID.Bytes()...) - data = append(data, 0x42) // packet number - hdr, _, _, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(extHdr).ToNot(BeNil()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - }) - - It("reads a Short Header with a 5 byte connection ID", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - data := append([]byte{0x40}, connID.Bytes()...) - data = append(data, 0x42) // packet number - hdr, pdata, rest, err := ParsePacket(data, 5) - Expect(err).ToNot(HaveOccurred()) - Expect(pdata).To(HaveLen(len(data))) - Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.DestConnectionID).To(Equal(connID)) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) - Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeZero()) - Expect(rest).To(BeEmpty()) - }) - - It("reads the Key Phase Bit", func() { - data := []byte{ - 0x40 ^ 0x4, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - data = append(data, 11) // packet number - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsLongHeader).To(BeFalse()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 2 byte packet number", func() { - data := []byte{ - 0x40 | 0x1, - 0xde, 0xad, 0xbe, 0xef, // connection ID - } - data = append(data, []byte{0x13, 0x37}...) // packet number - hdr, _, _, err := ParsePacket(data, 4) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("reads a header with a 3 byte packet number", func() { - data := []byte{ - 0x40 | 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID - } - data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number - hdr, _, _, err := ParsePacket(data, 10) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.IsLongHeader).To(BeFalse()) - Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOF, when parsing the header", func() { - data := []byte{ - 0x40 ^ 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - } - for i := 0; i < len(data); i++ { - data = data[:i] - _, _, _, err := ParsePacket(data, 8) - Expect(err).To(Equal(io.EOF)) - } - }) - - It("errors on EOF, when parsing the extended header", func() { - data := []byte{ - 0x40 ^ 0x3, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID - } - hdrLen := len(data) - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number - for i := hdrLen; i < len(data); i++ { - data = data[:i] - hdr, _, _, err := ParsePacket(data, 6) - Expect(err).ToNot(HaveOccurred()) - _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) - Expect(err).To(Equal(io.EOF)) - } - }) - }) - It("distinguishes long and short header packets", func() { Expect(IsLongHeaderPacket(0x40)).To(BeFalse()) Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue()) diff --git a/packet_packer_test.go b/packet_packer_test.go index 87087a43..b7be5c49 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -24,7 +24,7 @@ import ( var _ = Describe("Packet packer", func() { const maxPacketSize protocol.ByteCount = 1357 - const version = protocol.VersionTLS + const version = protocol.Version1 var ( packer *packetPacker @@ -39,24 +39,28 @@ var _ = Describe("Packet packer", func() { ) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - parsePacket := func(data []byte) []*wire.ExtendedHeader { - var hdrs []*wire.ExtendedHeader + parsePacket := func(data []byte) (hdrs []*wire.ExtendedHeader, more []byte) { for len(data) > 0 { - hdr, payload, rest, err := wire.ParsePacket(data, connID.Len()) + if !wire.IsLongHeaderPacket(data[0]) { + break + } + hdr, _, more, err := wire.ParsePacket(data, connID.Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(r, version) Expect(err).ToNot(HaveOccurred()) - if extHdr.IsLongHeader { - ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(rest) + int(extHdr.PacketNumberLen))) - ExpectWithOffset(1, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) - } else { - ExpectWithOffset(1, len(payload)+int(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) - } - data = rest + ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() - len(more) + int(extHdr.PacketNumberLen))) + ExpectWithOffset(1, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen)).To(BeNumerically(">=", 4)) + data = more hdrs = append(hdrs, extHdr) } - return hdrs + return hdrs, data + } + + parseShortHeaderPacket := func(data []byte) { + l, _, pnLen, _, err := wire.ParseShortHeader(data, connID.Len()) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, len(data)-l+int(pnLen)).To(BeNumerically(">=", 4)) } appendFrames := func(fs, frames []ackhandler.Frame) ([]ackhandler.Frame, protocol.ByteCount) { @@ -484,10 +488,11 @@ var _ = Describe("Packet packer", func() { Expect(ccf.IsApplicationError).To(BeTrue()) Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(ccf.ReasonPhrase).To(Equal("test error")) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) + Expect(more).To(BeEmpty()) }) }) @@ -693,13 +698,11 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added buffer.Data = buffer.Data[:buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(buffer.Data, packer.getDestConnID().Len()) - Expect(err).ToNot(HaveOccurred()) data := buffer.Data - r := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(r, packer.version) + l, _, pnLen, _, err := wire.ParseShortHeader(data, connID.Len()) Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + r := bytes.NewReader(data[l:]) + Expect(pnLen).To(Equal(protocol.PacketNumberLen1)) Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) // the first byte of the payload should be a PADDING frame... firstPayloadByte, err := r.ReadByte() @@ -974,9 +977,10 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(1)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(more).To(BeEmpty()) }) It("packs a maximum size Handshake packet", func() { @@ -1033,10 +1037,11 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) Expect(p.longHdrPackets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(more).To(BeEmpty()) }) It("packs a coalesced packet with Initial / super short Handshake, and pads it", func() { @@ -1066,10 +1071,11 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(more).To(BeEmpty()) }) It("packs a coalesced packet with super short Initial / super short Handshake, and pads it", func() { @@ -1095,10 +1101,11 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdrs[1].Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(more).To(BeEmpty()) }) It("packs a coalesced packet with Initial / super short 1-RTT, and pads it", func() { @@ -1128,10 +1135,11 @@ var _ = Describe("Packet packer", func() { Expect(p.shortHdrPacket).ToNot(BeNil()) Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) Expect(p.shortHdrPacket.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - hdrs := parsePacket(p.buffer.Data) - Expect(hdrs).To(HaveLen(2)) + hdrs, more := parsePacket(p.buffer.Data) + Expect(hdrs).To(HaveLen(1)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) - Expect(hdrs[1].IsLongHeader).To(BeFalse()) + Expect(more).ToNot(BeEmpty()) + parseShortHeaderPacket(more) }) It("packs a coalesced packet with Initial / 0-RTT, and pads it", func() { @@ -1164,10 +1172,11 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) Expect(p.longHdrPackets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - hdrs := parsePacket(p.buffer.Data) + hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdrs[1].Type).To(Equal(protocol.PacketType0RTT)) + Expect(more).To(BeEmpty()) }) It("packs a coalesced packet with Handshake / 1-RTT", func() { @@ -1197,13 +1206,13 @@ var _ = Describe("Packet packer", func() { Expect(p.shortHdrPacket).ToNot(BeNil()) Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) Expect(p.shortHdrPacket.Frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) + hdr, _, more, err := wire.ParsePacket(p.buffer.Data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - hdr, _, rest, err = wire.ParsePacket(rest, 0) + hdr, _, more, err = wire.ParsePacket(more, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(rest).To(BeEmpty()) + Expect(more).To(BeEmpty()) }) It("doesn't add a coalesced packet if the remaining size is smaller than MaxCoalescedPacketSize", func() {