improve garbage collection of stream

This removes the Iterate method of the streamsMap, and moves the garbage
collection of closed streams from the session to the streamsMap.
This commit is contained in:
Marten Seemann 2017-10-03 04:04:21 +08:00
parent 474b036474
commit 709eec55b5
6 changed files with 161 additions and 250 deletions

View file

@ -66,7 +66,7 @@ var _ = Describe("Packet packer", func() {
cryptoStream = &stream{} cryptoStream = &stream{}
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, nil) streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
streamsMap.streams[1] = cryptoStream streamsMap.streams[1] = cryptoStream
streamsMap.openStreams = []protocol.StreamID{1} streamsMap.openStreams = []protocol.StreamID{1}
streamFramer = newStreamFramer(streamsMap, nil) streamFramer = newStreamFramer(streamsMap, nil)

View file

@ -242,7 +242,7 @@ func (s *session) setup(
} }
s.flowControlManager = flowcontrol.NewFlowControlManager(s.connParams, s.rttStats) s.flowControlManager = flowcontrol.NewFlowControlManager(s.connParams, s.rttStats)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connParams) s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective, s.connParams)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup, s.cryptoSetup,
@ -347,7 +347,10 @@ runLoop:
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
} }
s.garbageCollectStreams()
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
s.closeLocal(err)
}
} }
// only send the error the handshakeChan when the handshake is not completed yet // only send the error the handshakeChan when the handshake is not completed yet
@ -794,22 +797,6 @@ func (s *session) newStream(id protocol.StreamID) *stream {
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager)
} }
// garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map.
func (s *session) garbageCollectStreams() {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
id := str.StreamID()
if str.finished() {
err := s.streamsMap.RemoveStream(id)
if err != nil {
return false, err
}
s.flowControlManager.RemoveStream(id)
}
return true, nil
})
}
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber)
return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0))

View file

@ -330,68 +330,7 @@ var _ = Describe("Session", func() {
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
}) })
It("does not delete streams with Close()", func() {
str, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
str.Close()
sess.garbageCollectStreams()
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
It("does not delete streams with FIN bit", func() {
sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
numOpenStreams := len(sess.streamsMap.openStreams)
str, _ := sess.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
p := make([]byte, 4)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
sess.garbageCollectStreams()
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
str, _ = sess.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
})
It("deletes streams with FIN bit & close", func() {
sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
numOpenStreams := len(sess.streamsMap.openStreams)
str, _ := sess.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
p := make([]byte, 4)
_, err := str.Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
sess.garbageCollectStreams()
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
str, _ = sess.streamsMap.GetOrOpenStream(5)
Expect(str).ToNot(BeNil())
// We still need to close the stream locally
str.Close()
// ... and simulate that we actually the FIN
str.sentFin()
sess.garbageCollectStreams()
Expect(len(sess.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams))
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
// flow controller should have been notified
_, err = sess.flowControlManager.SendWindowSize(5)
Expect(err).To(MatchError("Error accessing the flowController map."))
})
It("cancels streams with error", func() { It("cancels streams with error", func() {
sess.garbageCollectStreams()
testErr := errors.New("test") testErr := errors.New("test")
sess.handleStreamFrame(&wire.StreamFrame{ sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 5, StreamID: 5,
@ -406,10 +345,6 @@ var _ = Describe("Session", func() {
sess.handleCloseError(closeError{err: testErr, remote: true}) sess.handleCloseError(closeError{err: testErr, remote: true})
_, err = str.Read(p) _, err = str.Read(p)
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
sess.garbageCollectStreams()
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
}) })
It("cancels empty streams with error", func() { It("cancels empty streams with error", func() {
@ -421,10 +356,6 @@ var _ = Describe("Session", func() {
sess.handleCloseError(closeError{err: testErr, remote: true}) sess.handleCloseError(closeError{err: testErr, remote: true})
_, err = str.Read([]byte{0}) _, err = str.Read([]byte{0})
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
sess.garbageCollectStreams()
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
}) })
It("informs the FlowControlManager about new streams", func() { It("informs the FlowControlManager about new streams", func() {
@ -447,7 +378,8 @@ var _ = Describe("Session", func() {
Expect(err).To(MatchError(io.EOF)) Expect(err).To(MatchError(io.EOF))
str.Close() str.Close()
str.sentFin() str.sentFin()
sess.garbageCollectStreams() err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, _ = sess.streamsMap.GetOrOpenStream(5) str, _ = sess.streamsMap.GetOrOpenStream(5)
Expect(str).To(BeNil()) // make sure the stream is gone Expect(str).To(BeNil()) // make sure the stream is gone
err = sess.handleStreamFrame(&wire.StreamFrame{ err = sess.handleStreamFrame(&wire.StreamFrame{
@ -472,7 +404,8 @@ var _ = Describe("Session", func() {
Expect(err).To(MatchError(io.EOF)) Expect(err).To(MatchError(io.EOF))
str.Close() str.Close()
str.sentFin() str.sentFin()
sess.garbageCollectStreams() err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, _ = sess.streamsMap.GetOrOpenStream(2) str, _ = sess.streamsMap.GetOrOpenStream(2)
Expect(str).To(BeNil()) // make sure the stream is gone Expect(str).To(BeNil()) // make sure the stream is gone
err = sess.handleStreamFrame(&wire.StreamFrame{ err = sess.handleStreamFrame(&wire.StreamFrame{
@ -638,23 +571,17 @@ var _ = Describe("Session", func() {
Expect(str).ToNot(BeNil()) Expect(str).ToNot(BeNil())
}) })
It("errors when receiving a WindowUpdateFrame for a closed stream", func() { It("ignores WINDOW_UPDATEs for a closed stream", func() {
sess.handleStreamFrame(&wire.StreamFrame{StreamID: 5}) str, err := sess.GetOrOpenStream(5)
err := sess.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.garbageCollectStreams() str.Close()
err = sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ str.(*stream).Cancel(nil)
StreamID: 5, Expect(str.(*stream).finished()).To(BeTrue())
ByteOffset: 1337, err = sess.streamsMap.DeleteClosedStreams()
})
Expect(err).To(MatchError(errWindowUpdateOnClosedStream))
})
It("ignores errors when receiving a WindowUpdateFrame for a closed stream", func() {
sess.handleStreamFrame(&wire.StreamFrame{StreamID: 5})
err := sess.streamsMap.RemoveStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.garbageCollectStreams() str, err = sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
err = sess.handleFrames([]wire.Frame{&wire.WindowUpdateFrame{ err = sess.handleFrames([]wire.Frame{&wire.WindowUpdateFrame{
StreamID: 5, StreamID: 5,
ByteOffset: 1337, ByteOffset: 1337,
@ -1219,7 +1146,8 @@ var _ = Describe("Session", func() {
str.(*stream).sentFin() str.(*stream).sentFin()
str.Close() str.Close()
str.(*stream).RegisterRemoteError(nil) str.(*stream).RegisterRemoteError(nil)
sess.garbageCollectStreams() err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
_, err = sess.flowControlManager.SendWindowSize(5) _, err = sess.flowControlManager.SendWindowSize(5)
Expect(err).To(MatchError("Error accessing the flowController map.")) Expect(err).To(MatchError("Error accessing the flowController map."))
sph.retransmissionQueue = []*ackhandler.Packet{{ sph.retransmissionQueue = []*ackhandler.Packet{{
@ -1605,13 +1533,16 @@ var _ = Describe("Session", func() {
}) })
It("returns a nil-value (not an interface with value nil) for closed streams", func() { It("returns a nil-value (not an interface with value nil) for closed streams", func() {
_, err := sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
sess.streamsMap.RemoveStream(9)
sess.garbageCollectStreams()
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
str, err := sess.GetOrOpenStream(9) str, err := sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.Close()
str.(*stream).Cancel(nil)
Expect(str.(*stream).finished()).To(BeTrue())
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
str, err = sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil()) Expect(str).To(BeNil())
// make sure that the returned value is a plain nil, not an Stream with value nil // make sure that the returned value is a plain nil, not an Stream with value nil
_, ok := str.(Stream) _, ok := str.(Stream)
@ -1650,7 +1581,7 @@ var _ = Describe("Session", func() {
s.(*stream).CloseRemote(0) s.(*stream).CloseRemote(0)
_, err = s.Read([]byte("a")) _, err = s.Read([]byte("a"))
Expect(err).To(MatchError(io.EOF)) Expect(err).To(MatchError(io.EOF))
sess.garbageCollectStreams() sess.streamsMap.DeleteClosedStreams()
} }
}) })
}) })

View file

@ -37,7 +37,7 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: id1} stream1 = &stream{streamID: id1}
stream2 = &stream{streamID: id2} stream2 = &stream{streamID: id2}
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, nil) streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
streamsMap.putStream(stream1) streamsMap.putStream(stream1)
streamsMap.putStream(stream2) streamsMap.putStream(stream2)

View file

@ -29,26 +29,27 @@ type streamsMap struct {
closeErr error closeErr error
nextStreamToAccept protocol.StreamID nextStreamToAccept protocol.StreamID
newStream newStreamLambda newStream newStreamLambda
removeStreamCallback removeStreamCallback
numOutgoingStreams uint32 numOutgoingStreams uint32
numIncomingStreams uint32 numIncomingStreams uint32
} }
type streamLambda func(*stream) (bool, error) type streamLambda func(*stream) (bool, error)
type removeStreamCallback func(protocol.StreamID)
type newStreamLambda func(protocol.StreamID) *stream type newStreamLambda func(protocol.StreamID) *stream
var ( var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap { func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap {
sm := streamsMap{ sm := streamsMap{
perspective: pers, perspective: pers,
streams: map[protocol.StreamID]*stream{}, streams: make(map[protocol.StreamID]*stream),
openStreams: make([]protocol.StreamID, 0), openStreams: make([]protocol.StreamID, 0),
newStream: newStream, newStream: newStream,
connParams: connParams, removeStreamCallback: removeStreamCallback,
connParams: connParams,
} }
sm.nextStreamOrErrCond.L = &sm.mutex sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex
@ -216,21 +217,50 @@ func (m *streamsMap) AcceptStream() (*stream, error) {
return str, nil return str, nil
} }
func (m *streamsMap) Iterate(fn streamLambda) error { func (m *streamsMap) DeleteClosedStreams() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
openStreams := append([]protocol.StreamID{}, m.openStreams...) var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for _, streamID := range openStreams { for i, streamID := range m.openStreams {
cont, err := m.iterateFunc(streamID, fn) str, ok := m.streams[streamID]
if err != nil { if !ok {
return err return errMapAccess
} }
if !cont { if !str.finished() {
break continue
}
m.removeStreamCallback(streamID)
numDeletedStreams++
m.openStreams[i] = 0
if streamID%2 == 0 {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
}
if numDeletedStreams == 0 {
return nil
}
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
}
if id != 0 {
j++
} else if uint32(j) < m.roundRobinIndex {
m.roundRobinIndex--
} }
} }
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
m.openStreamOrErrCond.Signal()
return nil return nil
} }
@ -291,36 +321,6 @@ func (m *streamsMap) putStream(s *stream) error {
return nil return nil
} }
// Attention: this function must only be called if a mutex has been acquired previously
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
s, ok := m.streams[id]
if !ok || s == nil {
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
}
if id%2 == 0 {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
for i, s := range m.openStreams {
if s == id {
// delete the streamID from the openStreams slice
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
// adjust round-robin index, if necessary
if uint32(i) < m.roundRobinIndex {
m.roundRobinIndex--
}
break
}
}
delete(m.streams, id)
m.openStreamOrErrCond.Signal()
return nil
}
func (m *streamsMap) CloseWithError(err error) { func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View file

@ -27,16 +27,26 @@ var _ = Describe("Streams Map", func() {
mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams))
mockPn.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) mockPn.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams))
m = newStreamsMap(nil, p, mockPn) newStream := func(id protocol.StreamID) *stream {
m.newStream = func(id protocol.StreamID) *stream { return newStream(id, func() {}, nil, nil)
return newStream(id, nil, nil, nil)
} }
removeStreamCallback := func(protocol.StreamID) {}
m = newStreamsMap(newStream, removeStreamCallback, p, mockPn)
} }
AfterEach(func() { AfterEach(func() {
Expect(m.openStreams).To(HaveLen(len(m.streams))) Expect(m.openStreams).To(HaveLen(len(m.streams)))
}) })
deleteStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
str.cancelled.Set(true)
Expect(str.finished()).To(BeTrue())
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
}
Context("getting and creating streams", func() { Context("getting and creating streams", func() {
Context("as a server", func() { Context("as a server", func() {
BeforeEach(func() { BeforeEach(func() {
@ -77,8 +87,7 @@ var _ = Describe("Streams Map", func() {
It("returns nil for closed streams", func() { It("returns nil for closed streams", func() {
_, err := m.GetOrOpenStream(5) _, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = m.RemoveStream(5) deleteStream(5)
Expect(err).NotTo(HaveOccurred())
s, err := m.GetOrOpenStream(5) s, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s).To(BeNil()) Expect(s).To(BeNil())
@ -95,7 +104,7 @@ var _ = Describe("Streams Map", func() {
It("doesn't reopen an already closed stream", func() { It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(5) _, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = m.RemoveStream(5) deleteStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(5) str, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -119,9 +128,9 @@ var _ = Describe("Streams Map", func() {
It("does not error when many streams are opened and closed", func() { It("does not error when many streams are opened and closed", func() {
for i := 2; i < 10*maxIncomingStreams; i++ { for i := 2; i < 10*maxIncomingStreams; i++ {
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
m.RemoveStream(protocol.StreamID(i*2 + 1)) deleteStream(str.StreamID())
} }
}) })
}) })
@ -148,7 +157,7 @@ var _ = Describe("Streams Map", func() {
str, err := m.OpenStream() str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
err = m.RemoveStream(2) deleteStream(2)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(2) str, err = m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -169,7 +178,7 @@ var _ = Describe("Streams Map", func() {
for i := 2; i < 10*maxOutgoingStreams; i++ { for i := 2; i < 10*maxOutgoingStreams; i++ {
str, err := m.OpenStream() str, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
m.RemoveStream(str.StreamID()) deleteStream(str.StreamID())
} }
}) })
@ -208,8 +217,7 @@ var _ = Describe("Streams Map", func() {
}() }()
Consistently(func() bool { return returned }).Should(BeFalse()) Consistently(func() bool { return returned }).Should(BeFalse())
err := m.RemoveStream(6) deleteStream(6)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return returned }).Should(BeTrue()) Eventually(func() bool { return returned }).Should(BeTrue())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2))) Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2)))
}) })
@ -404,7 +412,7 @@ var _ = Describe("Streams Map", func() {
str, err := m.OpenStream() str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
err = m.RemoveStream(1) deleteStream(1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(1) str, err = m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -433,7 +441,7 @@ var _ = Describe("Streams Map", func() {
It("doesn't reopen an already closed stream", func() { It("doesn't reopen an already closed stream", func() {
_, err := m.GetOrOpenStream(4) _, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = m.RemoveStream(4) deleteStream(4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(4) str, err := m.GetOrOpenStream(4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -464,6 +472,14 @@ var _ = Describe("Streams Map", func() {
setNewStreamsMap(protocol.PerspectiveServer) setNewStreamsMap(protocol.PerspectiveServer)
}) })
closeStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
Expect(str.finished()).To(BeFalse())
str.cancelled.Set(true)
Expect(str.finished()).To(BeTrue())
}
Context("deleting streams", func() { Context("deleting streams", func() {
BeforeEach(func() { BeforeEach(func() {
for i := 1; i <= 5; i++ { for i := 1; i <= 5; i++ {
@ -473,27 +489,36 @@ var _ = Describe("Streams Map", func() {
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
}) })
It("errors when removing non-existing stream", func() { It("does not delete streams with Close()", func() {
err := m.RemoveStream(1337) str, err := m.GetOrOpenStream(55)
Expect(err).To(MatchError("attempted to remove non-existing stream: 1337")) Expect(err).ToNot(HaveOccurred())
str.Close()
err = m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
}) })
It("removes the first stream", func() { It("removes the first stream", func() {
err := m.RemoveStream(1) closeStream(1)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4)) Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
}) })
It("removes a stream in the middle", func() { It("removes a stream in the middle", func() {
err := m.RemoveStream(3) closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4)) Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
}) })
It("removes a stream at the end", func() { It("removes a stream at the end", func() {
err := m.RemoveStream(5) closeStream(5)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4)) Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
@ -501,64 +526,15 @@ var _ = Describe("Streams Map", func() {
It("removes all streams", func() { It("removes all streams", func() {
for i := 1; i <= 5; i++ { for i := 1; i <= 5; i++ {
err := m.RemoveStream(protocol.StreamID(i)) closeStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} }
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty()) Expect(m.openStreams).To(BeEmpty())
}) })
}) })
Context("Iterate", func() {
// create 3 streams, ids 1 to 3
BeforeEach(func() {
for i := 1; i <= 3; i++ {
err := m.putStream(&stream{streamID: protocol.StreamID(i)})
Expect(err).NotTo(HaveOccurred())
}
})
It("executes the lambda exactly once for every stream", func() {
var numIterations int
callbackCalled := make(map[protocol.StreamID]bool)
fn := func(str *stream) (bool, error) {
callbackCalled[str.StreamID()] = true
numIterations++
return true, nil
}
err := m.Iterate(fn)
Expect(err).ToNot(HaveOccurred())
Expect(callbackCalled).To(HaveKey(protocol.StreamID(1)))
Expect(callbackCalled).To(HaveKey(protocol.StreamID(2)))
Expect(callbackCalled).To(HaveKey(protocol.StreamID(3)))
Expect(numIterations).To(Equal(3))
})
It("stops iterating when the callback returns false", func() {
var numIterations int
fn := func(str *stream) (bool, error) {
numIterations++
return false, nil
}
err := m.Iterate(fn)
Expect(err).ToNot(HaveOccurred())
// due to map access randomization, we don't know for which stream the callback was executed
// but it must only be executed once
Expect(numIterations).To(Equal(1))
})
It("returns the error, if the lambda returns one", func() {
var numIterations int
expectedError := errors.New("test")
fn := func(str *stream) (bool, error) {
numIterations++
return true, expectedError
}
err := m.Iterate(fn)
Expect(err).To(MatchError(expectedError))
Expect(numIterations).To(Equal(1))
})
})
Context("RoundRobinIterate", func() { Context("RoundRobinIterate", func() {
// create 5 streams, ids 4 to 8 // create 5 streams, ids 4 to 8
var lambdaCalledForStream []protocol.StreamID var lambdaCalledForStream []protocol.StreamID
@ -630,22 +606,39 @@ var _ = Describe("Streams Map", func() {
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7})) Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7}))
}) })
It("adjust the RoundRobinIndex when deleting an element in front", func() { Context("adjusting the RoundRobinIndex when deleting streams", func() {
m.roundRobinIndex = 3 // stream 7 /*
m.RemoveStream(5) Index: 0 1 2 3 4
Expect(m.roundRobinIndex).To(Equal(uint32(2))) StreamID: [ 4, 5, 6, 7, 8 ]
}) */
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() { It("adjusts when deleting an element in front", func() {
m.roundRobinIndex = 1 // stream 5 m.roundRobinIndex = 3 // stream 7
m.RemoveStream(7) deleteStream(5)
Expect(m.roundRobinIndex).To(BeEquivalentTo(1)) Expect(m.roundRobinIndex).To(Equal(uint32(2)))
}) })
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() { It("doesn't adjust when deleting an element at the back", func() {
m.roundRobinIndex = 3 // stream 7 m.roundRobinIndex = 1 // stream 5
m.RemoveStream(7) deleteStream(7)
Expect(m.roundRobinIndex).To(Equal(uint32(3))) Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
It("doesn't adjust when deleting the element it is pointing to", func() {
m.roundRobinIndex = 3 // stream 7
deleteStream(7)
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
})
It("adjusts when deleting multiple elements", func() {
m.roundRobinIndex = 3 // stream 7
closeStream(5)
closeStream(6)
closeStream(8)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.roundRobinIndex).To(BeEquivalentTo(1))
})
}) })
Context("Prioritizing crypto- and header streams", func() { Context("Prioritizing crypto- and header streams", func() {