uquic/streams_map_test.go
Marten Seemann 759b0d87b1 refactor packet unpacking
This replaces version.UsesMaxDataFrame by versoin.UsesIETFFrameFormat.
That way, we can have two separate code paths in the unpacker to unpack
either gQUIC frames or IETF frames.
2017-12-08 09:18:54 +07:00

799 lines
25 KiB
Go

package quic
import (
"errors"
"sort"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
finishedStreams map[protocol.StreamID]*gomock.Call
)
newStream := func(id protocol.StreamID) streamI {
str := mocks.NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
c := str.EXPECT().Finished().Return(false).AnyTimes()
finishedStreams[id] = c
return str
}
setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) {
m = newStreamsMap(newStream, p, v)
}
BeforeEach(func() {
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
})
AfterEach(func() {
Expect(m.openStreams).To(HaveLen(len(m.streams)))
})
deleteStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
finishedStreams[id].Return(true)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
}
Context("getting and creating streams", func() {
Context("as a server", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
})
Context("client-side streams", func() {
It("gets new streams", func() {
s, err := m.GetOrOpenStream(1)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(1)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero())
})
It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 6"))
})
It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(4)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 4"))
})
It("gets existing streams", func() {
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
numStreams := m.numIncomingStreams
s, err = m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(Equal(numStreams))
})
It("returns nil for closed streams", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
deleteStream(5)
s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil())
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(1)))
Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
})
It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
deleteStream(5)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := uint32(0); i < m.maxIncomingStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
}
_, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("errors when too many streams are opened implicitely", func() {
_, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := uint32(2); i < 10*m.maxIncomingStreams; i++ {
str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
deleteStream(str.StreamID())
}
})
})
})
Context("server-side streams", func() {
It("doesn't allow opening streams before receiving the transport parameters", func() {
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("opens a stream 2 first", func() {
m.UpdateMaxStreamLimit(100)
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.numIncomingStreams).To(BeZero())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
It("returns the error when the streamsMap was closed", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.OpenStream()
Expect(err).To(MatchError(testErr))
})
It("doesn't reopen an already closed stream", func() {
m.UpdateMaxStreamLimit(100)
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
deleteStream(2)
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
Context("counting streams", func() {
const maxOutgoingStreams = 50
BeforeEach(func() {
m.UpdateMaxStreamLimit(maxOutgoingStreams)
})
It("errors when too many streams are opened", func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxOutgoingStreams; i++ {
str, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
deleteStream(str.StreamID())
}
})
It("allows many server- and client-side streams at the same time", func() {
for i := 1; i < maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
for i := 0; i < maxOutgoingStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1))
Expect(err).ToNot(HaveOccurred())
}
})
})
Context("opening streams synchronously", func() {
const maxOutgoingStreams = 10
BeforeEach(func() {
m.UpdateMaxStreamLimit(maxOutgoingStreams)
})
openMaxNumStreams := func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}
It("waits until another stream is closed", func() {
openMaxNumStreams()
var returned bool
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
returned = true
}()
Consistently(func() bool { return returned }).Should(BeFalse())
deleteStream(6)
Eventually(func() bool { return returned }).Should(BeTrue())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2)))
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("test error")
openMaxNumStreams()
for _, str := range m.streams {
str.(*mocks.MockStreamI).EXPECT().Cancel(testErr)
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
m.CloseWithError(testErr)
Eventually(done).Should(BeClosed())
})
It("immediately returns when OpenStreamSync is called after an error was registered", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
})
})
})
Context("accepting streams", func() {
It("does nothing if no stream is opened", func() {
var accepted bool
go func() {
_, _ = m.AcceptStream()
accepted = true
}()
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveServer, versionIETFFrames)
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns an implicitly opened stream, if a stream number is skipped", func() {
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns to multiple accepts", func() {
var str1, str2 streamI
go func() {
defer GinkgoRecover()
var err error
str1, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
go func() {
defer GinkgoRecover()
var err error
str2, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(5) // opens stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str1 }).ShouldNot(BeNil())
Eventually(func() streamI { return str2 }).ShouldNot(BeNil())
Expect(str1.StreamID()).ToNot(Equal(str2.StreamID()))
Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5))
})
It("waits a new stream is available", func() {
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() streamI { return str }).Should(BeNil())
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("returns multiple streams on subsequent Accept calls, if available", func() {
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(5)))
})
It("blocks after accepting a stream", func() {
var accepted bool
_, err := m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
go func() {
defer GinkgoRecover()
_, _ = m.AcceptStream()
accepted = true
}()
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("testErr")
var acceptErr error
go func() {
_, acceptErr = m.AcceptStream()
}()
Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred())
m.CloseWithError(testErr)
Eventually(func() error { return acceptErr }).Should(MatchError(testErr))
})
It("immediately returns when Accept is called after an error was registered", func() {
testErr := errors.New("testErr")
m.CloseWithError(testErr)
_, err := m.AcceptStream()
Expect(err).To(MatchError(testErr))
})
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
m.UpdateMaxStreamLimit(100)
})
Context("server-side streams", func() {
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("gets new streams", func() {
s, err := m.GetOrOpenStream(2)
Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(2)))
Expect(m.streams).To(HaveKey(protocol.StreamID(4)))
Expect(m.streams).To(HaveKey(protocol.StreamID(6)))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("doesn't reopen an already closed stream", func() {
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
deleteStream(3)
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
})
Context("client-side streams", func() {
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveClient, versionIETFFrames)
m.UpdateMaxStreamLimit(100)
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(3))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeZero())
})
It("opens multiple streams", func() {
s1, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
s2, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2))
})
It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred())
deleteStream(4)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
})
Context("accepting streams", func() {
It("accepts stream 2 first", func() {
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred())
Eventually(func() streamI { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
})
})
})
})
Context("DoS mitigation, iterating and deleting", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
})
closeStream := func(id protocol.StreamID) {
str := m.streams[id]
ExpectWithOffset(1, str).ToNot(BeNil())
finishedStreams[id].Return(true)
}
Context("deleting streams", func() {
Context("as a server", func() {
BeforeEach(func() {
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 1 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
})
It("does not delete streams with Close()", func() {
str, err := m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
str.(*mocks.MockStreamI).EXPECT().Close()
str.Close()
err = m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
It("removes the first stream", func() {
closeStream(1)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream in the middle", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a client-initiated stream", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("removes a stream at the end", func() {
closeStream(5)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes all streams", func() {
for i := 1; i <= 5; i++ {
closeStream(protocol.StreamID(i))
}
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 0 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
})
It("removes a stream that we initiated", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream that the server initiated", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("removes all streams", func() {
closeStream(3)
closeStream(2)
closeStream(5)
closeStream(4)
closeStream(7)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
})
Context("Ranging", func() {
// create 5 streams, ids 4 to 8
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
}
BeforeEach(func() {
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
Expect(err).NotTo(HaveOccurred())
}
})
It("ranges over all open streams", func() {
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
})
Context("RoundRobinIterate", func() {
// create 5 streams, ids 4 to 8
var lambdaCalledForStream []protocol.StreamID
var numIterations int
BeforeEach(func() {
lambdaCalledForStream = lambdaCalledForStream[:0]
numIterations = 0
for i := 4; i <= 8; i++ {
err := m.putStream(newStream(protocol.StreamID(i)))
Expect(err).NotTo(HaveOccurred())
}
})
It("executes the lambda exactly once for every stream", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
return true, nil
}
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(5))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
Expect(m.roundRobinIndex).To(BeZero())
})
It("goes around once when starting in the middle", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
return true, nil
}
m.roundRobinIndex = 3 // pointing to stream 7
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(5))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
Expect(m.roundRobinIndex).To(BeEquivalentTo(3))
})
It("picks up at the index+1 where it last stopped", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
if str.StreamID() == 5 {
return false, nil
}
return true, nil
}
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(2))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
Expect(m.roundRobinIndex).To(BeEquivalentTo(2))
numIterations = 0
lambdaCalledForStream = lambdaCalledForStream[:0]
fn2 := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
if str.StreamID() == 7 {
return false, nil
}
return true, nil
}
err = m.RoundRobinIterate(fn2)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(2))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7}))
})
Context("adjusting the RoundRobinIndex when deleting streams", func() {
/*
Index: 0 1 2 3 4
StreamID: [ 4, 5, 6, 7, 8 ]
*/
It("adjusts when deleting an element in front", func() {
m.roundRobinIndex = 3 // stream 7
deleteStream(5)
Expect(m.roundRobinIndex).To(BeEquivalentTo(2))
})
It("doesn't adjust when deleting an element at the back", func() {
m.roundRobinIndex = 1 // stream 5
deleteStream(7)
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
It("doesn't adjust when deleting the element it is pointing to", func() {
m.roundRobinIndex = 3 // stream 7
deleteStream(7)
Expect(m.roundRobinIndex).To(BeEquivalentTo(3))
})
It("adjusts when deleting multiple elements", func() {
m.roundRobinIndex = 3 // stream 7
closeStream(5)
closeStream(6)
closeStream(8)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
})
})
})
})