implement GetOrOpenStream in streamsMap

This commit is contained in:
Lucas Clemente 2016-08-08 14:31:33 +02:00
parent 77580dbf96
commit 65663c3314
5 changed files with 105 additions and 12 deletions

View file

@ -6,27 +6,34 @@ import (
"sync"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
maxNumStreams = int(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
)
type streamsMap struct {
streams map[protocol.StreamID]*stream
openStreams []protocol.StreamID
mutex sync.RWMutex
newStream newStreamLambda
roundRobinIndex int
}
type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) (*stream, error)
var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap() *streamsMap {
maxNumStreams := uint32(float32(protocol.MaxStreamsPerConnection) * protocol.MaxStreamsMultiplier)
func newStreamsMap(newStream newStreamLambda) *streamsMap {
return &streamsMap{
streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
newStream: newStream,
}
}
@ -40,6 +47,33 @@ func (m *streamsMap) GetStream(id protocol.StreamID) (*stream, bool) {
return s, true
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream, try opening a new one
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if len(m.openStreams) == maxNumStreams {
return nil, qerr.TooManyOpenStreams
}
s, err := m.newStream(id)
if err != nil {
return nil, err
}
m.putStreamImpl(s)
return s, nil
}
func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -94,7 +128,10 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
func (m *streamsMap) PutStream(s *stream) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.putStreamImpl(s)
}
func (m *streamsMap) putStreamImpl(s *stream) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)