diff --git a/crypto_stream.go b/crypto_stream.go index 8639e971..8b8e8bee 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -27,8 +27,8 @@ type cryptoStream struct { var _ cryptoStreamI = &cryptoStream{} -func newCryptoStream(onData func(), flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { - str := newStream(version.CryptoStreamID(), onData, nil, flowController, version) +func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { + str := newStream(version.CryptoStreamID(), sender, flowController, version) return &cryptoStream{str} } diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 61ef4b04..2debfe05 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -7,10 +7,16 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Stream", func() { - var str *cryptoStream +var _ = Describe("Crypto Stream", func() { + var ( + str *cryptoStream + mockSender *MockStreamSender + ) - str = newCryptoStream(func() {}, nil, protocol.VersionWhatever).(*cryptoStream) + BeforeEach(func() { + mockSender = NewMockStreamSender(mockCtrl) + str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStream) + }) It("sets the read offset", func() { str.setReadOffset(0x42) @@ -19,6 +25,7 @@ var _ = Describe("Stream", func() { }) It("says if it has data for writing", func() { + mockSender.EXPECT().scheduleSending() Expect(str.hasDataForWriting()).To(BeFalse()) done := make(chan struct{}) go func() { diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go new file mode 100644 index 00000000..604fb925 --- /dev/null +++ b/mock_stream_sender_test.go @@ -0,0 +1,53 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamSender) + +package quic + +import ( + gomock "github.com/golang/mock/gomock" + wire "github.com/lucas-clemente/quic-go/internal/wire" + reflect "reflect" +) + +// MockStreamSender is a mock of StreamSender interface +type MockStreamSender struct { + ctrl *gomock.Controller + recorder *MockStreamSenderMockRecorder +} + +// MockStreamSenderMockRecorder is the mock recorder for MockStreamSender +type MockStreamSenderMockRecorder struct { + mock *MockStreamSender +} + +// NewMockStreamSender creates a new mock instance +func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { + mock := &MockStreamSender{ctrl: ctrl} + mock.recorder = &MockStreamSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { + return _m.recorder +} + +// queueControlFrame mocks base method +func (_m *MockStreamSender) queueControlFrame(_param0 wire.Frame) { + _m.ctrl.Call(_m, "queueControlFrame", _param0) +} + +// queueControlFrame indicates an expected call of queueControlFrame +func (_mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) +} + +// scheduleSending mocks base method +func (_m *MockStreamSender) scheduleSending() { + _m.ctrl.Call(_m, "scheduleSending") +} + +// scheduleSending indicates an expected call of scheduleSending +func (_mr *MockStreamSenderMockRecorder) scheduleSending() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "scheduleSending", reflect.TypeOf((*MockStreamSender)(nil).scheduleSending)) +} diff --git a/mockgen.go b/mockgen.go index 0aabefba..21997ace 100644 --- a/mockgen.go +++ b/mockgen.go @@ -1,3 +1,4 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" diff --git a/packet_packer_test.go b/packet_packer_test.go index f1f5023d..a75b2407 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -61,7 +61,9 @@ var _ = Describe("Packet packer", func() { BeforeEach(func() { version := versionGQUICFrames - cryptoStream = newCryptoStream(func() {}, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) + mockSender := NewMockStreamSender(mockCtrl) + mockSender.EXPECT().scheduleSending().AnyTimes() + cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) streamFramer = newStreamFramer(cryptoStream, streamsMap, versionGQUICFrames) diff --git a/receive_stream.go b/receive_stream.go index 0bf1b68a..791184d7 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -17,11 +17,7 @@ type receiveStream struct { streamID protocol.StreamID - // onData tells the session that there's stuff to pack into a new packet - onData func() - // queueControlFrame queues a new control frame for sending - // it does not call onData - queueControlFrame func(wire.Frame) + sender streamSender frameQueue *streamFrameSorter readPosInFrame int @@ -47,17 +43,15 @@ var _ ReceiveStream = &receiveStream{} func newReceiveStream( streamID protocol.StreamID, - onData func(), - queueControlFrame func(wire.Frame), + sender streamSender, flowController flowcontrol.StreamFlowController, ) *receiveStream { return &receiveStream{ - streamID: streamID, - onData: onData, - queueControlFrame: queueControlFrame, - flowController: flowController, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), } } @@ -145,7 +139,7 @@ func (s *receiveStream) Read(p []byte) (int, error) { if !s.resetRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } - s.onData() // so that a possible WINDOW_UPDATE is sent + s.sender.scheduleSending() // so that a possible WINDOW_UPDATE is sent if s.readPosInFrame >= int(frame.DataLen()) { s.frameQueue.Pop() @@ -172,7 +166,7 @@ func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) erro s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) s.signalRead() if s.version.UsesIETFFrameFormat() { - s.queueControlFrame(&wire.StopSendingFrame{ + s.sender.queueControlFrame(&wire.StopSendingFrame{ StreamID: s.streamID, ErrorCode: errorCode, }) @@ -231,12 +225,11 @@ func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { func (s *receiveStream) onClose(offset protocol.ByteCount) { if s.canceledRead && !s.version.UsesIETFFrameFormat() { - s.queueControlFrame(&wire.RstStreamFrame{ + s.sender.queueControlFrame(&wire.RstStreamFrame{ StreamID: s.streamID, ByteOffset: offset, ErrorCode: 0, }) - s.onData() } } diff --git a/receive_stream_test.go b/receive_stream_test.go index 1b93bf48..73d70724 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -6,6 +6,7 @@ import ( "runtime" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -19,21 +20,16 @@ var _ = Describe("Receive Stream", func() { const streamID protocol.StreamID = 1337 var ( - str *receiveStream - strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader - onDataCalled bool - queuedControlFrames []wire.Frame - mockFC *mocks.MockStreamFlowController + str *receiveStream + strWithTimeout io.Reader // str wrapped with gbytes.TimeoutReader + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender ) - onData := func() { onDataCalled = true } - queueControlFrame := func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) } - BeforeEach(func() { - queuedControlFrames = queuedControlFrames[:0] - onDataCalled = false + mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newReceiveStream(streamID, onData, queueControlFrame, mockFC) + str = newReceiveStream(streamID, mockSender, mockFC) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = gbytes.TimeoutReader(str, timeout) @@ -45,6 +41,7 @@ var _ = Describe("Receive Stream", func() { Context("reading", func() { It("reads a single STREAM frame", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame := wire.StreamFrame{ @@ -61,6 +58,7 @@ var _ = Describe("Receive Stream", func() { }) It("reads a single STREAM frame in multiple goes", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) @@ -82,6 +80,7 @@ var _ = Describe("Receive Stream", func() { }) It("reads all data available", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) @@ -104,7 +103,8 @@ var _ = Describe("Receive Stream", func() { Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00})) }) - It("assembles multiple StreamFrames", func() { + It("assembles multiple STREAM frames", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) @@ -128,6 +128,7 @@ var _ = Describe("Receive Stream", func() { }) It("waits until data is available", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) go func() { @@ -144,6 +145,7 @@ var _ = Describe("Receive Stream", func() { }) It("handles STREAM frames in wrong order", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) @@ -167,6 +169,7 @@ var _ = Describe("Receive Stream", func() { }) It("ignores duplicate STREAM frames", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) @@ -197,6 +200,7 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't rejects a STREAM frames with an overlapping data range", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) @@ -220,21 +224,6 @@ var _ = Describe("Receive Stream", func() { Expect(b).To(Equal([]byte("foobar"))) }) - It("calls onData", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - frame := wire.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.handleStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - _, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(onDataCalled).To(BeTrue()) - }) - It("passes on errors from the streamFrameSorter", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), false) err := str.handleStreamFrame(&wire.StreamFrame{StreamID: streamID}) // STREAM frame without data @@ -311,6 +300,7 @@ var _ = Describe("Receive Stream", func() { Context("closing", func() { Context("with FIN bit", func() { It("returns EOFs", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame := wire.StreamFrame{ @@ -330,6 +320,7 @@ var _ = Describe("Receive Stream", func() { }) It("handles out-of-order frames", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) @@ -357,6 +348,7 @@ var _ = Describe("Receive Stream", func() { }) It("returns EOFs with partial read", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) frame := wire.StreamFrame{ @@ -374,6 +366,7 @@ var _ = Describe("Receive Stream", func() { }) It("handles immediate FINs", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) frame := wire.StreamFrame{ @@ -391,6 +384,7 @@ var _ = Describe("Receive Stream", func() { }) It("closes when CloseRemote is called", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) str.CloseRemote(0) @@ -432,6 +426,7 @@ var _ = Describe("Receive Stream", func() { Context("stream cancelations", func() { Context("canceling read", func() { It("unblocks Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -446,6 +441,7 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelRead(1234) Expect(err).ToNot(HaveOccurred()) _, err = strWithTimeout.Read([]byte{0}) @@ -453,6 +449,7 @@ var _ = Describe("Receive Stream", func() { }) It("does nothing when CancelRead is called twice", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelRead(1234) Expect(err).ToNot(HaveOccurred()) err = str.CancelRead(2345) @@ -462,8 +459,10 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + // no calls to mockSender.queueControlFrame err := str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, Data: []byte("foobar"), @@ -474,20 +473,19 @@ var _ = Describe("Receive Stream", func() { Expect(err).To(MatchError(io.EOF)) err = str.CancelRead(1234) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(BeEmpty()) // no RST_STREAM frame queued yet }) - Context("for IETF QUIC", func() { - It("queues a STOP_SENDING frame", func() { - err := str.CancelRead(1234) - Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.StopSendingFrame{ - StreamID: streamID, - ErrorCode: 1234, - }, - })) + It("queues a STOP_SENDING frame, for IETF QUIC", func() { + mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, }) + err := str.CancelRead(1234) + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't queue a STOP_SENDING frame, for gQUIC", func() { + }) }) @@ -570,7 +568,8 @@ var _ = Describe("Receive Stream", func() { Eventually(readReturned).Should(BeClosed()) }) - It("sends a RST_STREAM and continues reading until the end when receiving a RST_STREAM frame with error code 0", func() { + It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() { + mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) @@ -600,7 +599,6 @@ var _ = Describe("Receive Stream", func() { Expect(err).ToNot(HaveOccurred()) Eventually(readReturned).Should(BeClosed()) }) - }) }) }) @@ -639,6 +637,7 @@ var _ = Describe("Receive Stream", func() { }) It("is finished if it is only closed for reading", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) finishReading() Expect(str.finished()).To(BeTrue()) @@ -647,6 +646,7 @@ var _ = Describe("Receive Stream", func() { // the stream still needs to stay alive until we receive the final offset // (either by receiving a STREAM frame with FIN, or a RST_STREAM) It("is not finished after CancelRead", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelRead(123) Expect(err).ToNot(HaveOccurred()) Expect(str.finished()).To(BeFalse()) diff --git a/send_stream.go b/send_stream.go index d82dee14..fc1044cd 100644 --- a/send_stream.go +++ b/send_stream.go @@ -19,11 +19,7 @@ type sendStream struct { ctxCancel context.CancelFunc streamID protocol.StreamID - // onData tells the session that there's stuff to pack into a new packet - onData func() - // queueControlFrame queues a new control frame for sending - // it does not call onData - queueControlFrame func(wire.Frame) + sender streamSender writeOffset protocol.ByteCount @@ -47,18 +43,16 @@ var _ SendStream = &sendStream{} func newSendStream( streamID protocol.StreamID, - onData func(), - queueControlFrame func(wire.Frame), + sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *sendStream { s := &sendStream{ - streamID: streamID, - onData: onData, - queueControlFrame: queueControlFrame, - flowController: flowController, - writeChan: make(chan struct{}, 1), - version: version, + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s @@ -90,7 +84,7 @@ func (s *sendStream) Write(p []byte) (int, error) { s.dataForWriting = make([]byte, len(p)) copy(s.dataForWriting, p) - s.onData() + s.sender.scheduleSending() var bytesWritten int var err error @@ -153,11 +147,10 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFra s.finSent = true } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { - s.queueControlFrame(&wire.StreamBlockedFrame{ + s.sender.queueControlFrame(&wire.StreamBlockedFrame{ StreamID: s.streamID, Offset: offset, }) - s.onData() } } return frame @@ -199,7 +192,6 @@ func (s *sendStream) Close() error { } s.finishedWriting = true s.ctxCancel() - s.onData() return nil } @@ -221,13 +213,12 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr s.canceledWrite = true s.cancelWriteErr = writeErr s.signalWrite() - s.queueControlFrame(&wire.RstStreamFrame{ + s.sender.queueControlFrame(&wire.RstStreamFrame{ StreamID: s.streamID, ByteOffset: s.writeOffset, ErrorCode: errorCode, }) // TODO(#991): cancel retransmissions for this stream - s.onData() s.ctxCancel() return nil } diff --git a/send_stream_test.go b/send_stream_test.go index f8b77278..608c9e00 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -21,21 +21,16 @@ var _ = Describe("Send Stream", func() { const streamID protocol.StreamID = 1337 var ( - str *sendStream - strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter - onDataCalled bool - queuedControlFrames []wire.Frame - mockFC *mocks.MockStreamFlowController + str *sendStream + strWithTimeout io.Writer // str wrapped with gbytes.TimeoutWriter + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender ) - onData := func() { onDataCalled = true } - queueControlFrame := func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) } - BeforeEach(func() { - queuedControlFrames = queuedControlFrames[:0] - onDataCalled = false + mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newSendStream(streamID, onData, queueControlFrame, mockFC, protocol.VersionWhatever) + str = newSendStream(streamID, mockSender, mockFC, protocol.VersionWhatever) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = gbytes.TimeoutWriter(str, timeout) @@ -47,6 +42,7 @@ var _ = Describe("Send Stream", func() { Context("writing", func() { It("writes and gets all data at once", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) mockFC.EXPECT().IsNewlyBlocked() @@ -64,7 +60,6 @@ var _ = Describe("Send Stream", func() { f = str.popStreamFrame(1000) return f }).ShouldNot(BeNil()) - Expect(onDataCalled).To(BeTrue()) Expect(f.Data).To(Equal([]byte("foobar"))) Expect(f.FinBit).To(BeFalse()) Expect(f.Offset).To(BeZero()) @@ -75,6 +70,7 @@ var _ = Describe("Send Stream", func() { }) It("writes and gets data in two turns", func() { + mockSender.EXPECT().scheduleSending() frameHeaderLen := protocol.ByteCount(4) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2) @@ -111,6 +107,7 @@ var _ = Describe("Send Stream", func() { }) It("copies the slice while writing", func() { + mockSender.EXPECT().scheduleSending() frameHeaderSize := protocol.ByteCount(4) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) @@ -152,6 +149,11 @@ var _ = Describe("Send Stream", func() { Context("adding BLOCKED", func() { It("queues a BLOCKED frame if the stream is flow control blocked", func() { + mockSender.EXPECT().scheduleSending() + mockSender.EXPECT().queueControlFrame(&wire.StreamBlockedFrame{ + StreamID: streamID, + Offset: 10, + }) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) // don't use offset 6 here, to make sure the BLOCKED frame contains the number returned by the flow controller @@ -168,20 +170,15 @@ var _ = Describe("Send Stream", func() { f = str.popStreamFrame(1000) return f }).ShouldNot(BeNil()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.StreamBlockedFrame{ - StreamID: streamID, - Offset: 10, - }, - })) - Expect(onDataCalled).To(BeTrue()) Eventually(done).Should(BeClosed()) }) It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - // don't EXPECT a call to IsNewlyBlocked + // don't EXPECT a call to mockFC.IsNewlyBlocked + // don't EXPECT a call to mockSender.queueControlFrame done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -197,7 +194,6 @@ var _ = Describe("Send Stream", func() { return f }).ShouldNot(BeNil()) Expect(f.FinBit).To(BeTrue()) - Expect(queuedControlFrames).To(BeEmpty()) Eventually(done).Should(BeClosed()) }) }) @@ -208,10 +204,10 @@ var _ = Describe("Send Stream", func() { n, err := strWithTimeout.Write([]byte("foobar")) Expect(err).To(MatchError(errDeadline)) Expect(n).To(BeZero()) - Expect(onDataCalled).To(BeFalse()) }) It("unblocks after the deadline", func() { + mockSender.EXPECT().scheduleSending() deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) str.SetWriteDeadline(deadline) n, err := strWithTimeout.Write([]byte("foobar")) @@ -221,6 +217,7 @@ var _ = Describe("Send Stream", func() { }) It("returns the number of bytes written, when the deadline expires", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()) mockFC.EXPECT().IsNewlyBlocked() @@ -246,6 +243,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't pop any data after the deadline expired", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()) mockFC.EXPECT().IsNewlyBlocked() @@ -268,6 +266,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't unblock if the deadline is changed before the first one expires", func() { + mockSender.EXPECT().scheduleSending() deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) str.SetWriteDeadline(deadline1) @@ -286,6 +285,7 @@ var _ = Describe("Send Stream", func() { }) It("unblocks earlier, when a new deadline is set", func() { + mockSender.EXPECT().scheduleSending() deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) go func() { @@ -361,6 +361,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't get data for writing if an error occurred", func() { + mockSender.EXPECT().scheduleSending() mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(gomock.Any()) mockFC.EXPECT().IsNewlyBlocked() @@ -388,19 +389,19 @@ var _ = Describe("Send Stream", func() { Context("stream cancelations", func() { Context("canceling writing", func() { It("queues a RST_STREAM frame", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1234, + ErrorCode: 9876, + }) str.writeOffset = 1234 err := str.CancelWrite(9876) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.RstStreamFrame{ - StreamID: streamID, - ByteOffset: 1234, - ErrorCode: 9876, - }, - })) }) It("unblocks Write", func() { + mockSender.EXPECT().scheduleSending() + mockSender.EXPECT().queueControlFrame(gomock.Any()) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) mockFC.EXPECT().IsNewlyBlocked() @@ -425,12 +426,14 @@ var _ = Describe("Send Stream", func() { }) It("cancels the context", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) Expect(str.Context().Done()).ToNot(BeClosed()) str.CancelWrite(1234) Expect(str.Context().Done()).To(BeClosed()) }) It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelWrite(1234) Expect(err).ToNot(HaveOccurred()) _, err = strWithTimeout.Write([]byte("foobar")) @@ -438,12 +441,11 @@ var _ = Describe("Send Stream", func() { }) It("only cancels once", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelWrite(1234) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(HaveLen(1)) err = str.CancelWrite(4321) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(HaveLen(1)) }) It("doesn't cancel when the stream was already closed", func() { @@ -456,19 +458,19 @@ var _ = Describe("Send Stream", func() { Context("receiving STOP_SENDING frames", func() { It("queues a RST_STREAM frames with error code Stopping", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ErrorCode: errorCodeStopping, + }) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, }) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.RstStreamFrame{ - StreamID: streamID, - ErrorCode: errorCodeStopping, - }, - })) }) It("unblocks Write", func() { + mockSender.EXPECT().scheduleSending() + mockSender.EXPECT().queueControlFrame(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -488,6 +490,7 @@ var _ = Describe("Send Stream", func() { }) It("doesn't allow further calls to Write", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, @@ -515,12 +518,14 @@ var _ = Describe("Send Stream", func() { }) It("is finished after CancelWrite", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) err := str.CancelWrite(123) Expect(err).ToNot(HaveOccurred()) Expect(str.finished()).To(BeTrue()) }) It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID}) Expect(str.finished()).To(BeTrue()) }) diff --git a/session.go b/session.go index 9ece24b2..1dc2ad34 100644 --- a/session.go +++ b/session.go @@ -115,6 +115,7 @@ type session struct { } var _ Session = &session{} +var _ streamSender = &session{} // newSession makes a new session func newSession( @@ -892,7 +893,7 @@ func (s *session) newStream(id protocol.StreamID) streamI { initialSendWindow, s.rttStats, ) - return newStream(id, s.scheduleSending, s.packer.QueueControlFrame, flowController, s.version) + return newStream(id, s, flowController, s.version) } func (s *session) newCryptoStream() cryptoStreamI { @@ -906,7 +907,7 @@ func (s *session) newCryptoStream() cryptoStreamI { 0, s.rttStats, ) - return newCryptoStream(s.scheduleSending, flowController, s.version) + return newCryptoStream(s, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { @@ -965,6 +966,11 @@ func (s *session) getWindowUpdates() []wire.Frame { return res } +func (s *session) queueControlFrame(f wire.Frame) { + s.packer.QueueControlFrame(f) + s.scheduleSending() +} + func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } diff --git a/stream.go b/stream.go index a0c1fe52..6182e2c0 100644 --- a/stream.go +++ b/stream.go @@ -14,6 +14,12 @@ const ( errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 ) +// The streamSender is notified by the stream about various events. +type streamSender interface { + scheduleSending() + queueControlFrame(wire.Frame) +} + type streamI interface { Stream @@ -61,14 +67,13 @@ var _ StreamError = &streamCanceledError{} // newStream creates a new Stream func newStream(streamID protocol.StreamID, - onData func(), - queueControlFrame func(wire.Frame), + sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { return &stream{ - sendStream: *newSendStream(streamID, onData, queueControlFrame, flowController, version), - receiveStream: *newReceiveStream(streamID, onData, queueControlFrame, flowController), + sendStream: *newSendStream(streamID, sender, flowController, version), + receiveStream: *newReceiveStream(streamID, sender, flowController), } } diff --git a/stream_test.go b/stream_test.go index dd41016f..0a5802e4 100644 --- a/stream_test.go +++ b/stream_test.go @@ -3,15 +3,13 @@ package quic import ( "errors" "io" + "os" "strconv" "time" - "os" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" @@ -33,21 +31,16 @@ var _ = Describe("Stream", func() { const streamID protocol.StreamID = 1337 var ( - str *stream - strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} - onDataCalled bool - queuedControlFrames []wire.Frame - mockFC *mocks.MockStreamFlowController + str *stream + strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer} + mockFC *mocks.MockStreamFlowController + mockSender *MockStreamSender ) - onData := func() { onDataCalled = true } - queueControlFrame := func(f wire.Frame) { queuedControlFrames = append(queuedControlFrames, f) } - BeforeEach(func() { - queuedControlFrames = queuedControlFrames[:0] - onDataCalled = false + mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, onData, queueControlFrame, mockFC, protocol.VersionWhatever) + str = newStream(streamID, mockSender, mockFC, protocol.VersionWhatever) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { @@ -65,6 +58,10 @@ var _ = Describe("Stream", func() { // need some stream cancelation tests here, since gQUIC doesn't cleanly separate the two stream halves Context("stream cancelations", func() { + BeforeEach(func() { + mockSender.EXPECT().scheduleSending().AnyTimes() + }) + Context("for gQUIC", func() { BeforeEach(func() { str.version = versionGQUICFrames @@ -73,6 +70,11 @@ var _ = Describe("Stream", func() { }) It("unblocks Write when receiving a RST_STREAM frame with non-zero error code", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1000, + ErrorCode: errorCodeStoppingGQUIC, + }) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) str.writeOffset = 1000 f := &wire.RstStreamFrame{ @@ -93,17 +95,15 @@ var _ = Describe("Stream", func() { Consistently(writeReturned).ShouldNot(BeClosed()) err := str.handleRstStreamFrame(f) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.RstStreamFrame{ - StreamID: streamID, - ByteOffset: 1000, - ErrorCode: errorCodeStoppingGQUIC, - }, - })) Eventually(writeReturned).Should(BeClosed()) }) It("unblocks Write when receiving a RST_STREAM frame with error code 0", func() { + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 1000, + ErrorCode: errorCodeStoppingGQUIC, + }) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) str.writeOffset = 1000 f := &wire.RstStreamFrame{ @@ -124,13 +124,6 @@ var _ = Describe("Stream", func() { Consistently(writeReturned).ShouldNot(BeClosed()) err := str.handleRstStreamFrame(f) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.RstStreamFrame{ - StreamID: streamID, - ByteOffset: 1000, - ErrorCode: errorCodeStoppingGQUIC, - }, - })) Eventually(writeReturned).Should(BeClosed()) }) @@ -141,7 +134,6 @@ var _ = Describe("Stream", func() { mockFC.EXPECT().IsNewlyBlocked() err := str.CancelRead(1234) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(BeEmpty()) // no RST_STREAM frame queued yet writeReturned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -151,27 +143,25 @@ var _ = Describe("Stream", func() { }() Eventually(func() *wire.StreamFrame { return str.popStreamFrame(1000) }).ShouldNot(BeNil()) Eventually(writeReturned).Should(BeClosed()) - Expect(queuedControlFrames).To(BeEmpty()) // no RST_STREAM frame queued yet + mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ + StreamID: streamID, + ByteOffset: 6, + ErrorCode: 0, + }) err = str.Close() Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(Equal([]wire.Frame{ - &wire.RstStreamFrame{ - StreamID: streamID, - ByteOffset: 6, - ErrorCode: 0, - }, - })) }) }) Context("for IETF QUIC", func() { It("doesn't queue a RST_STREAM after closing the stream", func() { // this is what it does for gQUIC + mockSender.EXPECT().queueControlFrame(&wire.StopSendingFrame{ + StreamID: streamID, + ErrorCode: 1234, + }) err := str.CancelRead(1234) Expect(err).ToNot(HaveOccurred()) - Expect(queuedControlFrames).To(HaveLen(1)) - Expect(queuedControlFrames[0]).To(BeAssignableToTypeOf(&wire.StopSendingFrame{})) Expect(str.Close()).To(Succeed()) - Expect(queuedControlFrames).To(HaveLen(1)) }) }) })