implement a round-robin iterate function in StreamsMap

ref #207
This commit is contained in:
Marten Seemann 2016-08-06 14:13:52 +07:00
parent 64b0e03234
commit b0d116ad5a
2 changed files with 121 additions and 2 deletions

View file

@ -1,6 +1,7 @@
package quic
import (
"errors"
"fmt"
"sync"
@ -11,10 +12,16 @@ type streamsMap struct {
streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID
mutex sync.RWMutex
roundRobinIndex int
}
type streamLambda func(*stream) (bool, error)
var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap() *streamsMap {
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
return &streamsMap{
@ -49,6 +56,31 @@ func (m *streamsMap) Iterate(fn streamLambda) error {
return nil
}
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
numStreams := len(m.openStreams)
startIndex := m.roundRobinIndex
for i := 0; i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams]
str, ok := m.streams[streamID]
if !ok {
return errMapAccess
}
cont, err := fn(str)
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if err != nil {
return err
}
if !cont {
break
}
}
return nil
}
func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -77,6 +109,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
if s == id {
// delete the streamID from the openStreams slice
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
// adjust round-robin index, if necessary
if i < m.roundRobinIndex {
m.roundRobinIndex--
}
break
}
}