From c69992cae42837b6543aed75feffe22014e85191 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 17 Apr 2016 10:46:31 +0700 Subject: [PATCH] parse RST_STREAM frames --- frames/rst_stream_frame.go | 50 +++++++++++++++++++++++++++++++++ frames/rst_stream_frame_test.go | 24 ++++++++++++++++ session.go | 11 +++++++- utils/utils.go | 31 ++++++++++++++++++++ utils/utils_test.go | 15 ++++++++++ 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 frames/rst_stream_frame.go create mode 100644 frames/rst_stream_frame_test.go diff --git a/frames/rst_stream_frame.go b/frames/rst_stream_frame.go new file mode 100644 index 00000000..d1b089ed --- /dev/null +++ b/frames/rst_stream_frame.go @@ -0,0 +1,50 @@ +package frames + +import ( + "bytes" + "errors" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" +) + +// A RstStreamFrame in QUIC +type RstStreamFrame struct { + StreamID protocol.StreamID + ByteOffset uint64 + ErrorCode uint32 +} + +//Write writes a RST_STREAM frame +func (f *RstStreamFrame) Write(b *bytes.Buffer) error { + return errors.New("RstStreamFrame: Write not yet implemented") +} + +// ParseRstStreamFrame parses a RST_STREAM frame +func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) { + frame := &RstStreamFrame{} + + // read the TypeByte + _, err := r.ReadByte() + if err != nil { + return nil, err + } + + sid, err := utils.ReadUint32(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + + frame.ByteOffset, err = utils.ReadUint64(r) + if err != nil { + return nil, err + } + + frame.ErrorCode, err = utils.ReadUint32(r) + if err != nil { + return nil, err + } + + return frame, nil +} diff --git a/frames/rst_stream_frame_test.go b/frames/rst_stream_frame_test.go new file mode 100644 index 00000000..7bab6a54 --- /dev/null +++ b/frames/rst_stream_frame_test.go @@ -0,0 +1,24 @@ +package frames + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RstStreamFrame", func() { + Context("rst stream frames", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13}) + frame, err := ParseRstStreamFrame(b) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xDEADBEEF))) + Expect(frame.ByteOffset).To(Equal(uint64(0xDECAFBAD11223344))) + Expect(frame.ErrorCode).To(Equal(uint32(0x13371234))) + }) + }) + }) +}) diff --git a/session.go b/session.go index 52941119..ed3513b6 100644 --- a/session.go +++ b/session.go @@ -101,7 +101,7 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub case 0x0: // PAD return nil case 0x01: - err = errors.New("unimplemented: RST_STREAM") + err = s.handleRstStreamFrame(r) case 0x02: err = s.handleConnectionCloseFrame(r) case 0x03: @@ -198,6 +198,15 @@ func (s *Session) handleStopWaitingFrame(r *bytes.Reader, publicHeader *PublicHe return nil } +func (s *Session) handleRstStreamFrame(r *bytes.Reader) error { + frame, err := frames.ParseRstStreamFrame(r) + if err != nil { + return err + } + fmt.Printf("%#v\n", frame) + return nil +} + // SendFrames sends a number of frames to the client func (s *Session) SendFrames(frames []frames.Frame) error { var framesData bytes.Buffer diff --git a/utils/utils.go b/utils/utils.go index 26d3a093..d96d8370 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -18,6 +18,37 @@ func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { return res, nil } +// ReadUint64 reads a uint64 +func ReadUint64(b io.ByteReader) (uint64, error) { + var b1, b2, b3, b4, b5, b6, b7, b8 uint8 + var err error + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b5, err = b.ReadByte(); err != nil { + return 0, err + } + if b6, err = b.ReadByte(); err != nil { + return 0, err + } + if b7, err = b.ReadByte(); err != nil { + return 0, err + } + if b8, err = b.ReadByte(); err != nil { + return 0, err + } + return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil +} + // ReadUint32 reads a uint32 func ReadUint32(b io.ByteReader) (uint32, error) { var b1, b2, b3, b4 uint8 diff --git a/utils/utils_test.go b/utils/utils_test.go index 3446a563..d105e747 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -53,6 +53,21 @@ var _ = Describe("Utils", func() { }) }) + Context("ReadUint64", func() { + It("reads a little endian", func() { + b := []byte{0x12, 0x35, 0xAB, 0xFF, 0xEF, 0xBE, 0xAD, 0xDE} + val, err := ReadUint64(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(0xDEADBEEFFFAB3512))) + }) + + It("throws an error if less than 8 bytes are passed", func() { + b := []byte{0x13, 0x34, 0xEA, 0x00, 0x14, 0xAA} + _, err := ReadUint64(bytes.NewReader(b)) + Expect(err).To(HaveOccurred()) + }) + }) + Context("WriteUint16", func() { It("outputs 2 bytes", func() { b := &bytes.Buffer{}