diff --git a/streams_map.go b/streams_map.go index 41d24615..d9ea8d63 100644 --- a/streams_map.go +++ b/streams_map.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "sync" @@ -11,10 +12,16 @@ type streamsMap struct { streams map[protocol.StreamID]*stream openStreams []protocol.StreamID mutex sync.RWMutex + + roundRobinIndex int } type streamLambda func(*stream) (bool, error) +var ( + errMapAccess = errors.New("streamsMap: Error accessing the streams map") +) + func newStreamsMap() *streamsMap { maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier) return &streamsMap{ @@ -49,6 +56,31 @@ func (m *streamsMap) Iterate(fn streamLambda) error { return nil } +func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + numStreams := len(m.openStreams) + startIndex := m.roundRobinIndex + + for i := 0; i < numStreams; i++ { + streamID := m.openStreams[(i+startIndex)%numStreams] + str, ok := m.streams[streamID] + if !ok { + return errMapAccess + } + cont, err := fn(str) + m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams + if err != nil { + return err + } + if !cont { + break + } + } + return nil +} + func (m *streamsMap) PutStream(s *stream) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -77,6 +109,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { if s == id { // delete the streamID from the openStreams slice m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] + // adjust round-robin index, if necessary + if i < m.roundRobinIndex { + m.roundRobinIndex-- + } + break } } diff --git a/streams_map_test.go b/streams_map_test.go index 4c752d19..31d60dc3 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -116,8 +116,8 @@ var _ = Describe("Streams Map", func() { }) }) - Context("Lambda", func() { - // create 5 streams, ids 1 to 3 + Context("Iterate", func() { + // create 3 streams, ids 1 to 3 BeforeEach(func() { for i := 1; i <= 3; i++ { err := m.PutStream(&stream{streamID: protocol.StreamID(i)}) @@ -166,4 +166,86 @@ var _ = Describe("Streams Map", func() { Expect(numIterations).To(Equal(1)) }) }) + + Context("RoundRobinIterate", func() { + // create 5 streams, ids 1 to 5 + var lambdaCalledForStream []protocol.StreamID + var numIterations int + + BeforeEach(func() { + lambdaCalledForStream = lambdaCalledForStream[:0] + numIterations = 0 + for i := 1; i <= 5; i++ { + err := m.PutStream(&stream{streamID: protocol.StreamID(i)}) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("executes the lambda exactly once for every stream", func() { + fn := func(str *stream) (bool, error) { + lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) + numIterations++ + return true, nil + } + err := m.RoundRobinIterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(numIterations).To(Equal(5)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) + Expect(m.roundRobinIndex).To(BeZero()) + }) + + It("goes around once when starting in the middle", func() { + fn := func(str *stream) (bool, error) { + lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) + numIterations++ + return true, nil + } + m.roundRobinIndex = 3 + err := m.RoundRobinIterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(numIterations).To(Equal(5)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 1, 2, 3})) + Expect(m.roundRobinIndex).To(Equal(3)) + }) + + It("picks up at the index where it last stopped", func() { + fn := func(str *stream) (bool, error) { + lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) + numIterations++ + if str.StreamID() == 2 || str.StreamID() == 4 { + return false, nil + } + return true, nil + } + err := m.RoundRobinIterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(numIterations).To(Equal(2)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2})) + Expect(m.roundRobinIndex).To(Equal(2)) + numIterations = 0 + lambdaCalledForStream = lambdaCalledForStream[:0] + err = m.RoundRobinIterate(fn) + Expect(err).ToNot(HaveOccurred()) + Expect(numIterations).To(Equal(2)) + Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{3, 4})) + }) + + It("adjust the RoundRobinIndex when deleting an element in front", func() { + m.roundRobinIndex = 3 // stream 4 + m.RemoveStream(2) + Expect(m.roundRobinIndex).To(Equal(2)) + }) + + It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() { + m.roundRobinIndex = 1 // stream 2 + m.RemoveStream(4) + Expect(m.roundRobinIndex).To(Equal(1)) + }) + + It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() { + m.roundRobinIndex = 3 // stream 4 + m.RemoveStream(4) + Expect(m.roundRobinIndex).To(Equal(3)) + }) + }) })