mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
immediately delete a stream when it is completed
By introducing a callback to the stream, which the stream calls as soon as it is completed, we can get rid of checking every single open stream if it is completed.
This commit is contained in:
parent
843a0786fc
commit
8a3f807a12
12 changed files with 199 additions and 423 deletions
|
@ -169,18 +169,6 @@ func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Ca
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// finished mocks base method
|
||||
func (m *MockStreamI) finished() bool {
|
||||
ret := m.ctrl.Call(m, "finished")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// finished indicates an expected call of finished
|
||||
func (mr *MockStreamIMockRecorder) finished() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "finished", reflect.TypeOf((*MockStreamI)(nil).finished))
|
||||
}
|
||||
|
||||
// getWindowUpdate mocks base method
|
||||
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "getWindowUpdate")
|
||||
|
|
|
@ -55,6 +55,16 @@ func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gom
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted mocks base method
|
||||
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
|
||||
m.ctrl.Call(m, "onStreamCompleted", arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted indicates an expected call of onStreamCompleted
|
||||
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
|
||||
}
|
||||
|
||||
// queueControlFrame mocks base method
|
||||
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
|
||||
m.ctrl.Call(m, "queueControlFrame", arg0)
|
||||
|
|
|
@ -148,6 +148,7 @@ func (s *receiveStream) Read(p []byte) (int, error) {
|
|||
s.frameQueue.Pop()
|
||||
s.finRead = frame.FinBit
|
||||
if frame.FinBit {
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return bytesRead, io.EOF
|
||||
}
|
||||
}
|
||||
|
@ -219,6 +220,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
|||
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
|
||||
}
|
||||
s.signalRead()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -259,14 +261,6 @@ func (s *receiveStream) closeForShutdown(err error) {
|
|||
s.signalRead()
|
||||
}
|
||||
|
||||
func (s *receiveStream) finished() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
|
||||
s.finRead || s.resetRemotely
|
||||
}
|
||||
|
||||
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
|
||||
return s.flowController.GetWindowUpdate()
|
||||
}
|
||||
|
|
|
@ -33,7 +33,6 @@ var _ = Describe("Receive Stream", func() {
|
|||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = gbytes.TimeoutReader(str, timeout)
|
||||
strWithTimeout = str
|
||||
})
|
||||
|
||||
It("gets stream id", func() {
|
||||
|
@ -320,12 +319,12 @@ var _ = Describe("Receive Stream", func() {
|
|||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
FinBit: true,
|
||||
}
|
||||
str.handleStreamFrame(&frame)
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
|
@ -354,6 +353,7 @@ var _ = Describe("Receive Stream", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
|
@ -368,31 +368,30 @@ var _ = Describe("Receive Stream", func() {
|
|||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
Data: []byte{0xde, 0xad},
|
||||
FinBit: true,
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD}))
|
||||
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
|
||||
})
|
||||
|
||||
It("handles immediate FINs", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{},
|
||||
FinBit: true,
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
|
@ -405,6 +404,7 @@ var _ = Describe("Receive Stream", func() {
|
|||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
str.CloseRemote(0)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 8)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
|
@ -486,6 +486,7 @@ var _ = Describe("Receive Stream", func() {
|
|||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
_, err = strWithTimeout.Read(make([]byte, 100))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
err = str.CancelRead(1234)
|
||||
|
@ -526,11 +527,13 @@ var _ = Describe("Receive Stream", func() {
|
|||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleRstStreamFrame(rst)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Read", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -549,6 +552,7 @@ var _ = Describe("Receive Stream", func() {
|
|||
})
|
||||
|
||||
It("ignores duplicate RST_STREAM frames", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -580,6 +584,7 @@ var _ = Describe("Receive Stream", func() {
|
|||
close(readReturned)
|
||||
}()
|
||||
Consistently(readReturned).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(readReturned).Should(BeClosed())
|
||||
|
@ -587,8 +592,11 @@ var _ = Describe("Receive Stream", func() {
|
|||
|
||||
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
gomock.InOrder(
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
|
||||
mockSender.EXPECT().onStreamCompleted(streamID),
|
||||
)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
readReturned := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -637,43 +645,4 @@ var _ = Describe("Receive Stream", func() {
|
|||
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
finishReading := func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{FinBit: true})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 100)
|
||||
_, err = strWithTimeout.Read(b)
|
||||
ExpectWithOffset(0, err).To(MatchError(io.EOF))
|
||||
}
|
||||
|
||||
It("is finished after it is closed for shutdown", func() {
|
||||
str.closeForShutdown(errors.New("testErr"))
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished if it is only closed for reading", func() {
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
finishReading()
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
// the stream still needs to stay alive until we receive the final offset
|
||||
// (either by receiving a STREAM frame with FIN, or a RST_STREAM)
|
||||
It("is not finished after CancelRead", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelRead(123)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is finished after receiving a RST_STREAM frame", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(3), true)
|
||||
err := str.handleRstStreamFrame(&wire.RstStreamFrame{ByteOffset: 3})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -154,6 +154,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr
|
|||
}
|
||||
if frame.FinBit {
|
||||
s.finSent = true
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
|
||||
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
|
||||
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
|
||||
|
@ -231,6 +232,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr
|
|||
})
|
||||
// TODO(#991): cancel retransmissions for this stream
|
||||
s.ctxCancel()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -289,14 +291,6 @@ func (s *sendStream) closeForShutdown(err error) {
|
|||
s.ctxCancel()
|
||||
}
|
||||
|
||||
func (s *sendStream) finished() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
|
||||
s.finSent || s.canceledWrite
|
||||
}
|
||||
|
||||
func (s *sendStream) getWriteOffset() protocol.ByteCount {
|
||||
return s.writeOffset
|
||||
}
|
||||
|
|
|
@ -247,6 +247,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
// don't EXPECT a call to mockFC.IsBlocked
|
||||
|
@ -389,6 +390,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("allows FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
|
@ -409,6 +411,7 @@ var _ = Describe("Send Stream", func() {
|
|||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(Equal([]byte("foo")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
f, _ = str.popStreamFrame(100)
|
||||
Expect(f.Data).To(Equal([]byte("bar")))
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
|
@ -423,6 +426,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("doesn't allow FIN twice", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
|
@ -513,6 +517,7 @@ var _ = Describe("Send Stream", func() {
|
|||
ByteOffset: 1234,
|
||||
ErrorCode: 9876,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.writeOffset = 1234
|
||||
err := str.CancelWrite(9876)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -520,6 +525,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
|
@ -544,6 +550,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("cancels the context", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
str.CancelWrite(1234)
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
|
@ -551,6 +558,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
|
@ -559,6 +567,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("only cancels once", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.CancelWrite(4321)
|
||||
|
@ -580,6 +589,7 @@ var _ = Describe("Send Stream", func() {
|
|||
StreamID: streamID,
|
||||
ErrorCode: errorCodeStopping,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 101,
|
||||
|
@ -600,6 +610,7 @@ var _ = Describe("Send Stream", func() {
|
|||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
|
@ -609,6 +620,7 @@ var _ = Describe("Send Stream", func() {
|
|||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
|
@ -621,32 +633,4 @@ var _ = Describe("Send Stream", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
It("is finished after it is closed for shutdown", func() {
|
||||
str.closeForShutdown(errors.New("testErr"))
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after Close()", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
str.Close()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after CancelWrite", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelWrite(123)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID})
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
10
session.go
10
session.go
|
@ -408,10 +408,6 @@ runLoop:
|
|||
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
|
||||
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
|
||||
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// only send the error the handshakeChan when the handshake is not completed yet
|
||||
|
@ -950,6 +946,12 @@ func (s *session) onHasStreamData(id protocol.StreamID) {
|
|||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *session) onStreamCompleted(id protocol.StreamID) {
|
||||
if err := s.streamsMap.DeleteStream(id); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ var _ = Describe("Session", func() {
|
|||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
sess = pSess.(*session)
|
||||
Expect(sess.streamsMap.openStreams).To(BeEmpty())
|
||||
Expect(sess.streamsMap.streams).To(BeEmpty())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -195,9 +195,6 @@ var _ = Describe("Session", func() {
|
|||
sess.streamsMap.newStream = func(id protocol.StreamID) streamI {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(id).AnyTimes()
|
||||
if id == 1 {
|
||||
str.EXPECT().finished().AnyTimes()
|
||||
}
|
||||
return str
|
||||
}
|
||||
})
|
||||
|
@ -247,9 +244,9 @@ var _ = Describe("Session", func() {
|
|||
return str
|
||||
}
|
||||
sess.handleStreamFrame(f1)
|
||||
numOpenStreams := len(sess.streamsMap.openStreams)
|
||||
numOpenStreams := len(sess.streamsMap.streams)
|
||||
sess.handleStreamFrame(f2)
|
||||
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
|
||||
Expect(sess.streamsMap.streams).To(HaveLen(numOpenStreams))
|
||||
})
|
||||
|
||||
It("ignores STREAM frames for closed streams", func() {
|
||||
|
@ -329,8 +326,7 @@ var _ = Describe("Session", func() {
|
|||
It("ignores the error when the stream is not known", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
sess.streamsMap.DeleteClosedStreams()
|
||||
sess.onStreamCompleted(3)
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
|
@ -411,9 +407,7 @@ var _ = Describe("Session", func() {
|
|||
It("ignores MAX_STREAM_DATA frames for a closed stream", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.onStreamCompleted(3)
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
|
@ -457,8 +451,7 @@ var _ = Describe("Session", func() {
|
|||
It("ignores STOP_SENDING frames for a closed stream", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
sess.onStreamCompleted(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -1403,12 +1396,7 @@ var _ = Describe("Session", func() {
|
|||
It("returns a nil-value (not an interface with value nil) for closed streams", func() {
|
||||
str, err := sess.GetOrOpenStream(9)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
str.(*stream).closeForShutdown(nil)
|
||||
Expect(str.(*stream).finished()).To(BeTrue())
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
|
||||
sess.onStreamCompleted(9)
|
||||
str, err = sess.GetOrOpenStream(9)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
|
@ -1425,31 +1413,6 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("counting streams", func() {
|
||||
It("errors when too many streams are opened", func() {
|
||||
for i := 0; i < protocol.MaxIncomingStreams; i++ {
|
||||
_, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
_, err := sess.GetOrOpenStream(protocol.StreamID(301))
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
})
|
||||
|
||||
It("does not error when many streams are opened and closed", func() {
|
||||
for i := 2; i <= 1000; i++ {
|
||||
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.Close()).To(Succeed())
|
||||
f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
s.(*stream).CloseRemote(0)
|
||||
_, err = s.Read([]byte("a"))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
sess.streamsMap.DeleteClosedStreams()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("ignoring errors", func() {
|
||||
It("ignores duplicate acks", func() {
|
||||
sess.sentPacketHandler.SentPacket(&ackhandler.Packet{
|
||||
|
@ -1522,7 +1485,7 @@ var _ = Describe("Client Session", func() {
|
|||
)
|
||||
sess = sessP.(*session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.streamsMap.openStreams).To(BeEmpty())
|
||||
Expect(sess.streamsMap.streams).To(BeEmpty())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
|
65
stream.go
65
stream.go
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
|
@ -19,8 +20,34 @@ type streamSender interface {
|
|||
queueControlFrame(wire.Frame)
|
||||
onHasWindowUpdate(protocol.StreamID)
|
||||
onHasStreamData(protocol.StreamID)
|
||||
onStreamCompleted(protocol.StreamID)
|
||||
}
|
||||
|
||||
// Each of the both stream halves gets its own uniStreamSender.
|
||||
// This is necessary in order to keep track when both halves have been completed.
|
||||
type uniStreamSender struct {
|
||||
streamSender
|
||||
onStreamCompletedImpl func()
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
|
||||
s.streamSender.queueControlFrame(f)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
|
||||
s.streamSender.onHasWindowUpdate(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
|
||||
s.streamSender.onHasStreamData(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
|
||||
s.onStreamCompletedImpl()
|
||||
}
|
||||
|
||||
var _ streamSender = &uniStreamSender{}
|
||||
|
||||
type streamI interface {
|
||||
Stream
|
||||
|
||||
|
@ -28,7 +55,6 @@ type streamI interface {
|
|||
handleRstStreamFrame(*wire.RstStreamFrame) error
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
finished() bool
|
||||
closeForShutdown(error)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
|
@ -42,6 +68,11 @@ type stream struct {
|
|||
receiveStream
|
||||
sendStream
|
||||
|
||||
completedMutex sync.Mutex
|
||||
sender streamSender
|
||||
receiveStreamCompleted bool
|
||||
sendStreamCompleted bool
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
|
@ -72,10 +103,28 @@ func newStream(streamID protocol.StreamID,
|
|||
flowController flowcontrol.StreamFlowController,
|
||||
version protocol.VersionNumber,
|
||||
) *stream {
|
||||
return &stream{
|
||||
sendStream: *newSendStream(streamID, sender, flowController, version),
|
||||
receiveStream: *newReceiveStream(streamID, sender, flowController),
|
||||
s := &stream{sender: sender}
|
||||
senderForSendStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.sendStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
|
||||
senderForReceiveStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.receiveStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
|
||||
return s
|
||||
}
|
||||
|
||||
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
|
||||
|
@ -120,6 +169,10 @@ func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) finished() bool {
|
||||
return s.sendStream.finished() && s.receiveStream.finished()
|
||||
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
|
||||
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
|
||||
func (s *stream) checkIfCompleted() {
|
||||
if s.sendStreamCompleted && s.receiveStreamCompleted {
|
||||
s.sender.onStreamCompleted(s.StreamID())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
|
@ -72,6 +71,7 @@ var _ = Describe("Stream", func() {
|
|||
ByteOffset: 1000,
|
||||
ErrorCode: errorCodeStoppingGQUIC,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
|
||||
str.writeOffset = 1000
|
||||
f := &wire.RstStreamFrame{
|
||||
|
@ -189,26 +189,21 @@ var _ = Describe("Stream", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
It("is finished when both the send and the receive side are finished", func() {
|
||||
str.receiveStream.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeTrue())
|
||||
Expect(str.sendStream.finished()).To(BeFalse())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
Context("completing", func() {
|
||||
It("is not completed when only the receive side is completed", func() {
|
||||
// don't EXPECT a call to mockSender.onStreamCompleted()
|
||||
str.receiveStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
|
||||
It("is not finished when the receive side is finished", func() {
|
||||
str.sendStream.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeFalse())
|
||||
Expect(str.sendStream.finished()).To(BeTrue())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
It("is not completed when only the send side is completed", func() {
|
||||
// don't EXPECT a call to mockSender.onStreamCompleted()
|
||||
str.sendStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
|
||||
It("is not finished when the send side is finished", func() {
|
||||
str.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeTrue())
|
||||
Expect(str.sendStream.finished()).To(BeTrue())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
It("is completed when both sides are completed", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.sendStream.sender.onStreamCompleted(streamID)
|
||||
str.receiveStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -16,9 +16,6 @@ type streamsMap struct {
|
|||
perspective protocol.Perspective
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
// needed for round-robin scheduling
|
||||
openStreams []protocol.StreamID
|
||||
roundRobinIndex int
|
||||
|
||||
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
|
@ -51,7 +48,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
|
|||
sm := streamsMap{
|
||||
perspective: pers,
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
openStreams: make([]protocol.StreamID, 0),
|
||||
newStream: newStream,
|
||||
maxIncomingStreams: maxIncomingStreams,
|
||||
}
|
||||
|
@ -99,7 +95,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
|||
s, ok := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
if ok {
|
||||
return s, nil // s may be nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ... we don't have an existing stream
|
||||
|
@ -212,48 +208,19 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
|
|||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) DeleteClosedStreams() error {
|
||||
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var numDeletedStreams int
|
||||
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
|
||||
for i, streamID := range m.openStreams {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
if !str.finished() {
|
||||
continue
|
||||
}
|
||||
numDeletedStreams++
|
||||
m.openStreams[i] = 0
|
||||
if m.streamInitiatedBy(streamID) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
delete(m.streams, streamID)
|
||||
_, ok := m.streams[id]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
|
||||
if numDeletedStreams == 0 {
|
||||
return nil
|
||||
delete(m.streams, id)
|
||||
if m.streamInitiatedBy(id) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
|
||||
// remove all 0s (representing closed streams) from the openStreams slice
|
||||
// and adjust the roundRobinIndex
|
||||
var j int
|
||||
for i, id := range m.openStreams {
|
||||
if i != j {
|
||||
m.openStreams[j] = m.openStreams[i]
|
||||
}
|
||||
if id != 0 {
|
||||
j++
|
||||
} else if j < m.roundRobinIndex {
|
||||
m.roundRobinIndex--
|
||||
}
|
||||
}
|
||||
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
|
||||
m.openStreamOrErrCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
@ -264,28 +231,16 @@ func (m *streamsMap) Range(cb func(s streamI)) {
|
|||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, s := range m.streams {
|
||||
if s != nil {
|
||||
cb(s)
|
||||
}
|
||||
cb(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return true, errMapAccess
|
||||
}
|
||||
return fn(str)
|
||||
}
|
||||
|
||||
func (m *streamsMap) putStream(s streamI) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||
}
|
||||
|
||||
m.streams[id] = s
|
||||
m.openStreams = append(m.openStreams, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -295,8 +250,8 @@ func (m *streamsMap) CloseWithError(err error) {
|
|||
m.closeErr = err
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
for _, s := range m.openStreams {
|
||||
m.streams[s].closeForShutdown(err)
|
||||
for _, s := range m.streams {
|
||||
s.closeForShutdown(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,22 +7,16 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Streams Map", func() {
|
||||
var (
|
||||
m *streamsMap
|
||||
finishedStreams map[protocol.StreamID]*gomock.Call
|
||||
)
|
||||
var m *streamsMap
|
||||
|
||||
newStream := func(id protocol.StreamID) streamI {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(id).AnyTimes()
|
||||
c := str.EXPECT().finished().Return(false).AnyTimes()
|
||||
finishedStreams[id] = c
|
||||
return str
|
||||
}
|
||||
|
||||
|
@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() {
|
|||
m = newStreamsMap(newStream, p, v)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(m.openStreams).To(HaveLen(len(m.streams)))
|
||||
})
|
||||
|
||||
deleteStream := func(id protocol.StreamID) {
|
||||
str := m.streams[id]
|
||||
Expect(str).ToNot(BeNil())
|
||||
finishedStreams[id].Return(true)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
|
||||
}
|
||||
|
||||
Context("getting and creating streams", func() {
|
||||
|
@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("DoS mitigation, iterating and deleting", func() {
|
||||
Context("Ranging", func() {
|
||||
It("ranges over all open streams", func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
var callbackCalledForStream []protocol.StreamID
|
||||
callback := func(str streamI) {
|
||||
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
|
||||
sort.Slice(callbackCalledForStream, func(i, j int) bool {
|
||||
return callbackCalledForStream[i] < callbackCalledForStream[j]
|
||||
})
|
||||
}
|
||||
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
// create 5 streams, ids 4 to 8
|
||||
callbackCalledForStream = callbackCalledForStream[:0]
|
||||
for i := 4; i <= 8; i++ {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
|
||||
err := m.putStream(str)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
// execute the callback for all streams
|
||||
m.Range(callback)
|
||||
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("deleting streams", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
|
||||
})
|
||||
|
||||
closeStream := func(id protocol.StreamID) {
|
||||
str := m.streams[id]
|
||||
ExpectWithOffset(1, str).ToNot(BeNil())
|
||||
finishedStreams[id].Return(true)
|
||||
}
|
||||
|
||||
Context("deleting streams", func() {
|
||||
Context("as a server", func() {
|
||||
BeforeEach(func() {
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
for i := 1; i <= 5; i++ {
|
||||
if i%2 == 1 {
|
||||
_, err := m.openRemoteStream(protocol.StreamID(i))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
|
||||
})
|
||||
|
||||
It("does not delete streams with Close()", func() {
|
||||
str, err := m.GetOrOpenStream(55)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().Close()
|
||||
str.Close()
|
||||
err = m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err = m.GetOrOpenStream(55)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("removes the first stream", func() {
|
||||
closeStream(1)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a stream in the middle", func() {
|
||||
closeStream(3)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a client-initiated stream", func() {
|
||||
closeStream(2)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
|
||||
})
|
||||
|
||||
It("removes a stream at the end", func() {
|
||||
closeStream(5)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes all streams", func() {
|
||||
for i := 1; i <= 5; i++ {
|
||||
closeStream(protocol.StreamID(i))
|
||||
}
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
Expect(m.openStreams).To(BeEmpty())
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("as a client", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
for i := 1; i <= 5; i++ {
|
||||
if i%2 == 0 {
|
||||
_, err := m.openRemoteStream(protocol.StreamID(i))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
|
||||
})
|
||||
|
||||
It("removes a stream that we initiated", func() {
|
||||
closeStream(3)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a stream that the server initiated", func() {
|
||||
closeStream(2)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
It("removes all streams", func() {
|
||||
closeStream(3)
|
||||
closeStream(2)
|
||||
closeStream(5)
|
||||
closeStream(4)
|
||||
closeStream(7)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
Expect(m.openStreams).To(BeEmpty())
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
})
|
||||
It("deletes an incoming stream", func() {
|
||||
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
err = m.DeleteStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(1))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
Context("Ranging", func() {
|
||||
// create 5 streams, ids 4 to 8
|
||||
var callbackCalledForStream []protocol.StreamID
|
||||
callback := func(str streamI) {
|
||||
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
|
||||
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
|
||||
}
|
||||
It("deletes an outgoing stream", func() {
|
||||
m.UpdateMaxStreamLimit(10000)
|
||||
_, err := m.OpenStream() // open stream 2
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
err = m.DeleteStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
callbackCalledForStream = callbackCalledForStream[:0]
|
||||
for i := 4; i <= 8; i++ {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
|
||||
err := m.putStream(str)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
It("ranges over all open streams", func() {
|
||||
m.Range(callback)
|
||||
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||
})
|
||||
It("errors when the stream doesn't exist", func() {
|
||||
err := m.DeleteStream(1337)
|
||||
Expect(err).To(MatchError(errMapAccess))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue