mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
parent
64b0e03234
commit
b0d116ad5a
2 changed files with 121 additions and 2 deletions
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -11,10 +12,16 @@ type streamsMap struct {
|
||||||
streams map[protocol.StreamID]*stream
|
streams map[protocol.StreamID]*stream
|
||||||
openStreams []protocol.StreamID
|
openStreams []protocol.StreamID
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
roundRobinIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamLambda func(*stream) (bool, error)
|
type streamLambda func(*stream) (bool, error)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||||
|
)
|
||||||
|
|
||||||
func newStreamsMap() *streamsMap {
|
func newStreamsMap() *streamsMap {
|
||||||
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
|
||||||
return &streamsMap{
|
return &streamsMap{
|
||||||
|
@ -49,6 +56,31 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
numStreams := len(m.openStreams)
|
||||||
|
startIndex := m.roundRobinIndex
|
||||||
|
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
streamID := m.openStreams[(i+startIndex)%numStreams]
|
||||||
|
str, ok := m.streams[streamID]
|
||||||
|
if !ok {
|
||||||
|
return errMapAccess
|
||||||
|
}
|
||||||
|
cont, err := fn(str)
|
||||||
|
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !cont {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *streamsMap) PutStream(s *stream) error {
|
func (m *streamsMap) PutStream(s *stream) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
@ -77,6 +109,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
||||||
if s == id {
|
if s == id {
|
||||||
// delete the streamID from the openStreams slice
|
// delete the streamID from the openStreams slice
|
||||||
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
|
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
|
||||||
|
// adjust round-robin index, if necessary
|
||||||
|
if i < m.roundRobinIndex {
|
||||||
|
m.roundRobinIndex--
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -116,8 +116,8 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Lambda", func() {
|
Context("Iterate", func() {
|
||||||
// create 5 streams, ids 1 to 3
|
// create 3 streams, ids 1 to 3
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
err := m.PutStream(&stream{streamID: protocol.StreamID(i)})
|
err := m.PutStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
|
@ -166,4 +166,86 @@ var _ = Describe("Streams Map", func() {
|
||||||
Expect(numIterations).To(Equal(1))
|
Expect(numIterations).To(Equal(1))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("RoundRobinIterate", func() {
|
||||||
|
// create 5 streams, ids 1 to 5
|
||||||
|
var lambdaCalledForStream []protocol.StreamID
|
||||||
|
var numIterations int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||||
|
numIterations = 0
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
err := m.PutStream(&stream{streamID: protocol.StreamID(i)})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("executes the lambda exactly once for every stream", func() {
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(5))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
||||||
|
Expect(m.roundRobinIndex).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("goes around once when starting in the middle", func() {
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
m.roundRobinIndex = 3
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(5))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 1, 2, 3}))
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(3))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("picks up at the index where it last stopped", func() {
|
||||||
|
fn := func(str *stream) (bool, error) {
|
||||||
|
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
|
||||||
|
numIterations++
|
||||||
|
if str.StreamID() == 2 || str.StreamID() == 4 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
err := m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(2))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{1, 2}))
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(2))
|
||||||
|
numIterations = 0
|
||||||
|
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||||
|
err = m.RoundRobinIterate(fn)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(numIterations).To(Equal(2))
|
||||||
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{3, 4}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
||||||
|
m.roundRobinIndex = 3 // stream 4
|
||||||
|
m.RemoveStream(2)
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(2))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
||||||
|
m.roundRobinIndex = 1 // stream 2
|
||||||
|
m.RemoveStream(4)
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
||||||
|
m.roundRobinIndex = 3 // stream 4
|
||||||
|
m.RemoveStream(4)
|
||||||
|
Expect(m.roundRobinIndex).To(Equal(3))
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue