immediately delete a stream when it is completed

By introducing a callback to the stream, which the stream calls as soon
as it is completed, we can get rid of checking every single open stream
if it is completed.
This commit is contained in:
Marten Seemann 2017-12-25 16:32:29 +07:00
parent 843a0786fc
commit 8a3f807a12
12 changed files with 199 additions and 423 deletions

View file

@ -7,22 +7,16 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
finishedStreams map[protocol.StreamID]*gomock.Call
)
var m *streamsMap
newStream := func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
c := str.EXPECT().finished().Return(false).AnyTimes()
finishedStreams[id] = c
return str
}
@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() {
m = newStreamsMap(newStream, p, v)
}
BeforeEach(func() {
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
})
AfterEach(func() {
Expect(m.openStreams).To(HaveLen(len(m.streams)))
})
deleteStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
finishedStreams[id].Return(true)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
}
Context("getting and creating streams", func() {
@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() {
})
})
Context("DoS mitigation, iterating and deleting", func() {
Context("Ranging", func() {
It("ranges over all open streams", func() {
setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever)
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool {
return callbackCalledForStream[i] < callbackCalledForStream[j]
})
}
Expect(m.streams).To(BeEmpty())
// create 5 streams, ids 4 to 8
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
// execute the callback for all streams
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
})
Context("deleting streams", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
})
closeStream := func(id protocol.StreamID) {
str := m.streams[id]
ExpectWithOffset(1, str).ToNot(BeNil())
finishedStreams[id].Return(true)
}
Context("deleting streams", func() {
Context("as a server", func() {
BeforeEach(func() {
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 1 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
})
It("does not delete streams with Close()", func() {
str, err := m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().Close()
str.Close()
err = m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
It("removes the first stream", func() {
closeStream(1)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream in the middle", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a client-initiated stream", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("removes a stream at the end", func() {
closeStream(5)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes all streams", func() {
for i := 1; i <= 5; i++ {
closeStream(protocol.StreamID(i))
}
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 0 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
})
It("removes a stream that we initiated", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream that the server initiated", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("removes all streams", func() {
closeStream(3)
closeStream(2)
closeStream(5)
closeStream(4)
closeStream(7)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
It("deletes an incoming stream", func() {
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(1))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
Context("Ranging", func() {
// create 5 streams, ids 4 to 8
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
}
It("deletes an outgoing stream", func() {
m.UpdateMaxStreamLimit(10000)
_, err := m.OpenStream() // open stream 2
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
BeforeEach(func() {
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
})
It("ranges over all open streams", func() {
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
It("errors when the stream doesn't exist", func() {
err := m.DeleteStream(1337)
Expect(err).To(MatchError(errMapAccess))
})
})
})