diff --git a/fuzzing/frames/fuzz.go b/fuzzing/frames/fuzz.go index 1df2c81f..cd0409bc 100644 --- a/fuzzing/frames/fuzz.go +++ b/fuzzing/frames/fuzz.go @@ -33,7 +33,7 @@ func Fuzz(data []byte) int { encLevel := toEncLevel(data[0]) data = data[PrefixLen:] - parser := wire.NewFrameParser(version) + parser := wire.NewFrameParser(true, version) parser.SetAckDelayExponent(protocol.DefaultAckDelayExponent) r := bytes.NewReader(data) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 5d2ae038..a858989e 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -13,12 +13,17 @@ import ( type frameParser struct { ackDelayExponent uint8 + supportsDatagrams bool + version protocol.VersionNumber } // NewFrameParser creates a new frame parser. -func NewFrameParser(v protocol.VersionNumber) FrameParser { - return &frameParser{version: v} +func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { + return &frameParser{ + supportsDatagrams: supportsDatagrams, + version: v, + } } // ParseNextFrame parses the next frame @@ -88,7 +93,11 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protoc case 0x1e: frame, err = parseHandshakeDoneFrame(r, p.version) case 0x30, 0x31: - frame, err = parseDatagramFrame(r, p.version) + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, p.version) + break + } + fallthrough default: err = errors.New("unknown frame type") } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 47cacdd4..800af60f 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -18,7 +18,7 @@ var _ = Describe("Frame parsing", func() { BeforeEach(func() { buf = &bytes.Buffer{} - parser = NewFrameParser(versionIETFFrames) + parser = NewFrameParser(true, versionIETFFrames) }) It("returns nil if there's nothing more to read", func() { @@ -280,6 +280,24 @@ var _ = Describe("Frame parsing", func() { Expect(frame).To(Equal(f)) }) + It("unpacks DATAGRAM frames", func() { + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when DATAGRAM frames are not supported", func() { + parser = NewFrameParser(false, versionIETFFrames) + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x30): unknown frame type")) + }) + It("errors on invalid type", func() { _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x42): unknown frame type")) diff --git a/packet_packer_test.go b/packet_packer_test.go index 92b883d3..1b01722c 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -617,7 +617,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(secondPayloadByte).To(Equal(byte(0))) // ... followed by the PING - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(false, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) @@ -654,7 +654,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the STREAM frame - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(true, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) @@ -1166,7 +1166,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(secondPayloadByte).To(Equal(byte(0))) // ... followed by the PING - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(false, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) diff --git a/server_test.go b/server_test.go index f95a5907..cc9902a5 100644 --- a/server_test.go +++ b/server_test.go @@ -521,7 +521,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := f.(*wire.ConnectionCloseFrame) diff --git a/session.go b/session.go index 24b165a1..761e91ad 100644 --- a/session.go +++ b/session.go @@ -477,7 +477,7 @@ var newClientSession = func( func (s *session) preSetup() { s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue(s.version) - s.frameParser = wire.NewFrameParser(s.version) + s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.InitialMaxData, diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 376e1811..d59045d0 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -39,7 +39,7 @@ var _ = Describe("Streams Map (incoming)", func() { checkFrameSerialization := func(f wire.Frame) { b := &bytes.Buffer{} ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) - frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) + frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) ExpectWithOffset(1, err).ToNot(HaveOccurred()) Expect(f).To(Equal(frame)) }