diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 604fb925..208d955d 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -5,6 +5,7 @@ package quic import ( gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" reflect "reflect" ) @@ -32,6 +33,16 @@ func (_m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { return _m.recorder } +// onHasWindowUpdate mocks base method +func (_m *MockStreamSender) onHasWindowUpdate(_param0 protocol.StreamID, _param1 protocol.ByteCount) { + _m.ctrl.Call(_m, "onHasWindowUpdate", _param0, _param1) +} + +// onHasWindowUpdate indicates an expected call of onHasWindowUpdate +func (_mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0, arg1) +} + // queueControlFrame mocks base method func (_m *MockStreamSender) queueControlFrame(_param0 wire.Frame) { _m.ctrl.Call(_m, "queueControlFrame", _param0) diff --git a/receive_stream.go b/receive_stream.go index 791184d7..6c9e0460 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -139,7 +139,10 @@ func (s *receiveStream) Read(p []byte) (int, error) { if !s.resetRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } - s.sender.scheduleSending() // so that a possible WINDOW_UPDATE is sent + // this call triggers the flow controller to increase the flow control window, if necessary + if offset := s.flowController.GetWindowUpdate(); offset != 0 { + s.sender.onHasWindowUpdate(s.streamID, offset) + } if s.readPosInFrame >= int(frame.DataLen()) { s.frameQueue.Pop() diff --git a/receive_stream_test.go b/receive_stream_test.go index 73d70724..058768e8 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -41,9 +41,9 @@ var _ = Describe("Receive Stream", func() { Context("reading", func() { It("reads a single STREAM frame", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().GetWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -58,10 +58,10 @@ var _ = Describe("Receive Stream", func() { }) It("reads a single STREAM frame in multiple goes", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -80,10 +80,10 @@ var _ = Describe("Receive Stream", func() { }) It("reads all data available", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -104,10 +104,10 @@ var _ = Describe("Receive Stream", func() { }) It("assembles multiple STREAM frames", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -128,9 +128,9 @@ var _ = Describe("Receive Stream", func() { }) It("waits until data is available", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().GetWindowUpdate() go func() { defer GinkgoRecover() frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} @@ -145,10 +145,10 @@ var _ = Describe("Receive Stream", func() { }) It("handles STREAM frames in wrong order", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -169,11 +169,11 @@ var _ = Describe("Receive Stream", func() { }) It("ignores duplicate STREAM frames", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -200,11 +200,11 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't rejects a STREAM frames with an overlapping data range", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foob"), @@ -230,6 +230,22 @@ var _ = Describe("Receive Stream", func() { Expect(err).To(MatchError(errEmptyStreamData)) }) + It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(1337)) + mockSender.EXPECT().onHasWindowUpdate(streamID, protocol.ByteCount(1337)) + frame1 := wire.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + } + err := str.handleStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + _, err = strWithTimeout.Read(b) + Expect(err).ToNot(HaveOccurred()) + }) + Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { Expect(errDeadline.Temporary()).To(BeTrue()) @@ -300,9 +316,9 @@ var _ = Describe("Receive Stream", func() { Context("closing", func() { Context("with FIN bit", func() { It("returns EOFs", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) + mockFC.EXPECT().GetWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -320,10 +336,10 @@ var _ = Describe("Receive Stream", func() { }) It("handles out-of-order frames", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().GetWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -348,9 +364,9 @@ var _ = Describe("Receive Stream", func() { }) It("returns EOFs with partial read", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().GetWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -366,9 +382,9 @@ var _ = Describe("Receive Stream", func() { }) It("handles immediate FINs", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + mockFC.EXPECT().GetWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{}, @@ -384,9 +400,9 @@ var _ = Describe("Receive Stream", func() { }) It("closes when CloseRemote is called", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + mockFC.EXPECT().GetWindowUpdate() str.CloseRemote(0) b := make([]byte, 8) n, err := strWithTimeout.Read(b) @@ -459,9 +475,9 @@ var _ = Describe("Receive Stream", func() { }) It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + mockFC.EXPECT().GetWindowUpdate() // no calls to mockSender.queueControlFrame err := str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, @@ -569,10 +585,10 @@ var _ = Describe("Receive Stream", func() { }) It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() { - mockSender.EXPECT().scheduleSending().Times(2) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().GetWindowUpdate().Times(2) readReturned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -637,8 +653,8 @@ var _ = Describe("Receive Stream", func() { }) It("is finished if it is only closed for reading", func() { - mockSender.EXPECT().scheduleSending() mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) + mockFC.EXPECT().GetWindowUpdate() finishReading() Expect(str.finished()).To(BeTrue()) }) diff --git a/session.go b/session.go index 1dc2ad34..1cf5ad90 100644 --- a/session.go +++ b/session.go @@ -62,8 +62,8 @@ type session struct { sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - - connFlowController flowcontrol.ConnectionFlowController + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker @@ -316,7 +316,6 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, initialPacketNumber, s.cryptoSetup, @@ -324,6 +323,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.perspective, s.version, ) + s.windowUpdateQueue = newWindowUpdateQueue(s.packer.QueueControlFrame) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} return nil } @@ -717,14 +717,13 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete func (s *session) sendPacket() error { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - // Get MAX_DATA and MAX_STREAM_DATA frames - // this call triggers the flow controller to increase the flow control windows, if necessary - for _, f := range s.getWindowUpdates() { - s.packer.QueueControlFrame(f) + if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { + s.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: offset}) } if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) } + s.windowUpdateQueue.QueueAll() ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { @@ -948,29 +947,16 @@ func (s *session) tryDecryptingQueuedPackets() { s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *session) getWindowUpdates() []wire.Frame { - var res []wire.Frame - s.streamsMap.Range(func(str streamI) { - if offset := str.getWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxStreamDataFrame{ - StreamID: str.StreamID(), - ByteOffset: offset, - }) - } - }) - if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxDataFrame{ - ByteOffset: offset, - }) - } - return res -} - func (s *session) queueControlFrame(f wire.Frame) { s.packer.QueueControlFrame(f) s.scheduleSending() } +func (s *session) onHasWindowUpdate(streamID protocol.StreamID, offset protocol.ByteCount) { + s.windowUpdateQueue.Add(streamID, offset) + s.scheduleSending() +} + func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } diff --git a/session_test.go b/session_test.go index acf9023a..2eeca88d 100644 --- a/session_test.go +++ b/session_test.go @@ -857,6 +857,45 @@ var _ = Describe("Session", func() { Expect(mconn.written).To(HaveLen(1)) }) + It("adds a MAX_DATA frames", func() { + fc := mocks.NewMockConnectionFlowController(mockCtrl) + fc.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) + fc.EXPECT().IsNewlyBlocked() + sess.connFlowController = fc + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(true) + sph.EXPECT().SendingAllowed() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(Equal([]wire.Frame{ + &wire.MaxDataFrame{ByteOffset: 0x1337}, + })) + }) + sess.sentPacketHandler = sph + err := sess.sendPacket() + Expect(err).ToNot(HaveOccurred()) + }) + + It("adds MAX_STREAM_DATA frames", func() { + sess.windowUpdateQueue.Add(1, 10) + sess.windowUpdateQueue.Add(2, 20) + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(true) + sph.EXPECT().SendingAllowed() + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) + Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) + }) + sess.sentPacketHandler = sph + err := sess.sendPacket() + Expect(err).ToNot(HaveOccurred()) + }) + It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().GetWindowUpdate() diff --git a/stream.go b/stream.go index 6182e2c0..9da1fb60 100644 --- a/stream.go +++ b/stream.go @@ -18,6 +18,7 @@ const ( type streamSender interface { scheduleSending() queueControlFrame(wire.Frame) + onHasWindowUpdate(protocol.StreamID, protocol.ByteCount) } type streamI interface { diff --git a/window_update_queue.go b/window_update_queue.go new file mode 100644 index 00000000..bed3d60c --- /dev/null +++ b/window_update_queue.go @@ -0,0 +1,40 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]protocol.ByteCount + callback func(wire.Frame) +} + +func newWindowUpdateQueue(cb func(wire.Frame)) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]protocol.ByteCount), + callback: cb, + } +} + +func (q *windowUpdateQueue) Add(stream protocol.StreamID, offset protocol.ByteCount) { + q.mutex.Lock() + q.queue[stream] = offset + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + for stream, offset := range q.queue { + q.callback(&wire.MaxStreamDataFrame{ + StreamID: stream, + ByteOffset: offset, + }) + delete(q.queue, stream) + } + q.mutex.Unlock() +} diff --git a/window_update_queue_test.go b/window_update_queue_test.go new file mode 100644 index 00000000..ab488e16 --- /dev/null +++ b/window_update_queue_test.go @@ -0,0 +1,49 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Window Update Queue", func() { + var ( + q *windowUpdateQueue + queuedFrames []wire.Frame + ) + + BeforeEach(func() { + queuedFrames = queuedFrames[:0] + q = newWindowUpdateQueue(func(f wire.Frame) { + queuedFrames = append(queuedFrames, f) + }) + }) + + It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { + q.Add(1, 10) + q.Add(2, 20) + q.Add(3, 30) + q.QueueAll() + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) + Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, ByteOffset: 30})) + }) + + It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { + q.Add(10, 100) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + q.QueueAll() + Expect(queuedFrames).To(HaveLen(1)) + }) + + It("replaces old entries", func() { + q.Add(10, 100) + q.Add(10, 200) + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamDataFrame{StreamID: 10, ByteOffset: 200}, + })) + }) +})