diff --git a/stream.go b/stream.go index 479a6bfd..289964e3 100644 --- a/stream.go +++ b/stream.go @@ -14,6 +14,8 @@ type Stream struct { CurrentFrame *frames.StreamFrame ReadPosInFrame int WriteOffset uint64 + ReadOffset uint64 + frameQueue []*frames.StreamFrame // TODO: replace with heap } // NewStream creates a new Stream @@ -30,14 +32,9 @@ func (s *Stream) Read(p []byte) (int, error) { bytesRead := 0 for bytesRead < len(p) { if s.CurrentFrame == nil { - select { - case s.CurrentFrame = <-s.StreamFrames: - default: - if bytesRead == 0 { - s.CurrentFrame = <-s.StreamFrames - } else { - return bytesRead, nil - } + s.CurrentFrame = s.getNextFrameInOrder(bytesRead == 0) + if s.CurrentFrame == nil { + return bytesRead, nil } s.ReadPosInFrame = 0 } @@ -45,6 +42,7 @@ func (s *Stream) Read(p []byte) (int, error) { copy(p[bytesRead:], s.CurrentFrame.Data[s.ReadPosInFrame:]) s.ReadPosInFrame += m bytesRead += m + s.ReadOffset += uint64(m) if s.ReadPosInFrame >= len(s.CurrentFrame.Data) { s.CurrentFrame = nil } @@ -53,6 +51,44 @@ func (s *Stream) Read(p []byte) (int, error) { return bytesRead, nil } +func (s *Stream) getNextFrameInOrder(wait bool) *frames.StreamFrame { + // First, check the queue + for i, f := range s.frameQueue { + if f.Offset == s.ReadOffset { + // Move last element into position i + s.frameQueue[i] = s.frameQueue[len(s.frameQueue)-1] + s.frameQueue = s.frameQueue[:len(s.frameQueue)-1] + return f + } + } + + // TODO: Handle error and break while(true) loop + for { + var nextFrameFromChannel *frames.StreamFrame + if wait { + nextFrameFromChannel = <-s.StreamFrames + } else { + select { + case nextFrameFromChannel = <-s.StreamFrames: + default: + return nil + } + } + + if nextFrameFromChannel.Offset == s.ReadOffset { + return nextFrameFromChannel + } + + // Discard if we already know it + if nextFrameFromChannel.Offset < s.ReadOffset { + continue + } + + // Append to queue + s.frameQueue = append(s.frameQueue, nextFrameFromChannel) + } +} + // ReadByte implements io.ByteReader func (s *Stream) ReadByte() (byte, error) { // TODO: Optimize diff --git a/stream_test.go b/stream_test.go index 833b21b0..bef9a197 100644 --- a/stream_test.go +++ b/stream_test.go @@ -116,18 +116,135 @@ var _ = Describe("Stream", func() { Expect(n).To(Equal(2)) }) - PIt("rejects StreamFrames with wrong Offsets", func() { + It("handles StreamFrames in wrong order", func() { + frame1 := frames.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + frame2 := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + stream := NewStream(nil, 1337) + stream.AddStreamFrame(&frame1) + stream.AddStreamFrame(&frame2) + b := make([]byte, 4) + n, err := stream.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("handles duplicate StreamFrames", func() { + frame1 := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame2 := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + frame3 := frames.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + } + stream := NewStream(nil, 1337) + stream.AddStreamFrame(&frame1) + stream.AddStreamFrame(&frame2) + stream.AddStreamFrame(&frame3) + b := make([]byte, 4) + n, err := stream.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + It("discards unneeded stream frames", func() { frame1 := frames.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, } frame2 := frames.StreamFrame{ Offset: 1, + Data: []byte{0x42, 0x24}, + } + frame3 := frames.StreamFrame{ + Offset: 2, Data: []byte{0xBE, 0xEF}, } stream := NewStream(nil, 1337) stream.AddStreamFrame(&frame1) - err := stream.AddStreamFrame(&frame2) - Expect(err).To(HaveOccurred()) + stream.AddStreamFrame(&frame2) + stream.AddStreamFrame(&frame3) + b := make([]byte, 4) + n, err := stream.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + }) + + Context("getting next stream frame", func() { + It("gets next frame", func() { + stream := NewStream(nil, 1337) + stream.AddStreamFrame(&frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + }) + f := stream.getNextFrameInOrder(true) + Expect(f.Data).To(Equal([]byte{0xDE, 0xAD})) + }) + + It("waits for next frame", func() { + stream := NewStream(nil, 1337) + var b bool + go func() { + time.Sleep(time.Millisecond) + b = true + stream.AddStreamFrame(&frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + }) + }() + f := stream.getNextFrameInOrder(true) + Expect(b).To(BeTrue()) + Expect(f.Data).To(Equal([]byte{0xDE, 0xAD})) + }) + + It("queues non-matching stream frames", func() { + stream := NewStream(nil, 1337) + var b bool + stream.AddStreamFrame(&frames.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + }) + go func() { + time.Sleep(time.Millisecond) + b = true + stream.AddStreamFrame(&frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + }) + }() + f := stream.getNextFrameInOrder(true) + Expect(b).To(BeTrue()) + Expect(f.Data).To(Equal([]byte{0xDE, 0xAD})) + stream.ReadOffset += 2 + f = stream.getNextFrameInOrder(true) + Expect(f.Data).To(Equal([]byte{0xBE, 0xEF})) + }) + + It("returns nil if non-blocking", func() { + stream := NewStream(nil, 1337) + Expect(stream.getNextFrameInOrder(false)).To(BeNil()) + }) + + It("returns properly if non-blocking", func() { + stream := NewStream(nil, 1337) + stream.AddStreamFrame(&frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + }) + Expect(stream.getNextFrameInOrder(false)).ToNot(BeNil()) + }) }) })