use separate streamsMaps for gQUIC and IETF QUIC

This is a lot of duplicate code for now, but it will make moving towards
the new stream ID mapping in IETF QUIC (and unidirectional streams) much
easier.
This commit is contained in:
Marten Seemann 2018-01-04 10:26:02 +07:00
parent 69437a0e78
commit a20e94ee16
5 changed files with 873 additions and 89 deletions

View file

@ -28,6 +28,18 @@ type streamGetter interface {
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type streamManager interface {
GetOrOpenStream(protocol.StreamID) (streamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenStreamSync() (Stream, error)
AcceptStream() (Stream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
CloseWithError(error)
}
type receivedPacket struct {
remoteAddr net.Addr
header *wire.Header
@ -310,7 +322,11 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
if s.version.UsesTLS() {
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
} else {
s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective)
}
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,

View file

@ -12,18 +12,6 @@ import (
"github.com/lucas-clemente/quic-go/qerr"
)
type streamManager interface {
GetOrOpenStream(protocol.StreamID) (streamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenStreamSync() (Stream, error)
AcceptStream() (Stream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
CloseWithError(error)
}
type streamsMap struct {
mutex sync.RWMutex
@ -49,12 +37,11 @@ type streamsMap struct {
var _ streamManager = &streamsMap{}
type streamLambda func(streamI) (bool, error)
type newStreamLambda func(protocol.StreamID) streamI
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
@ -72,13 +59,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
nextClientInitiatedStream := protocol.StreamID(1)
nextServerInitiatedStream := protocol.StreamID(2)
if !ver.UsesTLS() {
nextServerInitiatedStream = 2
nextClientInitiatedStream = 3
if pers == protocol.PerspectiveServer {
sm.highestStreamOpenedByPeer = 1
}
}
if pers == protocol.PerspectiveServer {
sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = nextClientInitiatedStream

257
streams_map_legacy.go Normal file
View file

@ -0,0 +1,257 @@
package quic
import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamsMapLegacy struct {
mutex sync.RWMutex
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
}
var _ streamManager = &streamsMapLegacy{}
func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
)
sm := streamsMapLegacy{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
nextServerInitiatedStream := protocol.StreamID(2)
nextClientInitiatedStream := protocol.StreamID(3)
if pers == protocol.PerspectiveServer {
sm.highestStreamOpenedByPeer = 1
}
if pers == protocol.PerspectiveServer {
sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = nextClientInitiatedStream
} else {
sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextServerInitiatedStream
}
return &sm
}
// getStreamPerspective says which side should initiate a stream
func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
if id%2 == 0 {
return protocol.PerspectiveServer
}
return protocol.PerspectiveClient
}
func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
// every bidirectional stream is also a receive stream
return m.GetOrOpenStream(id)
}
func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
// every bidirectional stream is also a send stream
return m.GetOrOpenStream(id)
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil
}
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 {
if _, err := m.openRemoteStream(sid); err != nil {
return nil, err
}
}
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) {
if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
m.numIncomingStreams++
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
s := m.newStream(id)
m.putStream(s)
return s, nil
}
func (m *streamsMapLegacy) openStreamImpl() (streamI, error) {
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
m.numOutgoingStreams++
s := m.newStream(m.nextStreamToOpen)
m.putStream(s)
m.nextStreamToOpen += 2
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMapLegacy) OpenStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return nil, m.closeErr
}
return m.openStreamImpl()
}
func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamOrErrCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str streamI
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
return str, nil
}
func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
delete(m.streams, id)
if m.streamInitiatedBy(id) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
m.openStreamOrErrCond.Signal()
return nil
}
func (m *streamsMapLegacy) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
return nil
}
func (m *streamsMapLegacy) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.streams {
s.closeForShutdown(err)
}
}
// TODO(#952): this won't be needed when gQUIC supports stateless handshakes
func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
m.mutex.Lock()
m.maxOutgoingStreams = params.MaxStreams
for id, str := range m.streams {
str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: id,
ByteOffset: params.StreamFlowControlWindow,
})
}
m.mutex.Unlock()
m.openStreamOrErrCond.Broadcast()
}

549
streams_map_legacy_test.go Normal file
View file

@ -0,0 +1,549 @@
package quic
import (
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map (for gQUIC)", func() {
var m *streamsMapLegacy
newStream := func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
return str
}
setNewStreamsMap := func(p protocol.Perspective) {
m = newStreamsMapLegacy(newStream, p).(*streamsMapLegacy)
}
deleteStream := func(id protocol.StreamID) {
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
}
Context("getting and creating streams", func() {
Context("as a server", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer)
})
Context("client-side streams", func() {
It("gets new streams", func() {
s, err := m.GetOrOpenStream(3)
Expect(err).NotTo(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(m.streams).To(HaveLen(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero())
})
It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 6"))
})
It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(4)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 4"))
})
It("gets existing streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
numStreams := m.numIncomingStreams
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(Equal(numStreams))
})
It("returns nil for closed streams", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
deleteStream(5)
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(7)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.streams).To(HaveKey(protocol.StreamID(7)))
})
It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
deleteStream(5)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := uint32(0); i < m.maxIncomingStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
}
_, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("errors when too many streams are opened implicitely", func() {
_, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 3))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := uint32(2); i < 10*m.maxIncomingStreams; i++ {
str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
deleteStream(str.StreamID())
}
})
})
})
Context("server-side streams", func() {
It("doesn't allow opening streams before receiving the transport parameters", func() {
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("opens a stream 2 first", func() {
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.numIncomingStreams).To(BeZero())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
It("returns the error when the streamsMap was closed", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.OpenStream()
Expect(err).To(MatchError(testErr))
})
It("doesn't reopen an already closed stream", func() {
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
deleteStream(2)
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
Context("counting streams", func() {
const maxOutgoingStreams = 50
BeforeEach(func() {
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams})
})
It("errors when too many streams are opened", func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxOutgoingStreams; i++ {
str, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
deleteStream(str.StreamID())
}
})
It("allows many server- and client-side streams at the same time", func() {
for i := 1; i < maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
for i := 0; i < maxOutgoingStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1))
Expect(err).ToNot(HaveOccurred())
}
})
})
Context("opening streams synchronously", func() {
const maxOutgoingStreams = 10
BeforeEach(func() {
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams})
})
openMaxNumStreams := func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}
It("waits until another stream is closed", func() {
openMaxNumStreams()
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
deleteStream(6)
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2)))
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("test error")
openMaxNumStreams()
for _, str := range m.streams {
str.(*MockStreamI).EXPECT().closeForShutdown(testErr)
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
m.CloseWithError(testErr)
Eventually(done).Should(BeClosed())
})
It("immediately returns when OpenStreamSync is called after an error was registered", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
})
})
})
Context("accepting streams", func() {
It("does nothing if no stream is opened", func() {
var accepted bool
go func() {
_, _ = m.AcceptStream()
accepted = true
}()
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("starts with stream 3", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns an implicitly opened stream, if a stream number is skipped", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns to multiple accepts", func() {
var str1, str2 Stream
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str1, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done1)
}()
go func() {
defer GinkgoRecover()
var err error
str2, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done2)
}()
_, err := m.GetOrOpenStream(5) // opens stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Eventually(done1).Should(BeClosed())
Eventually(done2).Should(BeClosed())
Expect(str1.StreamID()).ToNot(Equal(str2.StreamID()))
Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5))
})
It("waits until a new stream is available", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns multiple streams on subsequent Accept calls, if available", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("blocks after accepting a stream", func() {
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, _ = m.AcceptStream()
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any())
m.CloseWithError(errors.New("shut down"))
Eventually(done).Should(BeClosed())
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("testErr")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.AcceptStream()
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
m.CloseWithError(testErr)
Eventually(done).Should(BeClosed())
})
It("immediately returns when Accept is called after an error was registered", func() {
testErr := errors.New("testErr")
m.CloseWithError(testErr)
_, err := m.AcceptStream()
Expect(err).To(MatchError(testErr))
})
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient)
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
})
Context("server-side streams", func() {
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(2)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.streams).To(HaveLen(1))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(2)))
Expect(m.streams).To(HaveKey(protocol.StreamID(4)))
Expect(m.streams).To(HaveKey(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("doesn't reopen an already closed stream", func() {
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
deleteStream(3)
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
})
Context("client-side streams", func() {
It("starts with stream 3", func() {
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(3))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
It("opens multiple streams", func() {
s1, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
s2, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2))
})
It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred())
deleteStream(4)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
})
Context("accepting streams", func() {
It("accepts stream 2 first", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
})
})
})
})
Context("deleting streams", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer)
})
It("deletes an incoming stream", func() {
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(1))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("deletes an outgoing stream", func() {
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
_, err := m.OpenStream() // open stream 2
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
It("errors when the stream doesn't exist", func() {
err := m.DeleteStream(1337)
Expect(err).To(MatchError(errMapAccess))
})
})
It("sets the flow control limit", func() {
setNewStreamsMap(protocol.PerspectiveServer)
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 3,
ByteOffset: 321,
})
m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 5,
ByteOffset: 321,
})
m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321})
})
})

View file

@ -3,6 +3,7 @@ package quic
import (
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -12,7 +13,7 @@ import (
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var _ = Describe("Streams Map (for IETF QUIC)", func() {
var m *streamsMap
newStream := func(id protocol.StreamID) streamI {
@ -21,8 +22,8 @@ var _ = Describe("Streams Map", func() {
return str
}
setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) {
m = newStreamsMap(newStream, p, v)
setNewStreamsMap := func(p protocol.Perspective) {
m = newStreamsMap(newStream, p).(*streamsMap)
}
deleteStream := func(id protocol.StreamID) {
@ -32,15 +33,15 @@ var _ = Describe("Streams Map", func() {
Context("getting and creating streams", func() {
Context("as a server", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
setNewStreamsMap(protocol.PerspectiveServer)
})
Context("client-side streams", func() {
It("gets new streams", func() {
s, err := m.GetOrOpenStream(3)
s, err := m.GetOrOpenStream(1)
Expect(err).NotTo(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(s.StreamID()).To(Equal(protocol.StreamID(1)))
Expect(m.streams).To(HaveLen(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero())
@ -264,8 +265,7 @@ var _ = Describe("Streams Map", func() {
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveServer, versionIETFFrames)
It("starts with stream 1", func() {
var str Stream
done := make(chan struct{})
go func() {
@ -281,7 +281,7 @@ var _ = Describe("Streams Map", func() {
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
It("returns an implicitly opened stream, if a stream number is skipped", func() {
var str Stream
done := make(chan struct{})
go func() {
@ -294,23 +294,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns an implicitly opened stream, if a stream number is skipped", func() {
var str Stream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
})
It("returns to multiple accepts", func() {
@ -331,12 +315,12 @@ var _ = Describe("Streams Map", func() {
Expect(err).ToNot(HaveOccurred())
close(done2)
}()
_, err := m.GetOrOpenStream(5) // opens stream 3 and 5
_, err := m.GetOrOpenStream(3) // opens stream 1 and 3
Expect(err).ToNot(HaveOccurred())
Eventually(done1).Should(BeClosed())
Eventually(done2).Should(BeClosed())
Expect(str1.StreamID()).ToNot(Equal(str2.StreamID()))
Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5))
Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3))
})
It("waits until a new stream is available", func() {
@ -350,10 +334,10 @@ var _ = Describe("Streams Map", func() {
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
_, err := m.GetOrOpenStream(3)
_, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
})
It("returns multiple streams on subsequent Accept calls, if available", func() {
@ -366,39 +350,46 @@ var _ = Describe("Streams Map", func() {
Expect(err).ToNot(HaveOccurred())
close(done)
}()
_, err := m.GetOrOpenStream(5)
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("blocks after accepting a stream", func() {
var accepted bool
_, err := m.GetOrOpenStream(3)
_, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, _ = m.AcceptStream()
accepted = true
close(done)
}()
Consistently(func() bool { return accepted }).Should(BeFalse())
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any())
m.CloseWithError(errors.New("shut down"))
Eventually(done).Should(BeClosed())
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("testErr")
var acceptErr error
done := make(chan struct{})
go func() {
_, acceptErr = m.AcceptStream()
defer GinkgoRecover()
_, err := m.AcceptStream()
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed())
m.CloseWithError(testErr)
Eventually(func() error { return acceptErr }).Should(MatchError(testErr))
Eventually(done).Should(BeClosed())
})
It("immediately returns when Accept is called after an error was registered", func() {
@ -412,7 +403,7 @@ var _ = Describe("Streams Map", func() {
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
setNewStreamsMap(protocol.PerspectiveClient)
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
})
@ -451,18 +442,18 @@ var _ = Describe("Streams Map", func() {
It("doesn't reopen an already closed stream", func() {
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
deleteStream(3)
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
deleteStream(1)
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(3)
str, err = m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
})
Context("client-side streams", func() {
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveClient, versionIETFFrames)
It("starts with stream 1", func() {
setNewStreamsMap(protocol.PerspectiveClient)
m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000})
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
@ -472,15 +463,6 @@ var _ = Describe("Streams Map", func() {
Expect(m.numIncomingStreams).To(BeZero())
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(3))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
It("opens multiple streams", func() {
s1, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
@ -522,17 +504,17 @@ var _ = Describe("Streams Map", func() {
Context("deleting streams", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
setNewStreamsMap(protocol.PerspectiveServer)
})
It("deletes an incoming stream", func() {
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
_, err := m.GetOrOpenStream(3) // open stream 1 and 3
Expect(err).ToNot(HaveOccurred())
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(3)
err = m.DeleteStream(1)
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(1))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
@ -555,15 +537,15 @@ var _ = Describe("Streams Map", func() {
})
It("sets the flow control limit", func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
_, err := m.GetOrOpenStream(5)
setNewStreamsMap(protocol.PerspectiveServer)
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 3,
m.streams[1].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 1,
ByteOffset: 321,
})
m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 5,
m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
StreamID: 3,
ByteOffset: 321,
})
m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321})