mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
implement handling of MAX_STREAM_ID frames
This commit is contained in:
parent
cd4bcda458
commit
46e20ce8c9
11 changed files with 350 additions and 53 deletions
|
@ -10,6 +10,7 @@ import (
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockStreamManager is a mock of StreamManager interface
|
// MockStreamManager is a mock of StreamManager interface
|
||||||
|
@ -96,6 +97,18 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleMaxStreamIDFrame mocks base method
|
||||||
|
func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error {
|
||||||
|
ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame
|
||||||
|
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// OpenStream mocks base method
|
// OpenStream mocks base method
|
||||||
func (m *MockStreamManager) OpenStream() (Stream, error) {
|
func (m *MockStreamManager) OpenStream() (Stream, error) {
|
||||||
ret := m.ctrl.Call(m, "OpenStream")
|
ret := m.ctrl.Call(m, "OpenStream")
|
||||||
|
|
|
@ -37,6 +37,7 @@ type streamManager interface {
|
||||||
AcceptStream() (Stream, error)
|
AcceptStream() (Stream, error)
|
||||||
DeleteStream(protocol.StreamID) error
|
DeleteStream(protocol.StreamID) error
|
||||||
UpdateLimits(*handshake.TransportParameters)
|
UpdateLimits(*handshake.TransportParameters)
|
||||||
|
HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error
|
||||||
CloseWithError(error)
|
CloseWithError(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,6 +564,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
||||||
s.handleMaxDataFrame(frame)
|
s.handleMaxDataFrame(frame)
|
||||||
case *wire.MaxStreamDataFrame:
|
case *wire.MaxStreamDataFrame:
|
||||||
err = s.handleMaxStreamDataFrame(frame)
|
err = s.handleMaxStreamDataFrame(frame)
|
||||||
|
case *wire.MaxStreamIDFrame:
|
||||||
|
err = s.handleMaxStreamIDFrame(frame)
|
||||||
case *wire.BlockedFrame:
|
case *wire.BlockedFrame:
|
||||||
case *wire.StreamBlockedFrame:
|
case *wire.StreamBlockedFrame:
|
||||||
case *wire.StopSendingFrame:
|
case *wire.StopSendingFrame:
|
||||||
|
@ -634,6 +637,10 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error {
|
||||||
|
return s.streamsMap.HandleMaxStreamIDFrame(frame)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||||
if frame.StreamID == s.version.CryptoStreamID() {
|
if frame.StreamID == s.version.CryptoStreamID() {
|
||||||
return errors.New("Received RST_STREAM frame for the crypto stream")
|
return errors.New("Received RST_STREAM frame for the crypto stream")
|
||||||
|
|
|
@ -357,6 +357,23 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("handling MAX_STREAM_ID frames", func() {
|
||||||
|
It("passes the frame to the streamsMap", func() {
|
||||||
|
f := &wire.MaxStreamIDFrame{StreamID: 10}
|
||||||
|
streamManager.EXPECT().HandleMaxStreamIDFrame(f)
|
||||||
|
err := sess.handleMaxStreamIDFrame(f)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns errors", func() {
|
||||||
|
f := &wire.MaxStreamIDFrame{StreamID: 10}
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr)
|
||||||
|
err := sess.handleMaxStreamIDFrame(f)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("handling STOP_SENDING frames", func() {
|
Context("handling STOP_SENDING frames", func() {
|
||||||
It("passes the frame to the stream", func() {
|
It("passes the frame to the stream", func() {
|
||||||
f := &wire.StopSendingFrame{
|
f := &wire.StopSendingFrame{
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamType int
|
type streamType int
|
||||||
|
@ -102,8 +103,7 @@ func (m *streamsMap) OpenStream() (Stream, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) OpenStreamSync() (Stream, error) {
|
func (m *streamsMap) OpenStreamSync() (Stream, error) {
|
||||||
// TODO(#1150): implement limits
|
return m.outgoingBidiStreams.OpenStreamSync()
|
||||||
return m.OpenStream()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) OpenUniStream() (SendStream, error) {
|
func (m *streamsMap) OpenUniStream() (SendStream, error) {
|
||||||
|
@ -111,8 +111,7 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
||||||
// TODO(#1150): implement limits
|
return m.outgoingUniStreams.OpenStreamSync()
|
||||||
return m.outgoingUniStreams.OpenStream()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
func (m *streamsMap) AcceptStream() (Stream, error) {
|
||||||
|
@ -170,8 +169,23 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) UpdateLimits(*handshake.TransportParameters) {
|
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
|
||||||
// TODO(#1150): implement limits
|
id := f.StreamID
|
||||||
|
switch m.getStreamType(id) {
|
||||||
|
case streamTypeOutgoingBidi:
|
||||||
|
m.outgoingBidiStreams.SetMaxStream(id)
|
||||||
|
return nil
|
||||||
|
case streamTypeOutgoingUni:
|
||||||
|
m.outgoingUniStreams.SetMaxStream(id)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {
|
||||||
|
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamID)
|
||||||
|
m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) CloseWithError(err error) {
|
func (m *streamsMap) CloseWithError(err error) {
|
||||||
|
|
|
@ -256,3 +256,8 @@ func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
m.openStreamOrErrCond.Broadcast()
|
m.openStreamOrErrCond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// should never be called, since MAX_STREAM_ID frames can only be unpacked for IETF QUIC
|
||||||
|
func (m *streamsMapLegacy) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
|
||||||
|
return errors.New("gQUIC doesn't have MAX_STREAM_ID frames")
|
||||||
|
}
|
||||||
|
|
|
@ -546,4 +546,8 @@ var _ = Describe("Streams Map (for gQUIC)", func() {
|
||||||
})
|
})
|
||||||
m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321})
|
m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("doesn't accept MAX_STREAM_ID frames", func() {
|
||||||
|
Expect(m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{})).ToNot(Succeed())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -14,30 +14,57 @@ import (
|
||||||
|
|
||||||
type outgoingBidiStreamsMap struct {
|
type outgoingBidiStreamsMap struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
cond sync.Cond
|
||||||
|
|
||||||
streams map[protocol.StreamID]streamI
|
streams map[protocol.StreamID]streamI
|
||||||
|
|
||||||
nextStream protocol.StreamID
|
nextStream protocol.StreamID
|
||||||
|
maxStream protocol.StreamID
|
||||||
newStream func(protocol.StreamID) streamI
|
newStream func(protocol.StreamID) streamI
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOutgoingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *outgoingBidiStreamsMap {
|
func newOutgoingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *outgoingBidiStreamsMap {
|
||||||
return &outgoingBidiStreamsMap{
|
m := &outgoingBidiStreamsMap{
|
||||||
streams: make(map[protocol.StreamID]streamI),
|
streams: make(map[protocol.StreamID]streamI),
|
||||||
nextStream: nextStream,
|
nextStream: nextStream,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
}
|
}
|
||||||
|
m.cond.L = &m.mutex
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) {
|
func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.openStreamImpl()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
str, err := m.openStreamImpl()
|
||||||
|
if err == nil {
|
||||||
|
return str, err
|
||||||
|
}
|
||||||
|
if err != nil && err != qerr.TooManyOpenStreams {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.cond.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) {
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
if m.nextStream > m.maxStream {
|
||||||
|
return nil, qerr.TooManyOpenStreams
|
||||||
|
}
|
||||||
s := m.newStream(m.nextStream)
|
s := m.newStream(m.nextStream)
|
||||||
m.streams[m.nextStream] = s
|
m.streams[m.nextStream] = s
|
||||||
m.nextStream += 4
|
m.nextStream += 4
|
||||||
|
@ -65,8 +92,18 @@ func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if id > m.maxStream {
|
||||||
|
m.maxStream = id
|
||||||
|
m.cond.Broadcast()
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
|
func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
m.closeErr = err
|
m.closeErr = err
|
||||||
|
m.cond.Broadcast()
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,30 +15,57 @@ type item generic.Type
|
||||||
//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream"
|
//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream"
|
||||||
type outgoingItemsMap struct {
|
type outgoingItemsMap struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
cond sync.Cond
|
||||||
|
|
||||||
streams map[protocol.StreamID]item
|
streams map[protocol.StreamID]item
|
||||||
|
|
||||||
nextStream protocol.StreamID
|
nextStream protocol.StreamID
|
||||||
|
maxStream protocol.StreamID
|
||||||
newStream func(protocol.StreamID) item
|
newStream func(protocol.StreamID) item
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOutgoingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *outgoingItemsMap {
|
func newOutgoingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *outgoingItemsMap {
|
||||||
return &outgoingItemsMap{
|
m := &outgoingItemsMap{
|
||||||
streams: make(map[protocol.StreamID]item),
|
streams: make(map[protocol.StreamID]item),
|
||||||
nextStream: nextStream,
|
nextStream: nextStream,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
}
|
}
|
||||||
|
m.cond.L = &m.mutex
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *outgoingItemsMap) OpenStream() (item, error) {
|
func (m *outgoingItemsMap) OpenStream() (item, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.openStreamImpl()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
str, err := m.openStreamImpl()
|
||||||
|
if err == nil {
|
||||||
|
return str, err
|
||||||
|
}
|
||||||
|
if err != nil && err != qerr.TooManyOpenStreams {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.cond.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingItemsMap) openStreamImpl() (item, error) {
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
if m.nextStream > m.maxStream {
|
||||||
|
return nil, qerr.TooManyOpenStreams
|
||||||
|
}
|
||||||
s := m.newStream(m.nextStream)
|
s := m.newStream(m.nextStream)
|
||||||
m.streams[m.nextStream] = s
|
m.streams[m.nextStream] = s
|
||||||
m.nextStream += 4
|
m.nextStream += 4
|
||||||
|
@ -66,8 +93,18 @@ func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if id > m.maxStream {
|
||||||
|
m.maxStream = id
|
||||||
|
m.cond.Broadcast()
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *outgoingItemsMap) CloseWithError(err error) {
|
func (m *outgoingItemsMap) CloseWithError(err error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
m.closeErr = err
|
m.closeErr = err
|
||||||
|
m.cond.Broadcast()
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,11 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||||
m = newOutgoingItemsMap(firstNewStream, newItem)
|
m = newOutgoingItemsMap(firstNewStream, newItem)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("no stream ID limit", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
m.SetMaxStream(0xffffffff)
|
||||||
|
})
|
||||||
|
|
||||||
It("opens streams", func() {
|
It("opens streams", func() {
|
||||||
str, err := m.OpenStream()
|
str, err := m.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -75,4 +80,50 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||||
err = m.DeleteStream(10)
|
err = m.DeleteStream(10)
|
||||||
Expect(err).To(MatchError("Tried to delete unknown stream 10"))
|
Expect(err).To(MatchError("Tried to delete unknown stream 10"))
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with stream ID limits", func() {
|
||||||
|
It("errors when no stream can be opened immediately", func() {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("blocks until a stream can be opened synchronously", func() {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
str, err := m.OpenStreamSync()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str).To(Equal(firstNewStream))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
m.SetMaxStream(firstNewStream)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("stops opening synchronously when it is closed", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := m.OpenStreamSync()
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
m.CloseWithError(testErr)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't reduce the stream limit", func() {
|
||||||
|
m.SetMaxStream(firstNewStream)
|
||||||
|
m.SetMaxStream(firstNewStream - 4)
|
||||||
|
str, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str).To(Equal(firstNewStream))
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -14,30 +14,57 @@ import (
|
||||||
|
|
||||||
type outgoingUniStreamsMap struct {
|
type outgoingUniStreamsMap struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
cond sync.Cond
|
||||||
|
|
||||||
streams map[protocol.StreamID]sendStreamI
|
streams map[protocol.StreamID]sendStreamI
|
||||||
|
|
||||||
nextStream protocol.StreamID
|
nextStream protocol.StreamID
|
||||||
|
maxStream protocol.StreamID
|
||||||
newStream func(protocol.StreamID) sendStreamI
|
newStream func(protocol.StreamID) sendStreamI
|
||||||
|
|
||||||
closeErr error
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOutgoingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) sendStreamI) *outgoingUniStreamsMap {
|
func newOutgoingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) sendStreamI) *outgoingUniStreamsMap {
|
||||||
return &outgoingUniStreamsMap{
|
m := &outgoingUniStreamsMap{
|
||||||
streams: make(map[protocol.StreamID]sendStreamI),
|
streams: make(map[protocol.StreamID]sendStreamI),
|
||||||
nextStream: nextStream,
|
nextStream: nextStream,
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
}
|
}
|
||||||
|
m.cond.L = &m.mutex
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
|
func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.openStreamImpl()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
str, err := m.openStreamImpl()
|
||||||
|
if err == nil {
|
||||||
|
return str, err
|
||||||
|
}
|
||||||
|
if err != nil && err != qerr.TooManyOpenStreams {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.cond.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) {
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
if m.nextStream > m.maxStream {
|
||||||
|
return nil, qerr.TooManyOpenStreams
|
||||||
|
}
|
||||||
s := m.newStream(m.nextStream)
|
s := m.newStream(m.nextStream)
|
||||||
m.streams[m.nextStream] = s
|
m.streams[m.nextStream] = s
|
||||||
m.nextStream += 4
|
m.nextStream += 4
|
||||||
|
@ -65,8 +92,18 @@ func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if id > m.maxStream {
|
||||||
|
m.maxStream = id
|
||||||
|
m.cond.Broadcast()
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *outgoingUniStreamsMap) CloseWithError(err error) {
|
func (m *outgoingUniStreamsMap) CloseWithError(err error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
m.closeErr = err
|
m.closeErr = err
|
||||||
|
m.cond.Broadcast()
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -50,12 +52,20 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() {
|
||||||
Context(perspective.String(), func() {
|
Context(perspective.String(), func() {
|
||||||
var m *streamsMap
|
var m *streamsMap
|
||||||
|
|
||||||
|
allowUnlimitedStreams := func() {
|
||||||
|
m.UpdateLimits(&handshake.TransportParameters{
|
||||||
|
MaxBidiStreamID: 0xffffffff,
|
||||||
|
MaxUniStreamID: 0xffffffff,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
m = newStreamsMap(nil, newFlowController, perspective, versionIETFFrames).(*streamsMap)
|
m = newStreamsMap(nil, newFlowController, perspective, versionIETFFrames).(*streamsMap)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("opening", func() {
|
Context("opening", func() {
|
||||||
It("opens bidirectional streams", func() {
|
It("opens bidirectional streams", func() {
|
||||||
|
allowUnlimitedStreams()
|
||||||
str, err := m.OpenStream()
|
str, err := m.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
||||||
|
@ -67,6 +77,7 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("opens unidirectional streams", func() {
|
It("opens unidirectional streams", func() {
|
||||||
|
allowUnlimitedStreams()
|
||||||
str, err := m.OpenUniStream()
|
str, err := m.OpenUniStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
|
Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
|
||||||
|
@ -99,6 +110,10 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("deleting", func() {
|
Context("deleting", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
allowUnlimitedStreams()
|
||||||
|
})
|
||||||
|
|
||||||
It("deletes outgoing bidirectional streams", func() {
|
It("deletes outgoing bidirectional streams", func() {
|
||||||
id := ids.firstOutgoingBidiStream
|
id := ids.firstOutgoingBidiStream
|
||||||
str, err := m.OpenStream()
|
str, err := m.OpenStream()
|
||||||
|
@ -145,6 +160,10 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("getting streams", func() {
|
Context("getting streams", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
allowUnlimitedStreams()
|
||||||
|
})
|
||||||
|
|
||||||
Context("send streams", func() {
|
Context("send streams", func() {
|
||||||
It("gets an outgoing bidirectional stream", func() {
|
It("gets an outgoing bidirectional stream", func() {
|
||||||
// need to open the stream ourselves first
|
// need to open the stream ourselves first
|
||||||
|
@ -231,6 +250,62 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("updating stream ID limits", func() {
|
||||||
|
It("processes the parameter for outgoing bidirectional streams", func() {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
m.UpdateLimits(&handshake.TransportParameters{
|
||||||
|
MaxBidiStreamID: ids.firstOutgoingBidiStream,
|
||||||
|
})
|
||||||
|
str, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("processes the parameter for outgoing bidirectional streams", func() {
|
||||||
|
_, err := m.OpenUniStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
m.UpdateLimits(&handshake.TransportParameters{
|
||||||
|
MaxUniStreamID: ids.firstOutgoingUniStream,
|
||||||
|
})
|
||||||
|
str, err := m.OpenUniStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("handling MAX_STREAM_ID frames", func() {
|
||||||
|
It("processes IDs for outgoing bidirectional streams", func() {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
str, err := m.OpenStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("processes IDs for outgoing bidirectional streams", func() {
|
||||||
|
_, err := m.OpenUniStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
str, err := m.OpenUniStream()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects IDs for incoming bidirectional streams", func() {
|
||||||
|
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream})
|
||||||
|
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects IDs for incoming unidirectional streams", func() {
|
||||||
|
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream})
|
||||||
|
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
It("closes", func() {
|
It("closes", func() {
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
m.CloseWithError(testErr)
|
m.CloseWithError(testErr)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue