diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index 61d57e31..20297be7 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -21,8 +21,8 @@ type StreamFlowController interface { // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // HasWindowUpdate says if it is necessary to update the window - HasWindowUpdate() bool + // MaybeQueueWindowUpdate queues a window update, if necessary + MaybeQueueWindowUpdate() } // The ConnectionFlowController is the flow controller for the connection. diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 6501278c..aff47fc9 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -14,6 +14,8 @@ type streamFlowController struct { streamID protocol.StreamID + queueWindowUpdate func() + connection connectionFlowControllerI contributesToConnection bool // does the stream contribute to connection level flow control @@ -30,6 +32,7 @@ func NewStreamFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), rttStats *congestion.RTTStats, logger utils.Logger, ) StreamFlowController { @@ -37,6 +40,7 @@ func NewStreamFlowController( streamID: streamID, contributesToConnection: contributesToConnection, connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, @@ -120,11 +124,13 @@ func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { return true, c.sendWindow } -func (c *streamFlowController) HasWindowUpdate() bool { +func (c *streamFlowController) MaybeQueueWindowUpdate() { c.mutex.Lock() hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() c.mutex.Unlock() - return hasWindowUpdate + if hasWindowUpdate { + c.queueWindowUpdate() + } } func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index dfca659a..5707b331 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -12,9 +12,13 @@ import ( ) var _ = Describe("Stream Flow controller", func() { - var controller *streamFlowController + var ( + controller *streamFlowController + queuedWindowUpdate bool + ) BeforeEach(func() { + queuedWindowUpdate = false rttStats := &congestion.RTTStats{} controller = &streamFlowController{ streamID: 10, @@ -23,24 +27,38 @@ var _ = Describe("Stream Flow controller", func() { controller.maxReceiveWindowSize = 10000 controller.rttStats = rttStats controller.logger = utils.DefaultLogger + controller.queueWindowUpdate = func() { queuedWindowUpdate = true } }) Context("Constructor", func() { rttStats := &congestion.RTTStats{} + receiveWindow := protocol.ByteCount(2000) + maxReceiveWindow := protocol.ByteCount(3000) + sendWindow := protocol.ByteCount(4000) It("sets the send and receive windows", func() { - receiveWindow := protocol.ByteCount(2000) - maxReceiveWindow := protocol.ByteCount(3000) - sendWindow := protocol.ByteCount(4000) - cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) - fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) Expect(fc.sendWindow).To(Equal(sendWindow)) Expect(fc.contributesToConnection).To(BeTrue()) }) + + It("queues window updates with the correction stream ID", func() { + var queued bool + queueWindowUpdate := func(id protocol.StreamID) { + Expect(id).To(Equal(protocol.StreamID(5))) + queued = true + } + + cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) + fc.AddBytesRead(receiveWindow) + fc.MaybeQueueWindowUpdate() + Expect(queued).To(BeTrue()) + }) }) Context("receiving data", func() { @@ -175,12 +193,16 @@ var _ = Describe("Stream Flow controller", func() { oldWindowSize = controller.receiveWindowSize }) - It("tells if it has window updates", func() { - Expect(controller.HasWindowUpdate()).To(BeFalse()) + It("queues window updates", func() { + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) controller.AddBytesRead(30) - Expect(controller.HasWindowUpdate()).To(BeTrue()) + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeTrue()) Expect(controller.GetWindowUpdate()).ToNot(BeZero()) - Expect(controller.HasWindowUpdate()).To(BeFalse()) + queuedWindowUpdate = false + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) }) It("tells the connection flow controller when the window was autotuned", func() { @@ -213,7 +235,8 @@ var _ = Describe("Stream Flow controller", func() { controller.AddBytesRead(30) err := controller.UpdateHighestReceived(90, true) Expect(err).ToNot(HaveOccurred()) - Expect(controller.HasWindowUpdate()).To(BeFalse()) + controller.MaybeQueueWindowUpdate() + Expect(queuedWindowUpdate).To(BeFalse()) offset := controller.GetWindowUpdate() Expect(offset).To(BeZero()) }) diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index a69e73f1..0f355a81 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -66,18 +66,6 @@ func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) } -// HasWindowUpdate mocks base method -func (m *MockStreamFlowController) HasWindowUpdate() bool { - ret := m.ctrl.Call(m, "HasWindowUpdate") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasWindowUpdate indicates an expected call of HasWindowUpdate -func (mr *MockStreamFlowControllerMockRecorder) HasWindowUpdate() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).HasWindowUpdate)) -} - // IsBlocked mocks base method func (m *MockStreamFlowController) IsBlocked() (bool, protocol.ByteCount) { ret := m.ctrl.Call(m, "IsBlocked") @@ -91,6 +79,16 @@ func (mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked)) } +// MaybeQueueWindowUpdate mocks base method +func (m *MockStreamFlowController) MaybeQueueWindowUpdate() { + m.ctrl.Call(m, "MaybeQueueWindowUpdate") +} + +// MaybeQueueWindowUpdate indicates an expected call of MaybeQueueWindowUpdate +func (mr *MockStreamFlowControllerMockRecorder) MaybeQueueWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeQueueWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).MaybeQueueWindowUpdate)) +} + // SendWindowSize mocks base method func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { ret := m.ctrl.Call(m, "SendWindowSize") diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index da3ad8d0..d6f090a1 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -45,16 +45,6 @@ func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) } -// onHasWindowUpdate mocks base method -func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) { - m.ctrl.Call(m, "onHasWindowUpdate", arg0) -} - -// onHasWindowUpdate indicates an expected call of onHasWindowUpdate -func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0) -} - // onStreamCompleted mocks base method func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { m.ctrl.Call(m, "onStreamCompleted", arg0) diff --git a/receive_stream.go b/receive_stream.go index 9fc158f1..cec69f1c 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -151,10 +151,8 @@ func (s *receiveStream) Read(p []byte) (int, error) { if !s.resetRemotely { s.flowController.AddBytesRead(protocol.ByteCount(m)) } - // this call triggers the flow controller to increase the flow control window, if necessary - if s.flowController.HasWindowUpdate() { - s.sender.onHasWindowUpdate(s.streamID) - } + // increase the flow control window, if necessary + s.flowController.MaybeQueueWindowUpdate() if s.readPosInFrame >= int(frame.DataLen()) { s.frameQueue.Pop() diff --git a/receive_stream_test.go b/receive_stream_test.go index 4a3e04e4..43443785 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -43,7 +43,7 @@ var _ = Describe("Receive Stream", func() { It("reads a single STREAM frame", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -61,7 +61,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -83,7 +83,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -107,7 +107,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -130,7 +130,7 @@ var _ = Describe("Receive Stream", func() { It("waits until data is available", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() go func() { defer GinkgoRecover() frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} @@ -148,7 +148,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -173,7 +173,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -204,7 +204,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foob"), @@ -230,22 +230,6 @@ var _ = Describe("Receive Stream", func() { Expect(err).To(MatchError(errEmptyStreamData)) }) - It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().HasWindowUpdate().Return(true) - mockSender.EXPECT().onHasWindowUpdate(streamID) - frame1 := wire.StreamFrame{ - Offset: 0, - Data: []byte("foobar"), - } - err := str.handleStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = strWithTimeout.Read(b) - Expect(err).ToNot(HaveOccurred()) - }) - Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { Expect(errDeadline.Temporary()).To(BeTrue()) @@ -318,7 +302,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -339,7 +323,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -367,7 +351,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs with partial read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, Data: []byte{0xde, 0xad}, @@ -385,7 +369,7 @@ var _ = Describe("Receive Stream", func() { It("handles immediate FINs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, FinBit: true, @@ -402,7 +386,7 @@ var _ = Describe("Receive Stream", func() { It("closes when CloseRemote is called", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() str.CloseRemote(0) mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 8) @@ -478,7 +462,7 @@ var _ = Describe("Receive Stream", func() { It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().HasWindowUpdate() + mockFC.EXPECT().MaybeQueueWindowUpdate() // no calls to mockSender.queueControlFrame err := str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, @@ -601,7 +585,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), mockSender.EXPECT().onStreamCompleted(streamID), ) - mockFC.EXPECT().HasWindowUpdate().Times(2) + mockFC.EXPECT().MaybeQueueWindowUpdate().Times(2) readReturned := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/session.go b/session.go index 8194dfe0..d144e89b 100644 --- a/session.go +++ b/session.go @@ -1137,6 +1137,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, + s.onHasWindowUpdate, s.rttStats, s.logger, ) @@ -1151,6 +1152,7 @@ func (s *session) newCryptoStream() cryptoStreamI { protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), 0, + s.onHasWindowUpdate, s.rttStats, s.logger, ) diff --git a/stream.go b/stream.go index 83123493..f8d851be 100644 --- a/stream.go +++ b/stream.go @@ -18,7 +18,6 @@ const ( // The streamSender is notified by the stream about various events. type streamSender interface { queueControlFrame(wire.Frame) - onHasWindowUpdate(protocol.StreamID) onHasStreamData(protocol.StreamID) onStreamCompleted(protocol.StreamID) } @@ -34,10 +33,6 @@ func (s *uniStreamSender) queueControlFrame(f wire.Frame) { s.streamSender.queueControlFrame(f) } -func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { - s.streamSender.onHasWindowUpdate(id) -} - func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { s.streamSender.onHasStreamData(id) }