mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
182 lines
4.5 KiB
Go
182 lines
4.5 KiB
Go
package quic
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/lucas-clemente/quic-go/protocol"
|
|
"github.com/lucas-clemente/quic-go/qerr"
|
|
"github.com/lucas-clemente/quic-go/utils"
|
|
)
|
|
|
|
type streamsMap struct {
|
|
streams map[protocol.StreamID]*stream
|
|
openStreams []protocol.StreamID
|
|
|
|
highestStreamOpenedByClient protocol.StreamID
|
|
|
|
mutex sync.RWMutex
|
|
newStream newStreamLambda
|
|
maxNumStreams int
|
|
|
|
roundRobinIndex int
|
|
}
|
|
|
|
type streamLambda func(*stream) (bool, error)
|
|
type newStreamLambda func(protocol.StreamID) (*stream, error)
|
|
|
|
var (
|
|
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
|
)
|
|
|
|
func newStreamsMap(newStream newStreamLambda) *streamsMap {
|
|
maxNumStreams := utils.Max(int(float32(protocol.MaxIncomingDynamicStreams)*protocol.MaxStreamsMultiplier), int(protocol.MaxIncomingDynamicStreams))
|
|
|
|
return &streamsMap{
|
|
streams: map[protocol.StreamID]*stream{},
|
|
openStreams: make([]protocol.StreamID, 0, maxNumStreams),
|
|
newStream: newStream,
|
|
maxNumStreams: maxNumStreams,
|
|
}
|
|
}
|
|
|
|
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
|
|
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
|
|
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
|
m.mutex.RLock()
|
|
s, ok := m.streams[id]
|
|
m.mutex.RUnlock()
|
|
if ok {
|
|
return s, nil // s may be nil
|
|
}
|
|
// ... we don't have an existing stream, try opening a new one
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
|
|
s, ok = m.streams[id]
|
|
if ok {
|
|
return s, nil
|
|
}
|
|
if len(m.openStreams) == m.maxNumStreams {
|
|
return nil, qerr.TooManyOpenStreams
|
|
}
|
|
if id%2 == 0 {
|
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
|
|
}
|
|
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient {
|
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient))
|
|
}
|
|
|
|
s, err := m.newStream(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if id > m.highestStreamOpenedByClient {
|
|
m.highestStreamOpenedByClient = id
|
|
}
|
|
|
|
m.putStream(s)
|
|
return s, nil
|
|
}
|
|
|
|
// OpenStream opens a stream from the server's side
|
|
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
|
panic("OpenStream: not implemented")
|
|
}
|
|
|
|
func (m *streamsMap) Iterate(fn streamLambda) error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
for _, streamID := range m.openStreams {
|
|
str, ok := m.streams[streamID]
|
|
if !ok {
|
|
return errMapAccess
|
|
}
|
|
if str == nil {
|
|
return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID)
|
|
}
|
|
cont, err := fn(str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !cont {
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
numStreams := len(m.openStreams)
|
|
startIndex := m.roundRobinIndex
|
|
|
|
for i := 0; i < numStreams; i++ {
|
|
streamID := m.openStreams[(i+startIndex)%numStreams]
|
|
str, ok := m.streams[streamID]
|
|
if !ok {
|
|
return errMapAccess
|
|
}
|
|
if str == nil {
|
|
return fmt.Errorf("BUG: Stream %d is closed, but still in openStreams map", streamID)
|
|
}
|
|
cont, err := fn(str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !cont {
|
|
break
|
|
}
|
|
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *streamsMap) putStream(s *stream) error {
|
|
id := s.StreamID()
|
|
if _, ok := m.streams[id]; ok {
|
|
return fmt.Errorf("a stream with ID %d already exists", id)
|
|
}
|
|
|
|
m.streams[id] = s
|
|
m.openStreams = append(m.openStreams, id)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Attention: this function must only be called if a mutex has been acquired previously
|
|
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
|
s, ok := m.streams[id]
|
|
if !ok || s == nil {
|
|
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
|
|
}
|
|
|
|
m.streams[id] = nil
|
|
|
|
for i, s := range m.openStreams {
|
|
if s == id {
|
|
// delete the streamID from the openStreams slice
|
|
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
|
|
// adjust round-robin index, if necessary
|
|
if i < m.roundRobinIndex {
|
|
m.roundRobinIndex--
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NumberOfStreams gets the number of open streams
|
|
func (m *streamsMap) NumberOfStreams() int {
|
|
m.mutex.RLock()
|
|
n := len(m.openStreams)
|
|
m.mutex.RUnlock()
|
|
return n
|
|
}
|