immediately delete a stream when it is completed

By introducing a callback to the stream, which the stream calls as soon
as it is completed, we can get rid of checking every single open stream
if it is completed.
This commit is contained in:
Marten Seemann 2017-12-25 16:32:29 +07:00
parent 843a0786fc
commit 8a3f807a12
12 changed files with 199 additions and 423 deletions

View file

@ -169,18 +169,6 @@ func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
}
// finished mocks base method
func (m *MockStreamI) finished() bool {
ret := m.ctrl.Call(m, "finished")
ret0, _ := ret[0].(bool)
return ret0
}
// finished indicates an expected call of finished
func (mr *MockStreamIMockRecorder) finished() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "finished", reflect.TypeOf((*MockStreamI)(nil).finished))
}
// getWindowUpdate mocks base method
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")

View file

@ -55,6 +55,16 @@ func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
}
// onStreamCompleted mocks base method
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onStreamCompleted", arg0)
}
// onStreamCompleted indicates an expected call of onStreamCompleted
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
}
// queueControlFrame mocks base method
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
m.ctrl.Call(m, "queueControlFrame", arg0)

View file

@ -148,6 +148,7 @@ func (s *receiveStream) Read(p []byte) (int, error) {
s.frameQueue.Pop()
s.finRead = frame.FinBit
if frame.FinBit {
s.sender.onStreamCompleted(s.streamID)
return bytesRead, io.EOF
}
}
@ -219,6 +220,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
s.signalRead()
s.sender.onStreamCompleted(s.streamID)
return nil
}
@ -259,14 +261,6 @@ func (s *receiveStream) closeForShutdown(err error) {
s.signalRead()
}
func (s *receiveStream) finished() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
s.finRead || s.resetRemotely
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}

View file

@ -33,7 +33,6 @@ var _ = Describe("Receive Stream", func() {
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutReader(str, timeout)
strWithTimeout = str
})
It("gets stream id", func() {
@ -320,12 +319,12 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
FinBit: true,
}
str.handleStreamFrame(&frame)
})
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
@ -354,6 +353,7 @@ var _ = Describe("Receive Stream", func() {
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
@ -368,31 +368,30 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
Data: []byte{0xde, 0xad},
FinBit: true,
}
err := str.handleStreamFrame(&frame)
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD}))
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
})
It("handles immediate FINs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{},
FinBit: true,
}
err := str.handleStreamFrame(&frame)
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
@ -405,6 +404,7 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
str.CloseRemote(0)
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 8)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
@ -486,6 +486,7 @@ var _ = Describe("Receive Stream", func() {
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err = strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
err = str.CancelRead(1234)
@ -526,11 +527,13 @@ var _ = Describe("Receive Stream", func() {
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleRstStreamFrame(rst)
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
@ -549,6 +552,7 @@ var _ = Describe("Receive Stream", func() {
})
It("ignores duplicate RST_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
@ -580,6 +584,7 @@ var _ = Describe("Receive Stream", func() {
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
@ -587,8 +592,11 @@ var _ = Describe("Receive Stream", func() {
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
gomock.InOrder(
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
mockSender.EXPECT().onStreamCompleted(streamID),
)
mockFC.EXPECT().HasWindowUpdate().Times(2)
readReturned := make(chan struct{})
go func() {
@ -637,43 +645,4 @@ var _ = Describe("Receive Stream", func() {
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
})
})
Context("saying if it is finished", func() {
finishReading := func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
err := str.handleStreamFrame(&wire.StreamFrame{FinBit: true})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 100)
_, err = strWithTimeout.Read(b)
ExpectWithOffset(0, err).To(MatchError(io.EOF))
}
It("is finished after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("testErr"))
Expect(str.finished()).To(BeTrue())
})
It("is finished if it is only closed for reading", func() {
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
finishReading()
Expect(str.finished()).To(BeTrue())
})
// the stream still needs to stay alive until we receive the final offset
// (either by receiving a STREAM frame with FIN, or a RST_STREAM)
It("is not finished after CancelRead", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(123)
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeFalse())
})
It("is finished after receiving a RST_STREAM frame", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(3), true)
err := str.handleRstStreamFrame(&wire.RstStreamFrame{ByteOffset: 3})
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeTrue())
})
})
})

View file

@ -154,6 +154,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr
}
if frame.FinBit {
s.finSent = true
s.sender.onStreamCompleted(s.streamID)
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
@ -231,6 +232,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr
})
// TODO(#991): cancel retransmissions for this stream
s.ctxCancel()
s.sender.onStreamCompleted(s.streamID)
return nil
}
@ -289,14 +291,6 @@ func (s *sendStream) closeForShutdown(err error) {
s.ctxCancel()
}
func (s *sendStream) finished() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
s.finSent || s.canceledWrite
}
func (s *sendStream) getWriteOffset() protocol.ByteCount {
return s.writeOffset
}

View file

@ -247,6 +247,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't EXPECT a call to mockFC.IsBlocked
@ -389,6 +390,7 @@ var _ = Describe("Send Stream", func() {
It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
@ -409,6 +411,7 @@ var _ = Describe("Send Stream", func() {
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
mockSender.EXPECT().onStreamCompleted(streamID)
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeTrue())
@ -423,6 +426,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
@ -513,6 +517,7 @@ var _ = Describe("Send Stream", func() {
ByteOffset: 1234,
ErrorCode: 9876,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.writeOffset = 1234
err := str.CancelWrite(9876)
Expect(err).ToNot(HaveOccurred())
@ -520,6 +525,7 @@ var _ = Describe("Send Stream", func() {
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
@ -544,6 +550,7 @@ var _ = Describe("Send Stream", func() {
It("cancels the context", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.CancelWrite(1234)
Expect(str.Context().Done()).To(BeClosed())
@ -551,6 +558,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Write([]byte("foobar"))
@ -559,6 +567,7 @@ var _ = Describe("Send Stream", func() {
It("only cancels once", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(4321)
@ -580,6 +589,7 @@ var _ = Describe("Send Stream", func() {
StreamID: streamID,
ErrorCode: errorCodeStopping,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 101,
@ -600,6 +610,7 @@ var _ = Describe("Send Stream", func() {
close(done)
}()
waitForWrite()
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
@ -609,6 +620,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
@ -621,32 +633,4 @@ var _ = Describe("Send Stream", func() {
})
})
})
Context("saying if it is finished", func() {
It("is finished after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("testErr"))
Expect(str.finished()).To(BeTrue())
})
It("is finished after Close()", func() {
mockSender.EXPECT().onHasStreamData(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f.FinBit).To(BeTrue())
Expect(str.finished()).To(BeTrue())
})
It("is finished after CancelWrite", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelWrite(123)
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeTrue())
})
It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID})
Expect(str.finished()).To(BeTrue())
})
})
})

View file

@ -408,10 +408,6 @@ runLoop:
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
s.closeLocal(err)
}
}
// only send the error the handshakeChan when the handshake is not completed yet
@ -950,6 +946,12 @@ func (s *session) onHasStreamData(id protocol.StreamID) {
s.scheduleSending()
}
func (s *session) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.Close(err)
}
}
func (s *session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}

View file

@ -122,7 +122,7 @@ var _ = Describe("Session", func() {
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(BeEmpty())
Expect(sess.streamsMap.streams).To(BeEmpty())
})
AfterEach(func() {
@ -195,9 +195,6 @@ var _ = Describe("Session", func() {
sess.streamsMap.newStream = func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
if id == 1 {
str.EXPECT().finished().AnyTimes()
}
return str
}
})
@ -247,9 +244,9 @@ var _ = Describe("Session", func() {
return str
}
sess.handleStreamFrame(f1)
numOpenStreams := len(sess.streamsMap.openStreams)
numOpenStreams := len(sess.streamsMap.streams)
sess.handleStreamFrame(f2)
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
Expect(sess.streamsMap.streams).To(HaveLen(numOpenStreams))
})
It("ignores STREAM frames for closed streams", func() {
@ -329,8 +326,7 @@ var _ = Describe("Session", func() {
It("ignores the error when the stream is not known", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
sess.streamsMap.DeleteClosedStreams()
sess.onStreamCompleted(3)
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@ -411,9 +407,7 @@ var _ = Describe("Session", func() {
It("ignores MAX_STREAM_DATA frames for a closed stream", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
sess.onStreamCompleted(3)
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@ -457,8 +451,7 @@ var _ = Describe("Session", func() {
It("ignores STOP_SENDING frames for a closed stream", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
err = sess.streamsMap.DeleteClosedStreams()
sess.onStreamCompleted(3)
Expect(err).ToNot(HaveOccurred())
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
@ -1403,12 +1396,7 @@ var _ = Describe("Session", func() {
It("returns a nil-value (not an interface with value nil) for closed streams", func() {
str, err := sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
str.Close()
str.(*stream).closeForShutdown(nil)
Expect(str.(*stream).finished()).To(BeTrue())
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
sess.onStreamCompleted(9)
str, err = sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@ -1425,31 +1413,6 @@ var _ = Describe("Session", func() {
})
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := 0; i < protocol.MaxIncomingStreams; i++ {
_, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
}
_, err := sess.GetOrOpenStream(protocol.StreamID(301))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i <= 1000; i++ {
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
Expect(s.Close()).To(Succeed())
f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit
Expect(f.FinBit).To(BeTrue())
s.(*stream).CloseRemote(0)
_, err = s.Read([]byte("a"))
Expect(err).To(MatchError(io.EOF))
sess.streamsMap.DeleteClosedStreams()
}
})
})
Context("ignoring errors", func() {
It("ignores duplicate acks", func() {
sess.sentPacketHandler.SentPacket(&ackhandler.Packet{
@ -1522,7 +1485,7 @@ var _ = Describe("Client Session", func() {
)
sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.openStreams).To(BeEmpty())
Expect(sess.streamsMap.streams).To(BeEmpty())
})
AfterEach(func() {

View file

@ -2,6 +2,7 @@ package quic
import (
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
@ -19,8 +20,34 @@ type streamSender interface {
queueControlFrame(wire.Frame)
onHasWindowUpdate(protocol.StreamID)
onHasStreamData(protocol.StreamID)
onStreamCompleted(protocol.StreamID)
}
// Each of the both stream halves gets its own uniStreamSender.
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
}
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
s.streamSender.onHasWindowUpdate(id)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
}
var _ streamSender = &uniStreamSender{}
type streamI interface {
Stream
@ -28,7 +55,6 @@ type streamI interface {
handleRstStreamFrame(*wire.RstStreamFrame) error
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
finished() bool
closeForShutdown(error)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
@ -42,6 +68,11 @@ type stream struct {
receiveStream
sendStream
completedMutex sync.Mutex
sender streamSender
receiveStreamCompleted bool
sendStreamCompleted bool
version protocol.VersionNumber
}
@ -72,10 +103,28 @@ func newStream(streamID protocol.StreamID,
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *stream {
return &stream{
sendStream: *newSendStream(streamID, sender, flowController, version),
receiveStream: *newReceiveStream(streamID, sender, flowController),
s := &stream{sender: sender}
senderForSendStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.sendStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.receiveStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
}
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
@ -120,6 +169,10 @@ func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
return nil
}
func (s *stream) finished() bool {
return s.sendStream.finished() && s.receiveStream.finished()
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
func (s *stream) checkIfCompleted() {
if s.sendStreamCompleted && s.receiveStreamCompleted {
s.sender.onStreamCompleted(s.StreamID())
}
}

View file

@ -1,7 +1,6 @@
package quic
import (
"errors"
"io"
"os"
"strconv"
@ -72,6 +71,7 @@ var _ = Describe("Stream", func() {
ByteOffset: 1000,
ErrorCode: errorCodeStoppingGQUIC,
})
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
str.writeOffset = 1000
f := &wire.RstStreamFrame{
@ -189,26 +189,21 @@ var _ = Describe("Stream", func() {
})
})
Context("saying if it is finished", func() {
It("is finished when both the send and the receive side are finished", func() {
str.receiveStream.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeTrue())
Expect(str.sendStream.finished()).To(BeFalse())
Expect(str.finished()).To(BeFalse())
Context("completing", func() {
It("is not completed when only the receive side is completed", func() {
// don't EXPECT a call to mockSender.onStreamCompleted()
str.receiveStream.sender.onStreamCompleted(streamID)
})
It("is not finished when the receive side is finished", func() {
str.sendStream.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeFalse())
Expect(str.sendStream.finished()).To(BeTrue())
Expect(str.finished()).To(BeFalse())
It("is not completed when only the send side is completed", func() {
// don't EXPECT a call to mockSender.onStreamCompleted()
str.sendStream.sender.onStreamCompleted(streamID)
})
It("is not finished when the send side is finished", func() {
str.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeTrue())
Expect(str.sendStream.finished()).To(BeTrue())
Expect(str.finished()).To(BeTrue())
It("is completed when both sides are completed", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
str.sendStream.sender.onStreamCompleted(streamID)
str.receiveStream.sender.onStreamCompleted(streamID)
})
})
})

View file

@ -16,9 +16,6 @@ type streamsMap struct {
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
// needed for round-robin scheduling
openStreams []protocol.StreamID
roundRobinIndex int
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
@ -51,7 +48,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
sm := streamsMap{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
}
@ -99,7 +95,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
return s, nil
}
// ... we don't have an existing stream
@ -212,48 +208,19 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
return str, nil
}
func (m *streamsMap) DeleteClosedStreams() error {
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for i, streamID := range m.openStreams {
str, ok := m.streams[streamID]
if !ok {
return errMapAccess
}
if !str.finished() {
continue
}
numDeletedStreams++
m.openStreams[i] = 0
if m.streamInitiatedBy(streamID) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
if numDeletedStreams == 0 {
return nil
delete(m.streams, id)
if m.streamInitiatedBy(id) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
}
if id != 0 {
j++
} else if j < m.roundRobinIndex {
m.roundRobinIndex--
}
}
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
m.openStreamOrErrCond.Signal()
return nil
}
@ -264,28 +231,16 @@ func (m *streamsMap) Range(cb func(s streamI)) {
defer m.mutex.RUnlock()
for _, s := range m.streams {
if s != nil {
cb(s)
}
cb(s)
}
}
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
func (m *streamsMap) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil
}
@ -295,8 +250,8 @@ func (m *streamsMap) CloseWithError(err error) {
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.openStreams {
m.streams[s].closeForShutdown(err)
for _, s := range m.streams {
s.closeForShutdown(err)
}
}

View file

@ -7,22 +7,16 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
finishedStreams map[protocol.StreamID]*gomock.Call
)
var m *streamsMap
newStream := func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
c := str.EXPECT().finished().Return(false).AnyTimes()
finishedStreams[id] = c
return str
}
@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() {
m = newStreamsMap(newStream, p, v)
}
BeforeEach(func() {
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
})
AfterEach(func() {
Expect(m.openStreams).To(HaveLen(len(m.streams)))
})
deleteStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
finishedStreams[id].Return(true)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
}
Context("getting and creating streams", func() {
@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() {
})
})
Context("DoS mitigation, iterating and deleting", func() {
Context("Ranging", func() {
It("ranges over all open streams", func() {
setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever)
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool {
return callbackCalledForStream[i] < callbackCalledForStream[j]
})
}
Expect(m.streams).To(BeEmpty())
// create 5 streams, ids 4 to 8
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
// execute the callback for all streams
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
})
Context("deleting streams", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
})
closeStream := func(id protocol.StreamID) {
str := m.streams[id]
ExpectWithOffset(1, str).ToNot(BeNil())
finishedStreams[id].Return(true)
}
Context("deleting streams", func() {
Context("as a server", func() {
BeforeEach(func() {
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 1 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
})
It("does not delete streams with Close()", func() {
str, err := m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().Close()
str.Close()
err = m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
It("removes the first stream", func() {
closeStream(1)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream in the middle", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a client-initiated stream", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("removes a stream at the end", func() {
closeStream(5)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes all streams", func() {
for i := 1; i <= 5; i++ {
closeStream(protocol.StreamID(i))
}
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 0 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
})
It("removes a stream that we initiated", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream that the server initiated", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("removes all streams", func() {
closeStream(3)
closeStream(2)
closeStream(5)
closeStream(4)
closeStream(7)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
It("deletes an incoming stream", func() {
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(1))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
Context("Ranging", func() {
// create 5 streams, ids 4 to 8
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
}
It("deletes an outgoing stream", func() {
m.UpdateMaxStreamLimit(10000)
_, err := m.OpenStream() // open stream 2
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
BeforeEach(func() {
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
})
It("ranges over all open streams", func() {
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
It("errors when the stream doesn't exist", func() {
err := m.DeleteStream(1337)
Expect(err).To(MatchError(errMapAccess))
})
})
})