introduce a queue of active streams in the streamFramer

The queue holds all streams that have data to send. When stream.Write or
stream.Close are called, a stream is added to this queue. It is removed
from the queue when all available (at that moment) data was sent.
This way, we don't need the round robin scheduling (which, for every
packet sent, asked every single open stream if it had data) any more.
This commit is contained in:
Marten Seemann 2017-12-20 13:27:12 +07:00
parent a5d3eaea61
commit d9c107b034
14 changed files with 237 additions and 335 deletions

View file

@ -13,7 +13,7 @@ type cryptoStreamI interface {
io.Reader io.Reader
io.Writer io.Writer
handleStreamFrame(*wire.StreamFrame) error handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) *wire.StreamFrame popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error) closeForShutdown(error)
hasDataForWriting() bool hasDataForWriting() bool
setReadOffset(protocol.ByteCount) setReadOffset(protocol.ByteCount)

View file

@ -130,10 +130,11 @@ func (mr *MockCryptoStreamMockRecorder) hasDataForWriting() *gomock.Call {
} }
// popStreamFrame mocks base method // popStreamFrame mocks base method
func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0) ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame) ret0, _ := ret[0].(*wire.StreamFrame)
return ret0 ret1, _ := ret[1].(bool)
return ret0, ret1
} }
// popStreamFrame indicates an expected call of popStreamFrame // popStreamFrame indicates an expected call of popStreamFrame

View file

@ -238,10 +238,11 @@ func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.C
} }
// popStreamFrame mocks base method // popStreamFrame mocks base method
func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) {
ret := m.ctrl.Call(m, "popStreamFrame", arg0) ret := m.ctrl.Call(m, "popStreamFrame", arg0)
ret0, _ := ret[0].(*wire.StreamFrame) ret0, _ := ret[0].(*wire.StreamFrame)
return ret0 ret1, _ := ret[1].(bool)
return ret0, ret1
} }
// popStreamFrame indicates an expected call of popStreamFrame // popStreamFrame indicates an expected call of popStreamFrame

View file

@ -65,8 +65,7 @@ var _ = Describe("Packet packer", func() {
mockSender := NewMockStreamSender(mockCtrl) mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version)
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) streamFramer = newStreamFramer(cryptoStream, nil, versionGQUICFrames)
streamFramer = newStreamFramer(cryptoStream, streamsMap, versionGQUICFrames)
packer = &packetPacker{ packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},

View file

@ -122,12 +122,12 @@ func (s *sendStream) Write(p []byte) (int, error) {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have. // maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame { func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.closeForShutdownErr != nil { if s.closeForShutdownErr != nil {
return nil return nil, false
} }
frame := &wire.StreamFrame{ frame := &wire.StreamFrame{
@ -137,11 +137,11 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFra
} }
frameLen := frame.MinLength(s.version) frameLen := frame.MinLength(s.version)
if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data
return nil return nil, s.dataForWriting != nil
} }
frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen) frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen)
if len(frame.Data) == 0 && !frame.FinBit { if len(frame.Data) == 0 && !frame.FinBit {
return nil return nil, s.dataForWriting != nil
} }
if frame.FinBit { if frame.FinBit {
s.finSent = true s.finSent = true
@ -153,7 +153,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFra
}) })
} }
} }
return frame return frame, s.dataForWriting != nil
} }
func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {

View file

@ -57,7 +57,7 @@ var _ = Describe("Send Stream", func() {
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
var f *wire.StreamFrame var f *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
f = str.popStreamFrame(1000) f, _ = str.popStreamFrame(1000)
return f return f
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Expect(f.Data).To(Equal([]byte("foobar"))) Expect(f.Data).To(Equal([]byte("foobar")))
@ -86,14 +86,14 @@ var _ = Describe("Send Stream", func() {
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
var f *wire.StreamFrame var f *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
f = str.popStreamFrame(3 + frameHeaderLen) f, _ = str.popStreamFrame(3 + frameHeaderLen)
return f return f
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse()) Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(BeZero()) Expect(f.Offset).To(BeZero())
Expect(f.DataLenPresent).To(BeTrue()) Expect(f.DataLenPresent).To(BeTrue())
f = str.popStreamFrame(100) f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeFalse()) Expect(f.FinBit).To(BeFalse())
Expect(f.Offset).To(Equal(protocol.ByteCount(3))) Expect(f.Offset).To(Equal(protocol.ByteCount(3)))
@ -103,7 +103,37 @@ var _ = Describe("Send Stream", func() {
}) })
It("popStreamFrame returns nil if no data is available", func() { It("popStreamFrame returns nil if no data is available", func() {
Expect(str.popStreamFrame(1000)).To(BeNil()) frame, hasMoreData := str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
})
It("says if it has more data for writing", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2)
mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2)
mockFC.EXPECT().IsNewlyBlocked().Times(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(100))
close(done)
}()
var hasMoreData bool
Eventually(func() *wire.StreamFrame {
var frame *wire.StreamFrame
frame, hasMoreData = str.popStreamFrame(50)
return frame
}).ShouldNot(BeNil())
Expect(hasMoreData).To(BeTrue())
frame, hasMoreData := str.popStreamFrame(1000)
Expect(frame).ToNot(BeNil())
Expect(hasMoreData).To(BeFalse())
frame, _ = str.popStreamFrame(1000)
Expect(frame).To(BeNil())
Eventually(done).Should(BeClosed())
}) })
It("copies the slice while writing", func() { It("copies the slice while writing", func() {
@ -121,10 +151,10 @@ var _ = Describe("Send Stream", func() {
Expect(n).To(Equal(3)) Expect(n).To(Equal(3))
}() }()
var frame *wire.StreamFrame var frame *wire.StreamFrame
Eventually(func() *wire.StreamFrame { frame = str.popStreamFrame(frameHeaderSize + 1); return frame }).ShouldNot(BeNil()) Eventually(func() *wire.StreamFrame { frame, _ = str.popStreamFrame(frameHeaderSize + 1); return frame }).ShouldNot(BeNil())
Expect(frame.Data).To(Equal([]byte("f"))) Expect(frame.Data).To(Equal([]byte("f")))
s[1] = 'e' s[1] = 'e'
f := str.popStreamFrame(100) f, _ := str.popStreamFrame(100)
Expect(f).ToNot(BeNil()) Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("oo"))) Expect(f.Data).To(Equal([]byte("oo")))
}) })
@ -168,7 +198,7 @@ var _ = Describe("Send Stream", func() {
}() }()
var f *wire.StreamFrame var f *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
f = str.popStreamFrame(1000) f, _ = str.popStreamFrame(1000)
return f return f
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -191,7 +221,7 @@ var _ = Describe("Send Stream", func() {
Expect(str.Close()).To(Succeed()) Expect(str.Close()).To(Succeed())
var f *wire.StreamFrame var f *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
f = str.popStreamFrame(1000) f, _ = str.popStreamFrame(1000)
return f return f
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
@ -236,7 +266,7 @@ var _ = Describe("Send Stream", func() {
}() }()
var frame *wire.StreamFrame var frame *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
frame = str.popStreamFrame(50) frame, _ = str.popStreamFrame(50)
return frame return frame
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
@ -259,11 +289,13 @@ var _ = Describe("Send Stream", func() {
}() }()
var frame *wire.StreamFrame var frame *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
frame = str.popStreamFrame(50) frame, _ = str.popStreamFrame(50)
return frame return frame
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed())
Expect(str.popStreamFrame(50)).To(BeNil()) frame, hasMoreData := str.popStreamFrame(50)
Expect(frame).To(BeNil())
Expect(hasMoreData).To(BeFalse())
}) })
It("doesn't unblock if the deadline is changed before the first one expires", func() { It("doesn't unblock if the deadline is changed before the first one expires", func() {
@ -315,10 +347,11 @@ var _ = Describe("Send Stream", func() {
It("allows FIN", func() { It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().onHasStreamData(streamID)
str.Close() str.Close()
f := str.popStreamFrame(1000) f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil()) Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty()) Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
Expect(hasMoreData).To(BeFalse())
}) })
It("doesn't send a FIN when there's still data", func() { It("doesn't send a FIN when there's still data", func() {
@ -329,29 +362,32 @@ var _ = Describe("Send Stream", func() {
mockFC.EXPECT().IsNewlyBlocked() mockFC.EXPECT().IsNewlyBlocked()
str.dataForWriting = []byte("foobar") str.dataForWriting = []byte("foobar")
Expect(str.Close()).To(Succeed()) Expect(str.Close()).To(Succeed())
f := str.popStreamFrame(3 + frameHeaderLen) f, _ := str.popStreamFrame(3 + frameHeaderLen)
Expect(f).ToNot(BeNil()) Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse()) Expect(f.FinBit).To(BeFalse())
f = str.popStreamFrame(100) f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
}) })
It("doesn't allow FIN after an error", func() { It("doesn't allow FIN after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("test")) str.closeForShutdown(errors.New("test"))
f := str.popStreamFrame(1000) f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil()) Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
}) })
It("doesn't allow FIN twice", func() { It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().onHasStreamData(streamID)
str.Close() str.Close()
f := str.popStreamFrame(1000) f, _ := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil()) Expect(f).ToNot(BeNil())
Expect(f.Data).To(BeEmpty()) Expect(f.Data).To(BeEmpty())
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
Expect(str.popStreamFrame(1000)).To(BeNil()) f, hasMoreData := str.popStreamFrame(1000)
Expect(f).To(BeNil())
Expect(hasMoreData).To(BeFalse())
}) })
}) })
@ -377,7 +413,12 @@ var _ = Describe("Send Stream", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
Eventually(func() *wire.StreamFrame { return str.popStreamFrame(50) }).ShouldNot(BeNil()) // get a STREAM frame containing some data, but not all Eventually(func() *wire.StreamFrame {
defer GinkgoRecover()
// get a STREAM frame containing some data, but not all
frame, _ := str.popStreamFrame(50)
return frame
}).ShouldNot(BeNil())
str.closeForShutdown(testErr) str.closeForShutdown(testErr)
Expect(str.popStreamFrame(1000)).To(BeNil()) Expect(str.popStreamFrame(1000)).To(BeNil())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -421,7 +462,7 @@ var _ = Describe("Send Stream", func() {
}() }()
var frame *wire.StreamFrame var frame *wire.StreamFrame
Eventually(func() *wire.StreamFrame { Eventually(func() *wire.StreamFrame {
frame = str.popStreamFrame(50) frame, _ = str.popStreamFrame(50)
return frame return frame
}).ShouldNot(BeNil()) }).ShouldNot(BeNil())
err := str.CancelWrite(1234) err := str.CancelWrite(1234)
@ -519,7 +560,7 @@ var _ = Describe("Send Stream", func() {
It("is finished after Close()", func() { It("is finished after Close()", func() {
mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().onHasStreamData(streamID)
str.Close() str.Close()
f := str.popStreamFrame(1000) f, _ := str.popStreamFrame(1000)
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
Expect(str.finished()).To(BeTrue()) Expect(str.finished()).To(BeTrue())
}) })

View file

@ -961,7 +961,8 @@ func (s *session) onHasWindowUpdate(id protocol.StreamID) {
s.scheduleSending() s.scheduleSending()
} }
func (s *session) onHasStreamData(streamID protocol.StreamID) { func (s *session) onHasStreamData(id protocol.StreamID) {
s.streamFramer.AddActiveStream(id)
s.scheduleSending() s.scheduleSending()
} }

View file

@ -1171,90 +1171,6 @@ var _ = Describe("Session", func() {
Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero())
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37}))))
}) })
Context("bundling of small packets", func() {
It("bundles two small frames of different streams into one packet", func() {
s1, err := sess.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
s2, err := sess.GetOrOpenStream(7)
Expect(err).NotTo(HaveOccurred())
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := s1.Write([]byte("foobar1"))
Expect(err).ToNot(HaveOccurred())
close(done1)
}()
go func() {
defer GinkgoRecover()
s2.Write([]byte("foobar2"))
Expect(err).ToNot(HaveOccurred())
close(done2)
}()
time.Sleep(100 * time.Millisecond) // make sure the both writes are active
sess.scheduleSending()
go sess.run()
defer sess.Close(nil)
Eventually(mconn.written).Should(HaveLen(1))
packet := <-mconn.written
Expect(packet).To(ContainSubstring("foobar1"))
Expect(packet).To(ContainSubstring("foobar2"))
Eventually(done1).Should(BeClosed())
Eventually(done2).Should(BeClosed())
})
It("sends out two big frames in two packets", func() {
s1, err := sess.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
s2, err := sess.GetOrOpenStream(7)
Expect(err).NotTo(HaveOccurred())
go sess.run()
defer sess.Close(nil)
go func() {
defer GinkgoRecover()
s1.Write(bytes.Repeat([]byte{'e'}, 1000))
}()
_, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000))
Expect(err).ToNot(HaveOccurred())
Eventually(mconn.written).Should(HaveLen(2))
})
It("sends out two small frames that are written to long after one another into two packets", func() {
s, err := sess.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
go sess.run()
defer sess.Close(nil)
_, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred())
Eventually(mconn.written).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred())
Eventually(mconn.written).Should(HaveLen(2))
})
It("sends a queued ACK frame only once", func() {
packetNumber := protocol.PacketNumber(0x1337)
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
s, err := sess.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
go sess.run()
defer sess.Close(nil)
_, err = s.Write([]byte("foobar1"))
Expect(err).NotTo(HaveOccurred())
Eventually(mconn.written).Should(HaveLen(1))
_, err = s.Write([]byte("foobar2"))
Expect(err).NotTo(HaveOccurred())
Eventually(mconn.written).Should(HaveLen(2))
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37}))))
Expect(mconn.written).ToNot(Receive(ContainSubstring(string([]byte{0x13, 0x37}))))
})
})
}) })
It("closes when crypto stream errors", func() { It("closes when crypto stream errors", func() {
@ -1577,7 +1493,7 @@ var _ = Describe("Session", func() {
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s.Close()).To(Succeed()) Expect(s.Close()).To(Succeed())
f := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit
Expect(f.FinBit).To(BeTrue()) Expect(f.FinBit).To(BeTrue())
s.(*stream).CloseRemote(0) s.(*stream).CloseRemote(0)
_, err = s.Read([]byte("a")) _, err = s.Read([]byte("a"))

View file

@ -27,7 +27,7 @@ type streamI interface {
handleStreamFrame(*wire.StreamFrame) error handleStreamFrame(*wire.StreamFrame) error
handleRstStreamFrame(*wire.RstStreamFrame) error handleRstStreamFrame(*wire.RstStreamFrame) error
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
finished() bool finished() bool
closeForShutdown(error) closeForShutdown(error)
// methods needed for flow control // methods needed for flow control

View file

@ -1,26 +1,33 @@
package quic package quic
import ( import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
type streamFramer struct { type streamFramer struct {
streamsMap *streamsMap streamGetter streamGetter
cryptoStream cryptoStreamI cryptoStream cryptoStreamI
version protocol.VersionNumber version protocol.VersionNumber
retransmissionQueue []*wire.StreamFrame retransmissionQueue []*wire.StreamFrame
streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
} }
func newStreamFramer( func newStreamFramer(
cryptoStream cryptoStreamI, cryptoStream cryptoStreamI,
streamsMap *streamsMap, streamGetter streamGetter,
v protocol.VersionNumber, v protocol.VersionNumber,
) *streamFramer { ) *streamFramer {
return &streamFramer{ return &streamFramer{
streamsMap: streamsMap, streamGetter: streamGetter,
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v, version: v,
} }
} }
@ -29,6 +36,18 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
f.retransmissionQueue = append(f.retransmissionQueue, frame) f.retransmissionQueue = append(f.retransmissionQueue, frame)
} }
func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
return
}
f.streamQueueMutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.streamQueueMutex.Unlock()
}
func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame {
fs, currentLen := f.maybePopFramesForRetransmission(maxLen) fs, currentLen := f.maybePopFramesForRetransmission(maxLen)
return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...)
@ -38,13 +57,15 @@ func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0 return len(f.retransmissionQueue) > 0
} }
// TODO: don't need to ask the crypto stream here, just record this information in AddActiveStream
func (f *streamFramer) HasCryptoStreamFrame() bool { func (f *streamFramer) HasCryptoStreamFrame() bool {
return f.cryptoStream.hasDataForWriting() return f.cryptoStream.hasDataForWriting()
} }
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame {
return f.cryptoStream.popStreamFrame(maxLen) frame, _ := f.cryptoStream.popStreamFrame(maxLen)
return frame
} }
func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) {
@ -72,32 +93,37 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.Byte
return return
} }
func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame) { func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
var frames []*wire.StreamFrame
fn := func(s streamI) (bool, error) { f.streamQueueMutex.Lock()
if s == nil { // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
return true, nil numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxTotalLen-currentLen < protocol.MinStreamFrameSize {
break
} }
id := f.streamQueue[0]
maxLen := maxTotalLen - currentLen f.streamQueue = f.streamQueue[1:]
if maxLen < protocol.MinStreamFrameSize { // don't try to add new STREAM frames, if only little space is left in the packet str, err := f.streamGetter.GetOrOpenStream(id)
return false, nil if err != nil { // can happen if the stream completed after it said it had data
delete(f.activeStreams, id)
continue
} }
frame := s.popStreamFrame(maxLen) frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen)
if frame == nil { if hasMoreData { // put the stream back in the queue (at the end)
return true, nil f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
} }
res = append(res, frame) if frame == nil { // can happen if the receiveStream was canceled after it said it had data
continue
}
frames = append(frames, frame)
currentLen += frame.MinLength(f.version) + frame.DataLen() currentLen += frame.MinLength(f.version) + frame.DataLen()
if currentLen == maxTotalLen {
return false, nil
} }
return true, nil f.streamQueueMutex.Unlock()
} return frames
f.streamsMap.RoundRobinIterate(fn)
return
} }
// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. // maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified.

View file

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"errors"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -20,15 +21,12 @@ var _ = Describe("Stream Framer", func() {
var ( var (
retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame
framer *streamFramer framer *streamFramer
streamsMap *streamsMap
stream1, stream2 *MockStreamI stream1, stream2 *MockStreamI
streamGetter *MockStreamGetter
) )
setNoData := func(str *MockStreamI) {
str.EXPECT().popStreamFrame(gomock.Any()).AnyTimes()
}
BeforeEach(func() { BeforeEach(func() {
streamGetter = NewMockStreamGetter(mockCtrl)
retransmittedFrame1 = &wire.StreamFrame{ retransmittedFrame1 = &wire.StreamFrame{
StreamID: 5, StreamID: 5,
Data: []byte{0x13, 0x37}, Data: []byte{0x13, 0x37},
@ -43,11 +41,7 @@ var _ = Describe("Stream Framer", func() {
stream2 = NewMockStreamI(mockCtrl) stream2 = NewMockStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) framer = newStreamFramer(nil, streamGetter, versionGQUICFrames)
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)
framer = newStreamFramer(nil, streamsMap, versionGQUICFrames)
}) })
It("says if it has retransmissions", func() { It("says if it has retransmissions", func() {
@ -57,8 +51,6 @@ var _ = Describe("Stream Framer", func() {
}) })
It("sets the DataLenPresent for dequeued retransmitted frames", func() { It("sets the DataLenPresent for dequeued retransmitted frames", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MaxByteCount) fs := framer.PopStreamFrames(protocol.MaxByteCount)
Expect(fs).To(HaveLen(1)) Expect(fs).To(HaveLen(1))
@ -67,14 +59,10 @@ var _ = Describe("Stream Framer", func() {
Context("Popping", func() { Context("Popping", func() {
It("returns nil when popping an empty framer", func() { It("returns nil when popping an empty framer", func() {
setNoData(stream1)
setNoData(stream2)
Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) Expect(framer.PopStreamFrames(1000)).To(BeEmpty())
}) })
It("pops frames for retransmission", func() { It("pops frames for retransmission", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
framer.AddFrameForRetransmission(retransmittedFrame2) framer.AddFrameForRetransmission(retransmittedFrame2)
fs := framer.PopStreamFrames(1000) fs := framer.PopStreamFrames(1000)
@ -93,8 +81,6 @@ var _ = Describe("Stream Framer", func() {
}) })
It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() { It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() {
setNoData(stream1)
setNoData(stream2)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1})) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1}))
@ -107,46 +93,127 @@ var _ = Describe("Stream Framer", func() {
}) })
It("returns normal frames", func() { It("returns normal frames", func() {
setNoData(stream2) streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: id1, StreamID: id1,
Data: []byte("foobar"), Data: []byte("foobar"),
Offset: 42, Offset: 42,
} }
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(1000) fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{f})) Expect(fs).To(Equal([]*wire.StreamFrame{f}))
}) })
It("returns multiple normal frames", func() { It("skips a stream that was reported active, but was completed shortly after", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(nil, errors.New("stream was already deleted"))
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("skips a stream that was reported active, but doesn't have any data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
f := &wire.StreamFrame{
StreamID: id2,
Data: []byte("foobar"),
}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f}))
})
It("pops from a stream multiple times, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1}))
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2}))
// no further calls to popStreamFrame, after popStreamFrame said there's no more data
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil())
})
It("re-queues a stream at the end, if it has enough data", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true)
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
framer.AddActiveStream(id1) // only add it once
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2
Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1
})
It("only dequeues data from each stream once per packet", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")}
f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")}
// both streams have more data, and will be re-queued
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true)
framer.AddActiveStream(id1)
framer.AddActiveStream(id2)
Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2}))
})
It("returns multiple normal frames in the order they were reported active", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil)
f1 := &wire.StreamFrame{Data: []byte("foobar")} f1 := &wire.StreamFrame{Data: []byte("foobar")}
f2 := &wire.StreamFrame{Data: []byte("foobaz")} f2 := &wire.StreamFrame{Data: []byte("foobaz")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2) stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false)
fs := framer.PopStreamFrames(1000) framer.AddActiveStream(id2)
Expect(fs).To(HaveLen(2)) framer.AddActiveStream(id1)
Expect(fs).To(ContainElement(f1)) Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1}))
Expect(fs).To(ContainElement(f2)) })
It("only asks a stream for data once, even if it was reported active multiple times", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
f := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function
framer.AddActiveStream(id1)
framer.AddActiveStream(id1)
Expect(framer.PopStreamFrames(1000)).To(HaveLen(1))
}) })
It("returns retransmission frames before normal frames", func() { It("returns retransmission frames before normal frames", func() {
setNoData(stream2) streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
framer.AddActiveStream(id1)
f1 := &wire.StreamFrame{Data: []byte("foobar")} f1 := &wire.StreamFrame{Data: []byte("foobar")}
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false)
framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame1)
fs := framer.PopStreamFrames(1000) fs := framer.PopStreamFrames(1000)
Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1})) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1}))
}) })
It("does not pop empty frames", func() { It("does not pop empty frames", func() {
setNoData(stream1)
setNoData(stream2)
fs := framer.PopStreamFrames(500) fs := framer.PopStreamFrames(500)
Expect(fs).To(BeEmpty()) Expect(fs).To(BeEmpty())
}) })
It("pops frames that have the minimum size", func() { It("pops frames that have the minimum size", func() {
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}) streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false)
framer.AddActiveStream(id1)
framer.PopStreamFrames(protocol.MinStreamFrameSize) framer.PopStreamFrames(protocol.MinStreamFrameSize)
}) })
@ -155,35 +222,15 @@ var _ = Describe("Stream Framer", func() {
framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) framer.PopStreamFrames(protocol.MinStreamFrameSize - 1)
}) })
It("uses the round-robin scheduling", func() {
stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{
StreamID: id1,
Data: []byte("foobar"),
})
stream1.EXPECT().popStreamFrame(gomock.Any()).MaxTimes(1)
stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{
StreamID: id2,
Data: []byte("foobaz"),
})
stream2.EXPECT().popStreamFrame(gomock.Any()).MaxTimes(1)
fs := framer.PopStreamFrames(protocol.MinStreamFrameSize)
Expect(fs).To(HaveLen(1))
// it doesn't matter here if this data is from stream1 or from stream2...
firstStreamID := fs[0].StreamID
fs = framer.PopStreamFrames(protocol.MinStreamFrameSize)
Expect(fs).To(HaveLen(1))
// ... but the data popped this time has to be from the other stream
Expect(fs[0].StreamID).ToNot(Equal(firstStreamID))
})
It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() {
streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil)
// pop a frame such that the remaining size is one byte less than the minimum STREAM frame size // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
f := &wire.StreamFrame{ f := &wire.StreamFrame{
StreamID: id1, StreamID: id1,
Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)),
} }
stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f) stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false)
setNoData(stream2) framer.AddActiveStream(id1)
fs := framer.PopStreamFrames(500) fs := framer.PopStreamFrames(500)
Expect(fs).To(Equal([]*wire.StreamFrame{f})) Expect(fs).To(Equal([]*wire.StreamFrame{f}))
}) })
@ -234,8 +281,6 @@ var _ = Describe("Stream Framer", func() {
}) })
It("only removes a frame from the framer after returning all split parts", func() { It("only removes a frame from the framer after returning all split parts", func() {
setNoData(stream1)
setNoData(stream2)
frameHeaderLen := protocol.ByteCount(4) frameHeaderLen := protocol.ByteCount(4)
frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))} frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))}
framer.AddFrameForRetransmission(frame) framer.AddFrameForRetransmission(frame)

View file

@ -140,7 +140,10 @@ var _ = Describe("Stream", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(writeReturned) close(writeReturned)
}() }()
Eventually(func() *wire.StreamFrame { return str.popStreamFrame(1000) }).ShouldNot(BeNil()) Eventually(func() *wire.StreamFrame {
frame, _ := str.popStreamFrame(1000)
return frame
}).ShouldNot(BeNil())
Eventually(writeReturned).Should(BeClosed()) Eventually(writeReturned).Should(BeClosed())
mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{
StreamID: streamID, StreamID: streamID,

View file

@ -255,30 +255,6 @@ func (m *streamsMap) DeleteClosedStreams() error {
return nil return nil
} }
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
// It prioritizes the the header-stream (StreamID 3)
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
numStreams := len(m.streams)
startIndex := m.roundRobinIndex
for i := 0; i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams]
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if !cont {
break
}
}
return nil
}
// Range executes a callback for all streams, in pseudo-random order // Range executes a callback for all streams, in pseudo-random order
func (m *streamsMap) Range(cb func(s streamI)) { func (m *streamsMap) Range(cb func(s streamI)) {
m.mutex.RLock() m.mutex.RLock()

View file

@ -689,112 +689,5 @@ var _ = Describe("Streams Map", func() {
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
}) })
}) })
Context("RoundRobinIterate", func() {
// create 5 streams, ids 4 to 8
var lambdaCalledForStream []protocol.StreamID
var numIterations int
BeforeEach(func() {
lambdaCalledForStream = lambdaCalledForStream[:0]
numIterations = 0
for i := 4; i <= 8; i++ {
err := m.putStream(newStream(protocol.StreamID(i)))
Expect(err).NotTo(HaveOccurred())
}
})
It("executes the lambda exactly once for every stream", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
return true, nil
}
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(5))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
Expect(m.roundRobinIndex).To(BeZero())
})
It("goes around once when starting in the middle", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
return true, nil
}
m.roundRobinIndex = 3 // pointing to stream 7
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(5))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
Expect(m.roundRobinIndex).To(BeEquivalentTo(3))
})
It("picks up at the index+1 where it last stopped", func() {
fn := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
if str.StreamID() == 5 {
return false, nil
}
return true, nil
}
err := m.RoundRobinIterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(2))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
Expect(m.roundRobinIndex).To(BeEquivalentTo(2))
numIterations = 0
lambdaCalledForStream = lambdaCalledForStream[:0]
fn2 := func(str streamI) (bool, error) {
lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID())
numIterations++
if str.StreamID() == 7 {
return false, nil
}
return true, nil
}
err = m.RoundRobinIterate(fn2)
Expect(err).ToNot(HaveOccurred())
Expect(numIterations).To(Equal(2))
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7}))
})
Context("adjusting the RoundRobinIndex when deleting streams", func() {
/*
Index: 0 1 2 3 4
StreamID: [ 4, 5, 6, 7, 8 ]
*/
It("adjusts when deleting an element in front", func() {
m.roundRobinIndex = 3 // stream 7
deleteStream(5)
Expect(m.roundRobinIndex).To(BeEquivalentTo(2))
})
It("doesn't adjust when deleting an element at the back", func() {
m.roundRobinIndex = 1 // stream 5
deleteStream(7)
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
It("doesn't adjust when deleting the element it is pointing to", func() {
m.roundRobinIndex = 3 // stream 7
deleteStream(7)
Expect(m.roundRobinIndex).To(BeEquivalentTo(3))
})
It("adjusts when deleting multiple elements", func() {
m.roundRobinIndex = 3 // stream 7
closeStream(5)
closeStream(6)
closeStream(8)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
})
})
}) })
}) })