mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
use stream numbers, not stream ids, in the stream maps
This commit is contained in:
parent
a8633a952c
commit
857e4ae9a9
12 changed files with 524 additions and 467 deletions
|
@ -1,19 +1,5 @@
|
|||
package protocol
|
||||
|
||||
// A StreamID in QUIC
|
||||
type StreamID int64
|
||||
|
||||
// InvalidPacketNumber is a stream ID that is invalid.
|
||||
// The first valid stream ID in QUIC is 0.
|
||||
const InvalidStreamID StreamID = -1
|
||||
|
||||
// StreamNum is the stream number
|
||||
type StreamNum int64
|
||||
|
||||
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
|
||||
// and as the stream count in the transport parameters
|
||||
const MaxStreamCount StreamNum = 1 << 60
|
||||
|
||||
// StreamType encodes if this is a unidirectional or bidirectional stream
|
||||
type StreamType uint8
|
||||
|
||||
|
@ -24,6 +10,49 @@ const (
|
|||
StreamTypeBidi
|
||||
)
|
||||
|
||||
// InvalidPacketNumber is a stream ID that is invalid.
|
||||
// The first valid stream ID in QUIC is 0.
|
||||
const InvalidStreamID StreamID = -1
|
||||
|
||||
// StreamNum is the stream number
|
||||
type StreamNum int64
|
||||
|
||||
const (
|
||||
// InvalidStreamNum is an invalid stream number.
|
||||
InvalidStreamNum = -1
|
||||
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
|
||||
// and as the stream count in the transport parameters
|
||||
MaxStreamCount StreamNum = 1 << 60
|
||||
)
|
||||
|
||||
// StreamID calculates the stream ID.
|
||||
func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID {
|
||||
if s == 0 {
|
||||
return InvalidStreamID
|
||||
}
|
||||
var first StreamID
|
||||
switch stype {
|
||||
case StreamTypeBidi:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 0
|
||||
case PerspectiveServer:
|
||||
first = 1
|
||||
}
|
||||
case StreamTypeUni:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 2
|
||||
case PerspectiveServer:
|
||||
first = 3
|
||||
}
|
||||
}
|
||||
return first + 4*StreamID(s-1)
|
||||
}
|
||||
|
||||
// A StreamID in QUIC
|
||||
type StreamID int64
|
||||
|
||||
// InitiatedBy says if the stream was initiated by the client or by the server
|
||||
func (s StreamID) InitiatedBy() Perspective {
|
||||
if s%2 == 0 {
|
||||
|
@ -45,34 +74,3 @@ func (s StreamID) Type() StreamType {
|
|||
func (s StreamID) StreamNum() StreamNum {
|
||||
return StreamNum(s/4) + 1
|
||||
}
|
||||
|
||||
// MaxStreamID is the highest stream ID that a peer is allowed to open,
|
||||
// when it is allowed to open numStreams.
|
||||
func MaxStreamID(stype StreamType, numStreams StreamNum, pers Perspective) StreamID {
|
||||
if numStreams == 0 {
|
||||
return InvalidStreamID
|
||||
}
|
||||
var first StreamID
|
||||
switch stype {
|
||||
case StreamTypeBidi:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 0
|
||||
case PerspectiveServer:
|
||||
first = 1
|
||||
}
|
||||
case StreamTypeUni:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 2
|
||||
case PerspectiveServer:
|
||||
first = 3
|
||||
}
|
||||
}
|
||||
return first + 4*StreamID(numStreams-1)
|
||||
}
|
||||
|
||||
// FirstStream returns the first valid stream ID
|
||||
func FirstStream(stype StreamType, pers Perspective) StreamID {
|
||||
return MaxStreamID(stype, 1, pers)
|
||||
}
|
||||
|
|
|
@ -24,13 +24,6 @@ var _ = Describe("Stream ID", func() {
|
|||
Expect(StreamID(7).Type()).To(Equal(StreamTypeUni))
|
||||
})
|
||||
|
||||
It("tells the first stream ID", func() {
|
||||
Expect(FirstStream(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0)))
|
||||
Expect(FirstStream(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1)))
|
||||
Expect(FirstStream(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2)))
|
||||
Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
|
||||
})
|
||||
|
||||
It("tells the stream number", func() {
|
||||
Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1))
|
||||
Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1))
|
||||
|
@ -42,26 +35,26 @@ var _ = Describe("Stream ID", func() {
|
|||
Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3))
|
||||
})
|
||||
|
||||
Context("maximum stream IDs", func() {
|
||||
It("doesn't allow any", func() {
|
||||
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(InvalidStreamID))
|
||||
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveServer)).To(Equal(InvalidStreamID))
|
||||
Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveClient)).To(Equal(InvalidStreamID))
|
||||
Expect(MaxStreamID(StreamTypeUni, 0, PerspectiveServer)).To(Equal(InvalidStreamID))
|
||||
Context("converting stream nums to stream IDs", func() {
|
||||
It("handles 0", func() {
|
||||
Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(InvalidStreamID))
|
||||
Expect(StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(InvalidStreamID))
|
||||
Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(InvalidStreamID))
|
||||
Expect(StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(InvalidStreamID))
|
||||
})
|
||||
|
||||
It("allows one", func() {
|
||||
Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveClient)).To(Equal(StreamID(0)))
|
||||
Expect(MaxStreamID(StreamTypeBidi, 1, PerspectiveServer)).To(Equal(StreamID(1)))
|
||||
Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveClient)).To(Equal(StreamID(2)))
|
||||
Expect(MaxStreamID(StreamTypeUni, 1, PerspectiveServer)).To(Equal(StreamID(3)))
|
||||
It("handles the first", func() {
|
||||
Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(0)))
|
||||
Expect(StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(1)))
|
||||
Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(2)))
|
||||
Expect(StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
|
||||
})
|
||||
|
||||
It("allows many", func() {
|
||||
Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveClient)).To(Equal(StreamID(396)))
|
||||
Expect(MaxStreamID(StreamTypeBidi, 100, PerspectiveServer)).To(Equal(StreamID(397)))
|
||||
Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveClient)).To(Equal(StreamID(398)))
|
||||
Expect(MaxStreamID(StreamTypeUni, 100, PerspectiveServer)).To(Equal(StreamID(399)))
|
||||
It("handles others", func() {
|
||||
Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)).To(Equal(StreamID(396)))
|
||||
Expect(StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)).To(Equal(StreamID(397)))
|
||||
Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398)))
|
||||
Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
134
streams_map.go
134
streams_map.go
|
@ -12,6 +12,27 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type streamError struct {
|
||||
message string
|
||||
nums []protocol.StreamNum
|
||||
}
|
||||
|
||||
func (e streamError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
|
||||
strError, ok := err.(streamError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
ids := make([]interface{}, len(strError.nums))
|
||||
for i, num := range strError.nums {
|
||||
ids[i] = num.StreamID(stype, pers)
|
||||
}
|
||||
return fmt.Errorf(strError.Error(), ids...)
|
||||
}
|
||||
|
||||
type streamOpenErr struct{ error }
|
||||
|
||||
var _ net.Error = &streamOpenErr{}
|
||||
|
@ -49,112 +70,144 @@ func newStreamsMap(
|
|||
newFlowController: newFlowController,
|
||||
sender: sender,
|
||||
}
|
||||
newBidiStream := func(id protocol.StreamID) streamI {
|
||||
return newStream(id, m.sender, m.newFlowController(id), version)
|
||||
}
|
||||
newUniSendStream := func(id protocol.StreamID) sendStreamI {
|
||||
return newSendStream(id, m.sender, m.newFlowController(id), version)
|
||||
}
|
||||
newUniReceiveStream := func(id protocol.StreamID) receiveStreamI {
|
||||
return newReceiveStream(id, m.sender, m.newFlowController(id), version)
|
||||
}
|
||||
m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
|
||||
protocol.FirstStream(protocol.StreamTypeBidi, perspective),
|
||||
newBidiStream,
|
||||
func(num protocol.StreamNum) streamI {
|
||||
id := num.StreamID(protocol.StreamTypeBidi, perspective)
|
||||
return newStream(id, m.sender, m.newFlowController(id), version)
|
||||
},
|
||||
sender.queueControlFrame,
|
||||
)
|
||||
m.incomingBidiStreams = newIncomingBidiStreamsMap(
|
||||
protocol.FirstStream(protocol.StreamTypeBidi, perspective.Opposite()),
|
||||
protocol.MaxStreamID(protocol.StreamTypeBidi, protocol.StreamNum(maxIncomingBidiStreams), perspective.Opposite()),
|
||||
func(num protocol.StreamNum) streamI {
|
||||
id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite())
|
||||
return newStream(id, m.sender, m.newFlowController(id), version)
|
||||
},
|
||||
maxIncomingBidiStreams,
|
||||
sender.queueControlFrame,
|
||||
newBidiStream,
|
||||
)
|
||||
m.outgoingUniStreams = newOutgoingUniStreamsMap(
|
||||
protocol.FirstStream(protocol.StreamTypeUni, perspective),
|
||||
newUniSendStream,
|
||||
func(num protocol.StreamNum) sendStreamI {
|
||||
id := num.StreamID(protocol.StreamTypeUni, perspective)
|
||||
return newSendStream(id, m.sender, m.newFlowController(id), version)
|
||||
},
|
||||
sender.queueControlFrame,
|
||||
)
|
||||
m.incomingUniStreams = newIncomingUniStreamsMap(
|
||||
protocol.FirstStream(protocol.StreamTypeUni, perspective.Opposite()),
|
||||
protocol.MaxStreamID(protocol.StreamTypeUni, protocol.StreamNum(maxIncomingUniStreams), perspective.Opposite()),
|
||||
func(num protocol.StreamNum) receiveStreamI {
|
||||
id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite())
|
||||
return newReceiveStream(id, m.sender, m.newFlowController(id), version)
|
||||
},
|
||||
maxIncomingUniStreams,
|
||||
sender.queueControlFrame,
|
||||
newUniReceiveStream,
|
||||
)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenStream() (Stream, error) {
|
||||
return m.outgoingBidiStreams.OpenStream()
|
||||
str, err := m.outgoingBidiStreams.OpenStream()
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenStreamSync() (Stream, error) {
|
||||
return m.outgoingBidiStreams.OpenStreamSync()
|
||||
str, err := m.outgoingBidiStreams.OpenStreamSync()
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenUniStream() (SendStream, error) {
|
||||
return m.outgoingUniStreams.OpenStream()
|
||||
str, err := m.outgoingUniStreams.OpenStream()
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
||||
return m.outgoingUniStreams.OpenStreamSync()
|
||||
str, err := m.outgoingUniStreams.OpenStreamSync()
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
||||
return m.incomingBidiStreams.AcceptStream()
|
||||
str, err := m.incomingBidiStreams.AcceptStream()
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
|
||||
return m.incomingUniStreams.AcceptStream()
|
||||
str, err := m.incomingUniStreams.AcceptStream()
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
num := id.StreamNum()
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return m.outgoingUniStreams.DeleteStream(id)
|
||||
return m.outgoingUniStreams.DeleteStream(num)
|
||||
}
|
||||
return m.incomingUniStreams.DeleteStream(id)
|
||||
return m.incomingUniStreams.DeleteStream(num)
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return m.outgoingBidiStreams.DeleteStream(id)
|
||||
return m.outgoingBidiStreams.DeleteStream(num)
|
||||
}
|
||||
return m.incomingBidiStreams.DeleteStream(id)
|
||||
return m.incomingBidiStreams.DeleteStream(num)
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
|
||||
str, err := m.getOrOpenReceiveStream(id)
|
||||
if err != nil {
|
||||
return nil, qerr.Error(qerr.StreamStateError, err.Error())
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
|
||||
num := id.StreamNum()
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
|
||||
}
|
||||
return m.incomingUniStreams.GetOrOpenStream(id)
|
||||
str, err := m.incomingUniStreams.GetOrOpenStream(num)
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||
case protocol.StreamTypeBidi:
|
||||
var str receiveStreamI
|
||||
var err error
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return m.outgoingBidiStreams.GetStream(id)
|
||||
str, err = m.outgoingBidiStreams.GetStream(num)
|
||||
} else {
|
||||
str, err = m.incomingBidiStreams.GetOrOpenStream(num)
|
||||
}
|
||||
return m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
|
||||
str, err := m.getOrOpenSendStream(id)
|
||||
if err != nil {
|
||||
return nil, qerr.Error(qerr.StreamStateError, err.Error())
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
|
||||
num := id.StreamNum()
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return m.outgoingUniStreams.GetStream(id)
|
||||
str, err := m.outgoingUniStreams.GetStream(num)
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||
}
|
||||
// an incoming unidirectional stream is a receive stream, not a send stream
|
||||
return nil, fmt.Errorf("peer attempted to open send stream %d", id)
|
||||
case protocol.StreamTypeBidi:
|
||||
var str sendStreamI
|
||||
var err error
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return m.outgoingBidiStreams.GetStream(id)
|
||||
str, err = m.outgoingBidiStreams.GetStream(num)
|
||||
} else {
|
||||
str, err = m.incomingBidiStreams.GetOrOpenStream(num)
|
||||
}
|
||||
return m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
@ -163,12 +216,11 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
|
|||
if f.MaxStreamNum > protocol.MaxStreamCount {
|
||||
return qerr.StreamLimitError
|
||||
}
|
||||
id := protocol.MaxStreamID(f.Type, f.MaxStreamNum, m.perspective)
|
||||
switch id.Type() {
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeUni:
|
||||
m.outgoingUniStreams.SetMaxStream(id)
|
||||
m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
|
||||
case protocol.StreamTypeBidi:
|
||||
m.outgoingBidiStreams.SetMaxStream(id)
|
||||
m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -179,8 +231,8 @@ func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) error {
|
|||
return qerr.StreamLimitError
|
||||
}
|
||||
// Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
|
||||
m.outgoingBidiStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeBidi, p.MaxBidiStreamNum, m.perspective))
|
||||
m.outgoingUniStreams.SetMaxStream(protocol.MaxStreamID(protocol.StreamTypeUni, p.MaxUniStreamNum, m.perspective))
|
||||
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
|
||||
m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -16,38 +15,39 @@ type incomingBidiStreamsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
streams map[protocol.StreamNum]streamI
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
// We need to wait until the application accepts it, and delete it immediately then.
|
||||
streamsToDelete map[protocol.StreamID]struct{} // used as a set
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
|
||||
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamID) streamI
|
||||
newStream func(protocol.StreamNum) streamI
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newIncomingBidiStreamsMap(
|
||||
nextStreamToAccept protocol.StreamID,
|
||||
initialMaxStreamID protocol.StreamID,
|
||||
maxNumStreams uint64,
|
||||
newStream func(protocol.StreamNum) streamI,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
newStream func(protocol.StreamID) streamI,
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingBidiStreamsMap {
|
||||
m := &incomingBidiStreamsMap{
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
streamsToDelete: make(map[protocol.StreamID]struct{}),
|
||||
nextStreamToAccept: nextStreamToAccept,
|
||||
nextStreamToOpen: nextStreamToAccept,
|
||||
maxStream: initialMaxStreamID,
|
||||
maxNumStreams: maxNumStreams,
|
||||
streams: make(map[protocol.StreamNum]streamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
maxNumStreams: maxStreams,
|
||||
newStream: newStream,
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
|
@ -57,45 +57,48 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
|||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var id protocol.StreamID
|
||||
var num protocol.StreamNum
|
||||
var str streamI
|
||||
for {
|
||||
id = m.nextStreamToAccept
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
str, ok = m.streams[id]
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
}
|
||||
m.nextStreamToAccept += 4
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
delete(m.streamsToDelete, id)
|
||||
if err := m.deleteStream(id); err != nil {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
||||
func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) {
|
||||
m.mutex.RLock()
|
||||
if id > m.maxStream {
|
||||
if num > m.maxStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
|
||||
return nil, streamError{
|
||||
message: "peer tried to open stream %d (current limit: %d)",
|
||||
nums: []protocol.StreamNum{num, m.maxStream},
|
||||
}
|
||||
// if the id is smaller than the highest we accepted
|
||||
}
|
||||
// if the num is smaller than the highest we accepted
|
||||
// * this stream exists in the map, and we can return it, or
|
||||
// * this stream was already closed, then we can return the nil
|
||||
if id < m.nextStreamToOpen {
|
||||
if num < m.nextStreamToOpen {
|
||||
var s streamI
|
||||
// If the stream was already queued for deletion, and is just waiting to be accepted, don't return it.
|
||||
if _, ok := m.streamsToDelete[id]; !ok {
|
||||
s = m.streams[id]
|
||||
if _, ok := m.streamsToDelete[num]; !ok {
|
||||
s = m.streams[num]
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
|
@ -106,46 +109,52 @@ func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI,
|
|||
// no need to check the two error conditions from above again
|
||||
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
|
||||
// * highestStream is only modified by this function
|
||||
for newID := m.nextStreamToOpen; newID <= id; newID += 4 {
|
||||
m.streams[newID] = m.newStream(newID)
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
}
|
||||
m.nextStreamToOpen = id + 4
|
||||
s := m.streams[id]
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
m.mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteStream(id)
|
||||
return m.deleteStream(num)
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) deleteStream(id protocol.StreamID) error {
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error {
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
}
|
||||
|
||||
// Don't delete this stream yet, if it was not yet accepted.
|
||||
// Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted.
|
||||
if id >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
return fmt.Errorf("Tried to delete stream %d multiple times", id)
|
||||
if num >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
return streamError{
|
||||
message: "Tried to delete stream %d multiple times",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
m.streamsToDelete[id] = struct{}{}
|
||||
}
|
||||
m.streamsToDelete[num] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(m.streams, id)
|
||||
delete(m.streams, num)
|
||||
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
|
||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1
|
||||
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeBidi,
|
||||
MaxStreamNum: m.maxStream.StreamNum(),
|
||||
MaxStreamNum: m.maxStream,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -14,38 +13,39 @@ type incomingItemsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]item
|
||||
streams map[protocol.StreamNum]item
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
// We need to wait until the application accepts it, and delete it immediately then.
|
||||
streamsToDelete map[protocol.StreamID]struct{} // used as a set
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
|
||||
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamID) item
|
||||
newStream func(protocol.StreamNum) item
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newIncomingItemsMap(
|
||||
nextStreamToAccept protocol.StreamID,
|
||||
initialMaxStreamID protocol.StreamID,
|
||||
maxNumStreams uint64,
|
||||
newStream func(protocol.StreamNum) item,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
newStream func(protocol.StreamID) item,
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingItemsMap {
|
||||
m := &incomingItemsMap{
|
||||
streams: make(map[protocol.StreamID]item),
|
||||
streamsToDelete: make(map[protocol.StreamID]struct{}),
|
||||
nextStreamToAccept: nextStreamToAccept,
|
||||
nextStreamToOpen: nextStreamToAccept,
|
||||
maxStream: initialMaxStreamID,
|
||||
maxNumStreams: maxNumStreams,
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
maxNumStreams: maxStreams,
|
||||
newStream: newStream,
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
|
@ -55,45 +55,48 @@ func (m *incomingItemsMap) AcceptStream() (item, error) {
|
|||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var id protocol.StreamID
|
||||
var num protocol.StreamNum
|
||||
var str item
|
||||
for {
|
||||
id = m.nextStreamToAccept
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
str, ok = m.streams[id]
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
}
|
||||
m.nextStreamToAccept += 4
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
delete(m.streamsToDelete, id)
|
||||
if err := m.deleteStream(id); err != nil {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) {
|
||||
func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) {
|
||||
m.mutex.RLock()
|
||||
if id > m.maxStream {
|
||||
if num > m.maxStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
|
||||
return nil, streamError{
|
||||
message: "peer tried to open stream %d (current limit: %d)",
|
||||
nums: []protocol.StreamNum{num, m.maxStream},
|
||||
}
|
||||
// if the id is smaller than the highest we accepted
|
||||
}
|
||||
// if the num is smaller than the highest we accepted
|
||||
// * this stream exists in the map, and we can return it, or
|
||||
// * this stream was already closed, then we can return the nil
|
||||
if id < m.nextStreamToOpen {
|
||||
if num < m.nextStreamToOpen {
|
||||
var s item
|
||||
// If the stream was already queued for deletion, and is just waiting to be accepted, don't return it.
|
||||
if _, ok := m.streamsToDelete[id]; !ok {
|
||||
s = m.streams[id]
|
||||
if _, ok := m.streamsToDelete[num]; !ok {
|
||||
s = m.streams[num]
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
|
@ -104,46 +107,52 @@ func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) {
|
|||
// no need to check the two error conditions from above again
|
||||
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
|
||||
// * highestStream is only modified by this function
|
||||
for newID := m.nextStreamToOpen; newID <= id; newID += 4 {
|
||||
m.streams[newID] = m.newStream(newID)
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
}
|
||||
m.nextStreamToOpen = id + 4
|
||||
s := m.streams[id]
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
m.mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteStream(id)
|
||||
return m.deleteStream(num)
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) deleteStream(id protocol.StreamID) error {
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error {
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
}
|
||||
|
||||
// Don't delete this stream yet, if it was not yet accepted.
|
||||
// Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted.
|
||||
if id >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
return fmt.Errorf("Tried to delete stream %d multiple times", id)
|
||||
if num >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
return streamError{
|
||||
message: "Tried to delete stream %d multiple times",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
m.streamsToDelete[id] = struct{}{}
|
||||
}
|
||||
m.streamsToDelete[num] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(m.streams, id)
|
||||
delete(m.streams, num)
|
||||
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
|
||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1
|
||||
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||
Type: streamTypeGeneric,
|
||||
MaxStreamNum: m.maxStream.StreamNum(),
|
||||
MaxStreamNum: m.maxStream,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -13,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
type mockGenericStream struct {
|
||||
id protocol.StreamID
|
||||
num protocol.StreamNum
|
||||
|
||||
closed bool
|
||||
closeErr error
|
||||
|
@ -26,64 +25,65 @@ func (s *mockGenericStream) closeForShutdown(err error) {
|
|||
|
||||
var _ = Describe("Streams Map (incoming)", func() {
|
||||
const (
|
||||
firstNewStream protocol.StreamID = 2
|
||||
maxNumStreams uint64 = 5
|
||||
initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1)
|
||||
)
|
||||
|
||||
var (
|
||||
m *incomingItemsMap
|
||||
newItem func(id protocol.StreamID) item
|
||||
newItemCounter int
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
newItemCounter = 0
|
||||
newItem = func(id protocol.StreamID) item {
|
||||
newItemCounter++
|
||||
return &mockGenericStream{id: id}
|
||||
}
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
m = newIncomingItemsMap(firstNewStream, initialMaxStream, maxNumStreams, mockSender.queueControlFrame, newItem)
|
||||
m = newIncomingItemsMap(
|
||||
func(num protocol.StreamNum) item {
|
||||
newItemCounter++
|
||||
return &mockGenericStream{num: num}
|
||||
},
|
||||
maxNumStreams,
|
||||
mockSender.queueControlFrame,
|
||||
)
|
||||
})
|
||||
|
||||
It("opens all streams up to the id on GetOrOpenStream", func() {
|
||||
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
|
||||
_, err := m.GetOrOpenStream(4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(newItemCounter).To(Equal(5))
|
||||
Expect(newItemCounter).To(Equal(4))
|
||||
})
|
||||
|
||||
It("starts opening streams at the right position", func() {
|
||||
// like the test above, but with 2 calls to GetOrOpenStream
|
||||
_, err := m.GetOrOpenStream(firstNewStream + 4)
|
||||
_, err := m.GetOrOpenStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(newItemCounter).To(Equal(2))
|
||||
_, err = m.GetOrOpenStream(firstNewStream + 4*4)
|
||||
_, err = m.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(newItemCounter).To(Equal(5))
|
||||
})
|
||||
|
||||
It("accepts streams in the right order", func() {
|
||||
_, err := m.GetOrOpenStream(firstNewStream + 4) // open stream 20 and 24
|
||||
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
str, err = m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
||||
It("allows opening the maximum stream ID", func() {
|
||||
str, err := m.GetOrOpenStream(initialMaxStream)
|
||||
str, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(initialMaxStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("errors when trying to get a stream ID higher than the maximum", func() {
|
||||
_, err := m.GetOrOpenStream(initialMaxStream + 4)
|
||||
Expect(err).To(MatchError(fmt.Errorf("peer tried to open stream %d (current limit: %d)", initialMaxStream+4, initialMaxStream)))
|
||||
_, err := m.GetOrOpenStream(6)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)"))
|
||||
})
|
||||
|
||||
It("blocks AcceptStream until a new stream is available", func() {
|
||||
|
@ -95,30 +95,12 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
strChan <- str
|
||||
}()
|
||||
Consistently(strChan).ShouldNot(Receive())
|
||||
str, err := m.GetOrOpenStream(firstNewStream)
|
||||
str, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
var acceptedStr item
|
||||
Eventually(strChan).Should(Receive(&acceptedStr))
|
||||
Expect(acceptedStr.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
})
|
||||
|
||||
It("works with stream 0", func() {
|
||||
m = newIncomingItemsMap(0, 1000, 1000, mockSender.queueControlFrame, newItem)
|
||||
strChan := make(chan item)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan <- str
|
||||
}()
|
||||
Consistently(strChan).ShouldNot(Receive())
|
||||
str, err := m.GetOrOpenStream(0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(BeZero())
|
||||
var acceptedStr item
|
||||
Eventually(strChan).Should(Receive(&acceptedStr))
|
||||
Expect(acceptedStr.(*mockGenericStream).id).To(BeZero())
|
||||
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("unblocks AcceptStream when it is closed", func() {
|
||||
|
@ -143,9 +125,9 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
})
|
||||
|
||||
It("closes all streams when CloseWithError is called", func() {
|
||||
str1, err := m.GetOrOpenStream(firstNewStream)
|
||||
str1, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str2, err := m.GetOrOpenStream(firstNewStream + 8)
|
||||
str2, err := m.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
testErr := errors.New("test err")
|
||||
m.CloseWithError(testErr)
|
||||
|
@ -157,37 +139,37 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
|
||||
It("deletes streams", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
_, err := m.GetOrOpenStream(firstNewStream)
|
||||
_, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(m.DeleteStream(firstNewStream)).To(Succeed())
|
||||
str, err = m.GetOrOpenStream(firstNewStream)
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
str, err = m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
})
|
||||
|
||||
It("waits until a stream is accepted before actually deleting it", func() {
|
||||
_, err := m.GetOrOpenStream(firstNewStream + 4)
|
||||
_, err := m.GetOrOpenStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed())
|
||||
Expect(m.DeleteStream(2)).To(Succeed())
|
||||
str, err := m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str, err = m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
||||
It("doesn't return a stream queued for deleting from GetOrOpenStream", func() {
|
||||
str, err := m.GetOrOpenStream(firstNewStream)
|
||||
str, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(m.DeleteStream(firstNewStream)).To(Succeed())
|
||||
str, err = m.GetOrOpenStream(firstNewStream)
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
str, err = m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
|
@ -199,12 +181,13 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
|
||||
It("errors when deleting a non-existing stream", func() {
|
||||
err := m.DeleteStream(1337)
|
||||
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1337"))
|
||||
})
|
||||
|
||||
It("sends MAX_STREAMS frames when streams are deleted", func() {
|
||||
// open a bunch of streams
|
||||
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
|
||||
_, err := m.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// accept all streams
|
||||
for i := 0; i < 5; i++ {
|
||||
|
@ -214,10 +197,10 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
|
||||
})
|
||||
Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed())
|
||||
Expect(m.DeleteStream(3)).To(Succeed())
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
|
||||
})
|
||||
Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed())
|
||||
Expect(m.DeleteStream(4)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -16,38 +15,39 @@ type incomingUniStreamsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]receiveStreamI
|
||||
streams map[protocol.StreamNum]receiveStreamI
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
// We need to wait until the application accepts it, and delete it immediately then.
|
||||
streamsToDelete map[protocol.StreamID]struct{} // used as a set
|
||||
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
|
||||
|
||||
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamID // the highest stream that the peer openend
|
||||
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
|
||||
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
|
||||
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
|
||||
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
|
||||
maxNumStreams uint64 // maximum number of streams
|
||||
|
||||
newStream func(protocol.StreamID) receiveStreamI
|
||||
newStream func(protocol.StreamNum) receiveStreamI
|
||||
queueMaxStreamID func(*wire.MaxStreamsFrame)
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newIncomingUniStreamsMap(
|
||||
nextStreamToAccept protocol.StreamID,
|
||||
initialMaxStreamID protocol.StreamID,
|
||||
maxNumStreams uint64,
|
||||
newStream func(protocol.StreamNum) receiveStreamI,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
newStream func(protocol.StreamID) receiveStreamI,
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingUniStreamsMap {
|
||||
m := &incomingUniStreamsMap{
|
||||
streams: make(map[protocol.StreamID]receiveStreamI),
|
||||
streamsToDelete: make(map[protocol.StreamID]struct{}),
|
||||
nextStreamToAccept: nextStreamToAccept,
|
||||
nextStreamToOpen: nextStreamToAccept,
|
||||
maxStream: initialMaxStreamID,
|
||||
maxNumStreams: maxNumStreams,
|
||||
streams: make(map[protocol.StreamNum]receiveStreamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
maxNumStreams: maxStreams,
|
||||
newStream: newStream,
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
|
@ -57,45 +57,48 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
|||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var id protocol.StreamID
|
||||
var num protocol.StreamNum
|
||||
var str receiveStreamI
|
||||
for {
|
||||
id = m.nextStreamToAccept
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
str, ok = m.streams[id]
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
}
|
||||
m.nextStreamToAccept += 4
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
delete(m.streamsToDelete, id)
|
||||
if err := m.deleteStream(id); err != nil {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) {
|
||||
func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) {
|
||||
m.mutex.RLock()
|
||||
if id > m.maxStream {
|
||||
if num > m.maxStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
|
||||
return nil, streamError{
|
||||
message: "peer tried to open stream %d (current limit: %d)",
|
||||
nums: []protocol.StreamNum{num, m.maxStream},
|
||||
}
|
||||
// if the id is smaller than the highest we accepted
|
||||
}
|
||||
// if the num is smaller than the highest we accepted
|
||||
// * this stream exists in the map, and we can return it, or
|
||||
// * this stream was already closed, then we can return the nil
|
||||
if id < m.nextStreamToOpen {
|
||||
if num < m.nextStreamToOpen {
|
||||
var s receiveStreamI
|
||||
// If the stream was already queued for deletion, and is just waiting to be accepted, don't return it.
|
||||
if _, ok := m.streamsToDelete[id]; !ok {
|
||||
s = m.streams[id]
|
||||
if _, ok := m.streamsToDelete[num]; !ok {
|
||||
s = m.streams[num]
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
|
@ -106,46 +109,52 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveSt
|
|||
// no need to check the two error conditions from above again
|
||||
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
|
||||
// * highestStream is only modified by this function
|
||||
for newID := m.nextStreamToOpen; newID <= id; newID += 4 {
|
||||
m.streams[newID] = m.newStream(newID)
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
}
|
||||
m.nextStreamToOpen = id + 4
|
||||
s := m.streams[id]
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
m.mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteStream(id)
|
||||
return m.deleteStream(num)
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) deleteStream(id protocol.StreamID) error {
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error {
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
}
|
||||
|
||||
// Don't delete this stream yet, if it was not yet accepted.
|
||||
// Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted.
|
||||
if id >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[id]; ok {
|
||||
return fmt.Errorf("Tried to delete stream %d multiple times", id)
|
||||
if num >= m.nextStreamToAccept {
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
return streamError{
|
||||
message: "Tried to delete stream %d multiple times",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
m.streamsToDelete[id] = struct{}{}
|
||||
}
|
||||
m.streamsToDelete[num] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(m.streams, id)
|
||||
delete(m.streams, num)
|
||||
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
|
||||
if m.maxNumStreams > uint64(len(m.streams)) {
|
||||
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
|
||||
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1
|
||||
m.queueMaxStreamID(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeUni,
|
||||
MaxStreamNum: m.maxStream.StreamNum(),
|
||||
MaxStreamNum: m.maxStream,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -5,11 +5,9 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -17,27 +15,26 @@ type outgoingBidiStreamsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
streams map[protocol.StreamNum]streamI
|
||||
|
||||
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
|
||||
newStream func(protocol.StreamID) streamI
|
||||
newStream func(protocol.StreamNum) streamI
|
||||
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newOutgoingBidiStreamsMap(
|
||||
nextStream protocol.StreamID,
|
||||
newStream func(protocol.StreamID) streamI,
|
||||
newStream func(protocol.StreamNum) streamI,
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *outgoingBidiStreamsMap {
|
||||
m := &outgoingBidiStreamsMap{
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
nextStream: nextStream,
|
||||
maxStream: protocol.InvalidStreamID,
|
||||
streams: make(map[protocol.StreamNum]streamI),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
|
||||
}
|
||||
|
@ -83,8 +80,8 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) {
|
|||
if m.nextStream > m.maxStream {
|
||||
if !m.blockedSent {
|
||||
var streamNum protocol.StreamNum
|
||||
if m.maxStream != protocol.InvalidStreamID {
|
||||
streamNum = m.maxStream.StreamNum()
|
||||
if m.maxStream != protocol.InvalidStreamNum {
|
||||
streamNum = m.maxStream
|
||||
}
|
||||
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
|
||||
Type: protocol.StreamTypeBidi,
|
||||
|
@ -96,36 +93,42 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) {
|
|||
}
|
||||
s := m.newStream(m.nextStream)
|
||||
m.streams[m.nextStream] = s
|
||||
m.nextStream += 4
|
||||
m.nextStream++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) {
|
||||
func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) {
|
||||
m.mutex.RLock()
|
||||
if id >= m.nextStream {
|
||||
if num >= m.nextStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))
|
||||
return nil, streamError{
|
||||
message: "peer attempted to open stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
s := m.streams[id]
|
||||
}
|
||||
s := m.streams[num]
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
delete(m.streams, id)
|
||||
}
|
||||
delete(m.streams, num)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) {
|
||||
func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) {
|
||||
m.mutex.Lock()
|
||||
if id > m.maxStream {
|
||||
m.maxStream = id
|
||||
if num > m.maxStream {
|
||||
m.maxStream = num
|
||||
m.blockedSent = false
|
||||
m.cond.Broadcast()
|
||||
}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -15,27 +13,26 @@ type outgoingItemsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]item
|
||||
streams map[protocol.StreamNum]item
|
||||
|
||||
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
|
||||
newStream func(protocol.StreamID) item
|
||||
newStream func(protocol.StreamNum) item
|
||||
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newOutgoingItemsMap(
|
||||
nextStream protocol.StreamID,
|
||||
newStream func(protocol.StreamID) item,
|
||||
newStream func(protocol.StreamNum) item,
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *outgoingItemsMap {
|
||||
m := &outgoingItemsMap{
|
||||
streams: make(map[protocol.StreamID]item),
|
||||
nextStream: nextStream,
|
||||
maxStream: protocol.InvalidStreamID,
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
|
||||
}
|
||||
|
@ -81,8 +78,8 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) {
|
|||
if m.nextStream > m.maxStream {
|
||||
if !m.blockedSent {
|
||||
var streamNum protocol.StreamNum
|
||||
if m.maxStream != protocol.InvalidStreamID {
|
||||
streamNum = m.maxStream.StreamNum()
|
||||
if m.maxStream != protocol.InvalidStreamNum {
|
||||
streamNum = m.maxStream
|
||||
}
|
||||
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
|
||||
Type: streamTypeGeneric,
|
||||
|
@ -94,36 +91,42 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) {
|
|||
}
|
||||
s := m.newStream(m.nextStream)
|
||||
m.streams[m.nextStream] = s
|
||||
m.nextStream += 4
|
||||
m.nextStream++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) {
|
||||
func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) {
|
||||
m.mutex.RLock()
|
||||
if id >= m.nextStream {
|
||||
if num >= m.nextStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))
|
||||
return nil, streamError{
|
||||
message: "peer attempted to open stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
s := m.streams[id]
|
||||
}
|
||||
s := m.streams[num]
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
delete(m.streams, id)
|
||||
}
|
||||
delete(m.streams, num)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) {
|
||||
func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) {
|
||||
m.mutex.Lock()
|
||||
if id > m.maxStream {
|
||||
m.maxStream = id
|
||||
if num > m.maxStream {
|
||||
m.maxStream = num
|
||||
m.blockedSent = false
|
||||
m.cond.Broadcast()
|
||||
}
|
||||
|
|
|
@ -5,27 +5,24 @@ import (
|
|||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Streams Map (outgoing)", func() {
|
||||
const firstNewStream protocol.StreamID = 3
|
||||
|
||||
var (
|
||||
m *outgoingItemsMap
|
||||
newItem func(id protocol.StreamID) item
|
||||
newItem func(num protocol.StreamNum) item
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
newItem = func(id protocol.StreamID) item {
|
||||
return &mockGenericStream{id: id}
|
||||
newItem = func(num protocol.StreamNum) item {
|
||||
return &mockGenericStream{num: num}
|
||||
}
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
m = newOutgoingItemsMap(firstNewStream, newItem, mockSender.queueControlFrame)
|
||||
m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame)
|
||||
})
|
||||
|
||||
Context("no stream ID limit", func() {
|
||||
|
@ -36,10 +33,10 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||
It("opens streams", func() {
|
||||
str, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
str, err = m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
||||
It("doesn't open streams after it has been closed", func() {
|
||||
|
@ -52,38 +49,40 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||
It("gets streams", func() {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.GetStream(firstNewStream)
|
||||
str, err := m.GetStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("errors when trying to get a stream that has not yet been opened", func() {
|
||||
_, err := m.GetStream(firstNewStream)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, "peer attempted to open stream 3")))
|
||||
_, err := m.GetStream(1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1"))
|
||||
})
|
||||
|
||||
It("deletes streams", func() {
|
||||
_, err := m.OpenStream() // opens firstNewStream
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = m.DeleteStream(firstNewStream)
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.GetStream(firstNewStream)
|
||||
str, err := m.GetStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
})
|
||||
|
||||
It("errors when deleting a non-existing stream", func() {
|
||||
err := m.DeleteStream(1337)
|
||||
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1337"))
|
||||
})
|
||||
|
||||
It("errors when deleting a stream twice", func() {
|
||||
_, err := m.OpenStream() // opens firstNewStream
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = m.DeleteStream(firstNewStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = m.DeleteStream(firstNewStream)
|
||||
Expect(err).To(MatchError("Tried to delete unknown stream 3"))
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
err = m.DeleteStream(1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown stream 1"))
|
||||
})
|
||||
|
||||
It("closes all streams when CloseWithError is called", func() {
|
||||
|
@ -114,31 +113,12 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
m.SetMaxStream(firstNewStream)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("works with stream 0", func() {
|
||||
m = newOutgoingItemsMap(0, newItem, mockSender.queueControlFrame)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeZero())
|
||||
})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(BeZero())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
m.SetMaxStream(0)
|
||||
m.SetMaxStream(1)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -159,17 +139,17 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||
})
|
||||
|
||||
It("doesn't reduce the stream limit", func() {
|
||||
m.SetMaxStream(firstNewStream + 4)
|
||||
m.SetMaxStream(firstNewStream)
|
||||
m.SetMaxStream(2)
|
||||
m.SetMaxStream(1)
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).id).To(Equal(firstNewStream + 4))
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
||||
It("queues a STREAM_ID_BLOCKED frame if no stream can be opened", func() {
|
||||
m.SetMaxStream(firstNewStream + 5*4)
|
||||
m.SetMaxStream(6)
|
||||
// open the 6 allowed streams
|
||||
for i := 0; i < 6; i++ {
|
||||
_, err := m.OpenStream()
|
||||
|
@ -185,7 +165,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
|||
})
|
||||
|
||||
It("only sends one STREAM_ID_BLOCKED frame for one stream ID", func() {
|
||||
m.SetMaxStream(firstNewStream)
|
||||
m.SetMaxStream(1)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
|
|
@ -5,11 +5,9 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -17,27 +15,26 @@ type outgoingUniStreamsMap struct {
|
|||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
|
||||
streams map[protocol.StreamID]sendStreamI
|
||||
streams map[protocol.StreamNum]sendStreamI
|
||||
|
||||
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
|
||||
newStream func(protocol.StreamID) sendStreamI
|
||||
newStream func(protocol.StreamNum) sendStreamI
|
||||
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newOutgoingUniStreamsMap(
|
||||
nextStream protocol.StreamID,
|
||||
newStream func(protocol.StreamID) sendStreamI,
|
||||
newStream func(protocol.StreamNum) sendStreamI,
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *outgoingUniStreamsMap {
|
||||
m := &outgoingUniStreamsMap{
|
||||
streams: make(map[protocol.StreamID]sendStreamI),
|
||||
nextStream: nextStream,
|
||||
maxStream: protocol.InvalidStreamID,
|
||||
streams: make(map[protocol.StreamNum]sendStreamI),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
|
||||
}
|
||||
|
@ -83,8 +80,8 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) {
|
|||
if m.nextStream > m.maxStream {
|
||||
if !m.blockedSent {
|
||||
var streamNum protocol.StreamNum
|
||||
if m.maxStream != protocol.InvalidStreamID {
|
||||
streamNum = m.maxStream.StreamNum()
|
||||
if m.maxStream != protocol.InvalidStreamNum {
|
||||
streamNum = m.maxStream
|
||||
}
|
||||
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
|
||||
Type: protocol.StreamTypeUni,
|
||||
|
@ -96,36 +93,42 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) {
|
|||
}
|
||||
s := m.newStream(m.nextStream)
|
||||
m.streams[m.nextStream] = s
|
||||
m.nextStream += 4
|
||||
m.nextStream++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) {
|
||||
func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) {
|
||||
m.mutex.RLock()
|
||||
if id >= m.nextStream {
|
||||
if num >= m.nextStream {
|
||||
m.mutex.RUnlock()
|
||||
return nil, qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))
|
||||
return nil, streamError{
|
||||
message: "peer attempted to open stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
s := m.streams[id]
|
||||
}
|
||||
s := m.streams[num]
|
||||
m.mutex.RUnlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, ok := m.streams[id]; !ok {
|
||||
return fmt.Errorf("Tried to delete unknown stream %d", id)
|
||||
if _, ok := m.streams[num]; !ok {
|
||||
return streamError{
|
||||
message: "Tried to delete unknown stream %d",
|
||||
nums: []protocol.StreamNum{num},
|
||||
}
|
||||
delete(m.streams, id)
|
||||
}
|
||||
delete(m.streams, num)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) {
|
||||
func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) {
|
||||
m.mutex.Lock()
|
||||
if id > m.maxStream {
|
||||
m.maxStream = id
|
||||
if num > m.maxStream {
|
||||
m.maxStream = num
|
||||
m.blockedSent = false
|
||||
m.cond.Broadcast()
|
||||
}
|
||||
|
|
|
@ -17,6 +17,14 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func (e streamError) TestError() error {
|
||||
nums := make([]interface{}, len(e.nums))
|
||||
for i, num := range e.nums {
|
||||
nums[i] = num
|
||||
}
|
||||
return fmt.Errorf(e.message, nums...)
|
||||
}
|
||||
|
||||
type streamMapping struct {
|
||||
firstIncomingBidiStream protocol.StreamID
|
||||
firstIncomingUniStream protocol.StreamID
|
||||
|
@ -221,7 +229,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
|
||||
id := ids.firstOutgoingBidiStream + 5*4
|
||||
_, err := m.GetOrOpenSendStream(id)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))))
|
||||
Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id)))
|
||||
})
|
||||
|
||||
It("gets an outgoing unidirectional stream", func() {
|
||||
|
@ -237,7 +245,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
|
||||
id := ids.firstOutgoingUniStream + 5*4
|
||||
_, err := m.GetOrOpenSendStream(id)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))))
|
||||
Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id)))
|
||||
})
|
||||
|
||||
It("gets an incoming bidirectional stream", func() {
|
||||
|
@ -250,7 +258,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("errors when trying to get an incoming unidirectional stream", func() {
|
||||
id := ids.firstIncomingUniStream
|
||||
_, err := m.GetOrOpenSendStream(id)
|
||||
Expect(err).To(MatchError(fmt.Errorf("peer attempted to open send stream %d", id)))
|
||||
Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open send stream %d", id)))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -268,7 +276,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
|
||||
id := ids.firstOutgoingBidiStream + 5*4
|
||||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.StreamStateError, fmt.Sprintf("peer attempted to open stream %d", id))))
|
||||
Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id)))
|
||||
})
|
||||
|
||||
It("gets an incoming bidirectional stream", func() {
|
||||
|
@ -288,37 +296,44 @@ var _ = Describe("Streams Map", func() {
|
|||
It("errors when trying to get an outgoing unidirectional stream", func() {
|
||||
id := ids.firstOutgoingUniStream
|
||||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).To(MatchError(fmt.Errorf("peer attempted to open receive stream %d", id)))
|
||||
Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open receive stream %d", id)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("updating stream ID limits", func() {
|
||||
It("processes the parameter for outgoing streams, as a server", func() {
|
||||
for _, p := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} {
|
||||
pers := p
|
||||
|
||||
It(fmt.Sprintf("processes the parameter for outgoing streams, as a %s", pers), func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
m.perspective = protocol.PerspectiveServer
|
||||
m.perspective = pers
|
||||
_, err := m.OpenStream()
|
||||
expectTooManyStreamsError(err)
|
||||
Expect(m.UpdateLimits(&handshake.TransportParameters{
|
||||
MaxBidiStreamNum: 5,
|
||||
MaxUniStreamNum: 5,
|
||||
MaxUniStreamNum: 8,
|
||||
})).To(Succeed())
|
||||
Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(17)))
|
||||
Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(19)))
|
||||
})
|
||||
|
||||
It("processes the parameter for outgoing streams, as a client", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
m.perspective = protocol.PerspectiveClient
|
||||
_, err := m.OpenUniStream()
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2)
|
||||
// test we can only 5 bidirectional streams
|
||||
for i := 0; i < 5; i++ {
|
||||
str, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i)))
|
||||
}
|
||||
_, err = m.OpenStream()
|
||||
expectTooManyStreamsError(err)
|
||||
// test we can only 8 unidirectional streams
|
||||
for i := 0; i < 8; i++ {
|
||||
str, err := m.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i)))
|
||||
}
|
||||
_, err = m.OpenUniStream()
|
||||
expectTooManyStreamsError(err)
|
||||
Expect(m.UpdateLimits(&handshake.TransportParameters{
|
||||
MaxBidiStreamNum: 5,
|
||||
MaxUniStreamNum: 5,
|
||||
})).To(Succeed())
|
||||
Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(16)))
|
||||
Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(18)))
|
||||
})
|
||||
}
|
||||
|
||||
It("rejects parameters with too large unidirectional stream counts", func() {
|
||||
Expect(m.UpdateLimits(&handshake.TransportParameters{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue