mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
parent
10f8156951
commit
3b4feedb2c
2 changed files with 158 additions and 61 deletions
|
@ -21,11 +21,14 @@ type streamsMap struct {
|
||||||
highestStreamOpenedByClient protocol.StreamID
|
highestStreamOpenedByClient protocol.StreamID
|
||||||
streamsOpenedAfterLastGarbageCollect int
|
streamsOpenedAfterLastGarbageCollect int
|
||||||
|
|
||||||
newStream newStreamLambda
|
newStream newStreamLambda
|
||||||
maxOpenOutgoingStreams uint32
|
|
||||||
maxIncomingStreams uint32
|
|
||||||
|
|
||||||
roundRobinIndex int
|
maxOutgoingStreams uint32
|
||||||
|
numOutgoingStreams uint32
|
||||||
|
maxIncomingStreams uint32
|
||||||
|
numIncomingStreams uint32
|
||||||
|
|
||||||
|
roundRobinIndex uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamLambda func(*stream) (bool, error)
|
type streamLambda func(*stream) (bool, error)
|
||||||
|
@ -62,7 +65,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
if ok {
|
if ok {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
if uint32(len(m.openStreams)) == m.connectionParameters.GetMaxIncomingStreams() {
|
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
|
||||||
return nil, qerr.TooManyOpenStreams
|
return nil, qerr.TooManyOpenStreams
|
||||||
}
|
}
|
||||||
if id%2 == 0 {
|
if id%2 == 0 {
|
||||||
|
@ -76,11 +79,13 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
m.numIncomingStreams++
|
||||||
|
|
||||||
if id > m.highestStreamOpenedByClient {
|
if id > m.highestStreamOpenedByClient {
|
||||||
m.highestStreamOpenedByClient = id
|
m.highestStreamOpenedByClient = id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybe trigger garbage collection of streams map
|
||||||
m.streamsOpenedAfterLastGarbageCollect++
|
m.streamsOpenedAfterLastGarbageCollect++
|
||||||
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
|
if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 {
|
||||||
m.garbageCollectClosedStreams()
|
m.garbageCollectClosedStreams()
|
||||||
|
@ -92,7 +97,28 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
|
|
||||||
// OpenStream opens a stream from the server's side
|
// OpenStream opens a stream from the server's side
|
||||||
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
panic("OpenStream: not implemented")
|
if id%2 == 1 {
|
||||||
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
_, ok := m.streams[id]
|
||||||
|
if ok {
|
||||||
|
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is already open", id))
|
||||||
|
}
|
||||||
|
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
|
||||||
|
return nil, qerr.TooManyOpenStreams
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := m.newStream(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.numOutgoingStreams++
|
||||||
|
|
||||||
|
m.putStream(s)
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) Iterate(fn streamLambda) error {
|
func (m *streamsMap) Iterate(fn streamLambda) error {
|
||||||
|
@ -118,7 +144,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
numStreams := len(m.openStreams)
|
numStreams := uint32(len(m.openStreams))
|
||||||
startIndex := m.roundRobinIndex
|
startIndex := m.roundRobinIndex
|
||||||
|
|
||||||
for _, i := range []protocol.StreamID{1, 3} {
|
for _, i := range []protocol.StreamID{1, 3} {
|
||||||
|
@ -131,7 +157,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numStreams; i++ {
|
for i := uint32(0); i < numStreams; i++ {
|
||||||
streamID := m.openStreams[(i+startIndex)%numStreams]
|
streamID := m.openStreams[(i+startIndex)%numStreams]
|
||||||
|
|
||||||
if streamID == 1 || streamID == 3 {
|
if streamID == 1 || streamID == 3 {
|
||||||
|
@ -181,13 +207,18 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
m.streams[id] = nil
|
m.streams[id] = nil
|
||||||
|
if id%2 == 0 {
|
||||||
|
m.numOutgoingStreams--
|
||||||
|
} else {
|
||||||
|
m.numIncomingStreams--
|
||||||
|
}
|
||||||
|
|
||||||
for i, s := range m.openStreams {
|
for i, s := range m.openStreams {
|
||||||
if s == id {
|
if s == id {
|
||||||
// delete the streamID from the openStreams slice
|
// delete the streamID from the openStreams slice
|
||||||
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
|
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
|
||||||
// adjust round-robin index, if necessary
|
// adjust round-robin index, if necessary
|
||||||
if i < m.roundRobinIndex {
|
if uint32(i) < m.roundRobinIndex {
|
||||||
m.roundRobinIndex--
|
m.roundRobinIndex--
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
@ -204,7 +235,10 @@ func (m *streamsMap) garbageCollectClosedStreams() {
|
||||||
if str != nil {
|
if str != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient {
|
|
||||||
|
// server-side streams can be gargage collected immediately
|
||||||
|
// client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams
|
||||||
|
if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient {
|
||||||
delete(m.streams, id)
|
delete(m.streams, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
type mockConnectionParametersManager struct {
|
type mockConnectionParametersManager struct {
|
||||||
maxIncomingStreams uint32
|
maxIncomingStreams uint32
|
||||||
|
maxOutgoingStreams uint32
|
||||||
idleTime time.Duration
|
idleTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +36,7 @@ func (m *mockConnectionParametersManager) GetReceiveStreamFlowControlWindow() pr
|
||||||
func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
||||||
return math.MaxUint64
|
return math.MaxUint64
|
||||||
}
|
}
|
||||||
func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { panic("not implemented") }
|
func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { return m.maxOutgoingStreams }
|
||||||
func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return m.maxIncomingStreams }
|
func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return m.maxIncomingStreams }
|
||||||
func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
||||||
return m.idleTime
|
return m.idleTime
|
||||||
|
@ -53,6 +54,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
cpm = &mockConnectionParametersManager{
|
cpm = &mockConnectionParametersManager{
|
||||||
maxIncomingStreams: 75,
|
maxIncomingStreams: 75,
|
||||||
|
maxOutgoingStreams: 60,
|
||||||
}
|
}
|
||||||
m = newStreamsMap(nil, cpm)
|
m = newStreamsMap(nil, cpm)
|
||||||
})
|
})
|
||||||
|
@ -68,57 +70,118 @@ var _ = Describe("Streams Map", func() {
|
||||||
s, err := m.GetOrOpenStream(5)
|
s, err := m.GetOrOpenStream(5)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
|
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
||||||
|
Expect(m.numOutgoingStreams).To(BeZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects streams with even IDs", func() {
|
Context("client-side streams", func() {
|
||||||
_, err := m.GetOrOpenStream(6)
|
It("rejects streams with even IDs", func() {
|
||||||
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
_, err := m.GetOrOpenStream(6)
|
||||||
})
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
|
||||||
|
|
||||||
It("gets existing streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
s, err = m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil for closed streams", func() {
|
|
||||||
s, err := m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
err = m.RemoveStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
s, err = m.GetOrOpenStream(5)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(s).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("panics on OpenStream", func() {
|
|
||||||
Expect(func() { m.OpenStream(0) }).To(Panic())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("counting streams", func() {
|
|
||||||
var maxNumStreams int
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
maxNumStreams = int(cpm.GetMaxIncomingStreams())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when too many streams are opened", func() {
|
It("gets existing streams", func() {
|
||||||
for i := 0; i < maxNumStreams; i++ {
|
s, err := m.GetOrOpenStream(5)
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(err).NotTo(HaveOccurred())
|
s, err = m.GetOrOpenStream(5)
|
||||||
}
|
Expect(err).NotTo(HaveOccurred())
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2))
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(5)))
|
||||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
Expect(m.numIncomingStreams).To(Equal(uint32(1)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("does not error when many streams are opened and closed", func() {
|
It("returns nil for closed streams", func() {
|
||||||
for i := 2; i < 10*maxNumStreams; i++ {
|
s, err := m.GetOrOpenStream(5)
|
||||||
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(err).NotTo(HaveOccurred())
|
err = m.RemoveStream(5)
|
||||||
m.RemoveStream(protocol.StreamID(i*2 + 1))
|
Expect(err).NotTo(HaveOccurred())
|
||||||
}
|
s, err = m.GetOrOpenStream(5)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(s).To(BeNil())
|
||||||
|
Expect(m.numIncomingStreams).To(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("counting streams", func() {
|
||||||
|
var maxNumStreams int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
maxNumStreams = int(cpm.GetMaxIncomingStreams())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when too many streams are opened", func() {
|
||||||
|
for i := 0; i < maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2))
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not error when many streams are opened and closed", func() {
|
||||||
|
for i := 2; i < 10*maxNumStreams; i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
m.RemoveStream(protocol.StreamID(i*2 + 1))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("server-side streams", func() {
|
||||||
|
It("rejects streams with odd IDs", func() {
|
||||||
|
_, err := m.OpenStream(5)
|
||||||
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("opens a new stream", func() {
|
||||||
|
s, err := m.OpenStream(6)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(s).ToNot(BeNil())
|
||||||
|
Expect(s.StreamID()).To(Equal(protocol.StreamID(6)))
|
||||||
|
Expect(m.numIncomingStreams).To(BeZero())
|
||||||
|
Expect(m.numOutgoingStreams).To(Equal(uint32(1)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns an error for already openend streams", func() {
|
||||||
|
_, err := m.OpenStream(4)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = m.OpenStream(4)
|
||||||
|
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4, which is already open"))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("counting streams", func() {
|
||||||
|
var maxNumStreams int
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when too many streams are opened", func() {
|
||||||
|
for i := 1; i <= maxNumStreams; i++ {
|
||||||
|
_, err := m.OpenStream(protocol.StreamID(2 * i))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
}
|
||||||
|
_, err := m.OpenStream(protocol.StreamID(2*maxNumStreams + 10))
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not error when many streams are opened and closed", func() {
|
||||||
|
for i := 2; i < 10*maxNumStreams; i++ {
|
||||||
|
_, err := m.OpenStream(protocol.StreamID(2*i + 2))
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
m.RemoveStream(protocol.StreamID(2 * i))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("allows many server- and client-side streams at the same time", func() {
|
||||||
|
for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ {
|
||||||
|
_, err := m.OpenStream(protocol.StreamID(2 * i))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ {
|
||||||
|
_, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -343,7 +406,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(numIterations).To(Equal(5))
|
Expect(numIterations).To(Equal(5))
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6}))
|
||||||
Expect(m.roundRobinIndex).To(Equal(3))
|
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("picks up at the index+1 where it last stopped", func() {
|
It("picks up at the index+1 where it last stopped", func() {
|
||||||
|
@ -359,7 +422,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(numIterations).To(Equal(2))
|
Expect(numIterations).To(Equal(2))
|
||||||
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
|
Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5}))
|
||||||
Expect(m.roundRobinIndex).To(Equal(2))
|
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
||||||
numIterations = 0
|
numIterations = 0
|
||||||
lambdaCalledForStream = lambdaCalledForStream[:0]
|
lambdaCalledForStream = lambdaCalledForStream[:0]
|
||||||
fn2 := func(str *stream) (bool, error) {
|
fn2 := func(str *stream) (bool, error) {
|
||||||
|
@ -379,19 +442,19 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
It("adjust the RoundRobinIndex when deleting an element in front", func() {
|
||||||
m.roundRobinIndex = 3 // stream 7
|
m.roundRobinIndex = 3 // stream 7
|
||||||
m.RemoveStream(5)
|
m.RemoveStream(5)
|
||||||
Expect(m.roundRobinIndex).To(Equal(2))
|
Expect(m.roundRobinIndex).To(Equal(uint32(2)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() {
|
||||||
m.roundRobinIndex = 1 // stream 5
|
m.roundRobinIndex = 1 // stream 5
|
||||||
m.RemoveStream(7)
|
m.RemoveStream(7)
|
||||||
Expect(m.roundRobinIndex).To(Equal(1))
|
Expect(m.roundRobinIndex).To(Equal(uint32(1)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() {
|
||||||
m.roundRobinIndex = 3 // stream 7
|
m.roundRobinIndex = 3 // stream 7
|
||||||
m.RemoveStream(7)
|
m.RemoveStream(7)
|
||||||
Expect(m.roundRobinIndex).To(Equal(3))
|
Expect(m.roundRobinIndex).To(Equal(uint32(3)))
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Prioritizing crypto- and header streams", func() {
|
Context("Prioritizing crypto- and header streams", func() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue