diff --git a/internal/flowcontrol/base_flow_controller.go b/internal/flowcontrol/base_flow_controller.go index baa8b045..e74c1d11 100644 --- a/internal/flowcontrol/base_flow_controller.go +++ b/internal/flowcontrol/base_flow_controller.go @@ -1,7 +1,7 @@ package flowcontrol import ( - "errors" + "sync" "time" "github.com/lucas-clemente/quic-go/congestion" @@ -10,6 +10,8 @@ import ( ) type baseFlowController struct { + mutex sync.RWMutex + rttStats *congestion.RTTStats bytesSent protocol.ByteCount @@ -24,24 +26,25 @@ type baseFlowController struct { maxReceiveWindowIncrement protocol.ByteCount } -// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously -var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") - func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.bytesSent += n } // UpdateSendWindow should be called after receiving a WindowUpdateFrame // it returns true if the window was actually updated -func (c *baseFlowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { - if newOffset > c.sendWindow { - c.sendWindow = newOffset - return true +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if offset > c.sendWindow { + c.sendWindow = offset } - return false } -func (c *baseFlowController) SendWindowSize() protocol.ByteCount { +func (c *baseFlowController) sendWindowSize() protocol.ByteCount { // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters if c.bytesSent > c.sendWindow { return 0 @@ -50,6 +53,9 @@ func (c *baseFlowController) SendWindowSize() protocol.ByteCount { } func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + // pretend we sent a WindowUpdate when reading the first byte // this way auto-tuning of the window increment already works for the first WindowUpdate if c.bytesRead == 0 { @@ -58,28 +64,26 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { c.bytesRead += n } -// MaybeUpdateWindow updates the receive window, if necessary -// if the receive window increment is changed, the new value is returned, otherwise a 0 -// the last return value is the new offset of the receive window -func (c *baseFlowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) { +// getWindowUpdate updates the receive window, if necessary +// it returns the new offset +func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { diff := c.receiveWindow - c.bytesRead - - // Chromium implements the same threshold - if diff < (c.receiveWindowIncrement / 2) { - var newWindowIncrement protocol.ByteCount - oldWindowIncrement := c.receiveWindowIncrement - - c.maybeAdjustWindowIncrement() - if c.receiveWindowIncrement != oldWindowIncrement { - newWindowIncrement = c.receiveWindowIncrement - } - - c.lastWindowUpdateTime = time.Now() - c.receiveWindow = c.bytesRead + c.receiveWindowIncrement - return true, newWindowIncrement, c.receiveWindow + // update the window when more than half of it was already consumed + if diff >= (c.receiveWindowIncrement / 2) { + return 0 } - return false, 0, 0 + c.maybeAdjustWindowIncrement() + c.receiveWindow = c.bytesRead + c.receiveWindowIncrement + c.lastWindowUpdateTime = time.Now() + return c.receiveWindow +} + +func (c *baseFlowController) IsBlocked() bool { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.sendWindowSize() == 0 } // maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often @@ -94,7 +98,6 @@ func (c *baseFlowController) maybeAdjustWindowIncrement() { } timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) - // interval between the window updates is sufficiently large, no need to increase the increment if timeSinceLastWindowUpdate >= 2*rtt { return @@ -102,6 +105,6 @@ func (c *baseFlowController) maybeAdjustWindowIncrement() { c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) } -func (c *baseFlowController) CheckFlowControlViolation() bool { +func (c *baseFlowController) checkFlowControlViolation() bool { return c.highestReceived > c.receiveWindow } diff --git a/internal/flowcontrol/base_flow_controller_test.go b/internal/flowcontrol/base_flow_controller_test.go index e68c91a2..0ac218bf 100644 --- a/internal/flowcontrol/base_flow_controller_test.go +++ b/internal/flowcontrol/base_flow_controller_test.go @@ -27,30 +27,34 @@ var _ = Describe("Base Flow controller", func() { It("gets the size of the remaining flow control window", func() { controller.bytesSent = 5 controller.sendWindow = 12 - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(12 - 5))) - }) - - It("gets the offset of the flow control window", func() { - controller.bytesSent = 5 - controller.sendWindow = 12 - Expect(controller.sendWindow).To(Equal(protocol.ByteCount(12))) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(12 - 5))) }) It("updates the size of the flow control window", func() { - controller.bytesSent = 5 - updateSuccessful := controller.UpdateSendWindow(15) - Expect(updateSuccessful).To(BeTrue()) + controller.AddBytesSent(5) + controller.UpdateSendWindow(15) Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15))) - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) + }) + + It("says that the window size is 0 if we sent more than we were allowed to", func() { + controller.AddBytesSent(15) + controller.UpdateSendWindow(10) + Expect(controller.sendWindowSize()).To(BeZero()) }) It("does not decrease the flow control window", func() { - updateSuccessful := controller.UpdateSendWindow(20) - Expect(updateSuccessful).To(BeTrue()) - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20))) - updateSuccessful = controller.UpdateSendWindow(10) - Expect(updateSuccessful).To(BeFalse()) - Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20))) + controller.UpdateSendWindow(20) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) + controller.UpdateSendWindow(10) + Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) + }) + + It("says when it's blocked", func() { + controller.UpdateSendWindow(100) + Expect(controller.IsBlocked()).To(BeFalse()) + controller.AddBytesSent(100) + Expect(controller.IsBlocked()).To(BeTrue()) }) }) @@ -73,8 +77,7 @@ var _ = Describe("Base Flow controller", func() { controller.lastWindowUpdateTime = time.Now().Add(-time.Hour) readPosition := receiveWindow - receiveWindowIncrement/2 + 1 controller.bytesRead = readPosition - updateNecessary, _, offset := controller.MaybeUpdateWindow() - Expect(updateNecessary).To(BeTrue()) + offset := controller.getWindowUpdate() Expect(offset).To(Equal(readPosition + receiveWindowIncrement)) Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowIncrement)) Expect(controller.lastWindowUpdateTime).To(BeTemporally("~", time.Now(), 20*time.Millisecond)) @@ -85,8 +88,8 @@ var _ = Describe("Base Flow controller", func() { controller.lastWindowUpdateTime = lastWindowUpdateTime readPosition := receiveWindow - receiveWindow/2 - 1 controller.bytesRead = readPosition - updateNecessary, _, _ := controller.MaybeUpdateWindow() - Expect(updateNecessary).To(BeFalse()) + offset := controller.getWindowUpdate() + Expect(offset).To(BeZero()) Expect(controller.lastWindowUpdateTime).To(Equal(lastWindowUpdateTime)) }) @@ -148,39 +151,28 @@ var _ = Describe("Base Flow controller", func() { setRtt(20 * time.Millisecond) controller.AddBytesRead(9900) // receive window is 10000 controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) - necessary, newIncrement, offset := controller.MaybeUpdateWindow() - Expect(necessary).To(BeTrue()) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) + newIncrement := controller.receiveWindowIncrement Expect(newIncrement).To(Equal(2 * oldIncrement)) - Expect(controller.receiveWindowIncrement).To(Equal(newIncrement)) Expect(offset).To(Equal(protocol.ByteCount(9900 + newIncrement))) }) It("increases the increment sent in the first WindowUpdate, if data is read fast enough", func() { setRtt(20 * time.Millisecond) controller.AddBytesRead(9900) - necessary, newIncrement, _ := controller.MaybeUpdateWindow() - Expect(necessary).To(BeTrue()) - Expect(newIncrement).To(Equal(2 * oldIncrement)) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) + Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) }) It("doesn't increamse the increment sent in the first WindowUpdate, if data is read slowly", func() { setRtt(5 * time.Millisecond) controller.AddBytesRead(9900) time.Sleep(15 * time.Millisecond) // more than 2x RTT - necessary, newIncrement, _ := controller.MaybeUpdateWindow() - Expect(necessary).To(BeTrue()) - Expect(newIncrement).To(BeZero()) - }) - - It("only returns the increment if it was increased", func() { - setRtt(20 * time.Millisecond) - controller.AddBytesRead(9900) // receive window is 10000 - controller.lastWindowUpdateTime = time.Now().Add(-45 * time.Millisecond) - necessary, newIncrement, offset := controller.MaybeUpdateWindow() - Expect(necessary).To(BeTrue()) - Expect(newIncrement).To(BeZero()) + offset := controller.getWindowUpdate() + Expect(offset).ToNot(BeZero()) Expect(controller.receiveWindowIncrement).To(Equal(oldIncrement)) - Expect(offset).To(Equal(protocol.ByteCount(9900 + oldIncrement))) }) }) }) diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 74ab8f42..934d646d 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -1,55 +1,77 @@ package flowcontrol import ( + "fmt" "time" "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" ) type connectionFlowController struct { baseFlowController } -// newConnectionFlowController gets a new flow controller for the connection -func newConnectionFlowController( +var _ ConnectionFlowController = &connectionFlowController{} + +// NewConnectionFlowController gets a new flow controller for the connection +// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0. +func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, - initialSendWindow protocol.ByteCount, rttStats *congestion.RTTStats, -) *connectionFlowController { +) ConnectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, receiveWindowIncrement: receiveWindow, maxReceiveWindowIncrement: maxReceiveWindow, - sendWindow: initialSendWindow, }, } } +func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.baseFlowController.sendWindowSize() +} + +// IncrementHighestReceived adds an increment to the highestReceived value +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.highestReceived += increment + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow)) + } + return nil +} + +func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + oldWindowIncrement := c.receiveWindowIncrement + offset := c.baseFlowController.getWindowUpdate() + if oldWindowIncrement < c.receiveWindowIncrement { + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + } + return offset +} + // EnsureMinimumWindowIncrement sets a minimum window increment // it should make sure that the connection-level window is increased when a stream-level window grows func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + if inc > c.receiveWindowIncrement { c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update } } - -// IncrementHighestReceived adds an increment to the highestReceived value -func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) { - c.highestReceived += increment -} - -func (c *connectionFlowController) MaybeUpdateWindow() (bool, protocol.ByteCount, protocol.ByteCount) { - oldWindowSize := c.receiveWindowIncrement - updated, newIncrement, newOffset := c.baseFlowController.MaybeUpdateWindow() - // debug log, if the window size was actually increased - if oldWindowSize < c.receiveWindowIncrement { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) - } - return updated, newIncrement, newOffset -} diff --git a/internal/flowcontrol/connection_flow_controller_test.go b/internal/flowcontrol/connection_flow_controller_test.go index 89300e01..dc400e1f 100644 --- a/internal/flowcontrol/connection_flow_controller_test.go +++ b/internal/flowcontrol/connection_flow_controller_test.go @@ -12,6 +12,12 @@ import ( var _ = Describe("Connection Flow controller", func() { var controller *connectionFlowController + // update the congestion such that it returns a given value for the smoothed RTT + setRtt := func(t time.Duration) { + controller.rttStats.UpdateRTT(t, 0, time.Now()) + Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked + } + BeforeEach(func() { controller = &connectionFlowController{} controller.rttStats = &congestion.RTTStats{} @@ -23,12 +29,10 @@ var _ = Describe("Connection Flow controller", func() { It("sets the send and receive windows", func() { receiveWindow := protocol.ByteCount(2000) maxReceiveWindow := protocol.ByteCount(3000) - sendWindow := protocol.ByteCount(4000) - fc := newConnectionFlowController(receiveWindow, maxReceiveWindow, sendWindow, rttStats) + fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) - Expect(fc.sendWindow).To(Equal(sendWindow)) }) }) @@ -38,12 +42,36 @@ var _ = Describe("Connection Flow controller", func() { controller.IncrementHighestReceived(123) Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123))) }) + + Context("getting window updates", func() { + BeforeEach(func() { + controller.receiveWindow = 100 + controller.receiveWindowIncrement = 60 + controller.maxReceiveWindowIncrement = 1000 + }) + + It("gets a window update", func() { + controller.AddBytesRead(80) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(protocol.ByteCount(80 + 60))) + }) + + It("autotunes the window", func() { + controller.AddBytesRead(80) + setRtt(20 * time.Millisecond) + controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(protocol.ByteCount(80 + 2*60))) + }) + }) }) Context("setting the minimum increment", func() { - var oldIncrement protocol.ByteCount - var receiveWindow protocol.ByteCount = 10000 - var receiveWindowIncrement protocol.ByteCount = 600 + var ( + oldIncrement protocol.ByteCount + receiveWindow protocol.ByteCount = 10000 + receiveWindowIncrement protocol.ByteCount = 600 + ) BeforeEach(func() { controller.receiveWindow = receiveWindow @@ -52,12 +80,6 @@ var _ = Describe("Connection Flow controller", func() { controller.maxReceiveWindowIncrement = 3000 }) - // update the congestion such that it returns a given value for the smoothed RTT - setRtt := func(t time.Duration) { - controller.rttStats.UpdateRTT(t, 0, time.Now()) - Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked - } - It("sets the minimum window increment", func() { controller.EnsureMinimumWindowIncrement(1000) Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(1000))) @@ -79,9 +101,8 @@ var _ = Describe("Connection Flow controller", func() { controller.bytesRead = 9900 // receive window is 10000 controller.lastWindowUpdateTime = time.Now().Add(-20 * time.Millisecond) controller.EnsureMinimumWindowIncrement(912) - necessary, newIncrement, offset := controller.MaybeUpdateWindow() - Expect(necessary).To(BeTrue()) - Expect(newIncrement).To(BeZero()) // no auto-tuning + offset := controller.getWindowUpdate() + Expect(controller.receiveWindowIncrement).To(Equal(protocol.ByteCount(912))) // no auto-tuning Expect(offset).To(Equal(protocol.ByteCount(9900 + 912))) }) }) diff --git a/internal/flowcontrol/flow_control_manager.go b/internal/flowcontrol/flow_control_manager.go deleted file mode 100644 index 55be904e..00000000 --- a/internal/flowcontrol/flow_control_manager.go +++ /dev/null @@ -1,250 +0,0 @@ -package flowcontrol - -import ( - "errors" - "fmt" - "sync" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" -) - -type flowControlManager struct { - rttStats *congestion.RTTStats - maxReceiveStreamWindow protocol.ByteCount - - streamFlowController map[protocol.StreamID]*streamFlowController - connFlowController *connectionFlowController - mutex sync.RWMutex - - initialStreamSendWindow protocol.ByteCount -} - -var _ FlowControlManager = &flowControlManager{} - -var errMapAccess = errors.New("Error accessing the flowController map") - -// NewFlowControlManager creates a new flow control manager -func NewFlowControlManager( - maxReceiveStreamWindow protocol.ByteCount, - maxReceiveConnectionWindow protocol.ByteCount, - rttStats *congestion.RTTStats, -) FlowControlManager { - return &flowControlManager{ - rttStats: rttStats, - maxReceiveStreamWindow: maxReceiveStreamWindow, - streamFlowController: make(map[protocol.StreamID]*streamFlowController), - connFlowController: newConnectionFlowController(protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, 0, rttStats), - } -} - -// NewStream creates new flow controllers for a stream -// it does nothing if the stream already exists -func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) { - f.mutex.Lock() - defer f.mutex.Unlock() - - if _, ok := f.streamFlowController[streamID]; ok { - return - } - f.streamFlowController[streamID] = newStreamFlowController(streamID, contributesToConnection, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.initialStreamSendWindow, f.rttStats) -} - -// RemoveStream removes a closed stream from flow control -func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { - f.mutex.Lock() - delete(f.streamFlowController, streamID) - f.mutex.Unlock() -} - -func (f *flowControlManager) UpdateTransportParameters(params *handshake.TransportParameters) { - f.mutex.Lock() - defer f.mutex.Unlock() - - f.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) - f.initialStreamSendWindow = params.StreamFlowControlWindow - for _, fc := range f.streamFlowController { - fc.UpdateSendWindow(params.StreamFlowControlWindow) - } -} - -// ResetStream should be called when receiving a RstStreamFrame -// it updates the byte offset to the value in the RstStreamFrame -// streamID must not be 0 here -func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - increment, err := streamFlowController.UpdateHighestReceived(byteOffset) - if err != nil { - return qerr.StreamDataAfterTermination - } - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// UpdateHighestReceived updates the highest received byte offset for a stream -// it adds the number of additional bytes to connection level flow control -// streamID must not be 0 here -func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - // UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered - // this error can be ignored here - increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesRead(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesRead(n) - } - - return nil -} - -func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { - f.mutex.Lock() - defer f.mutex.Unlock() - - // get WindowUpdates for streams - for id, fc := range f.streamFlowController { - if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: id, Offset: offset}) - if fc.ContributesToConnection() && newIncrement != 0 { - f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) - } - } - } - // get a WindowUpdate for the connection - if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: 0, Offset: offset}) - } - - return -} - -func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - // StreamID can be 0 when retransmitting - if streamID == 0 { - return f.connFlowController.receiveWindow, nil - } - - flowController, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - return flowController.receiveWindow, nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesSent(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesSent(n) - } - - return nil -} - -// must not be called with StreamID 0 -func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - res := fc.SendWindowSize() - - if fc.ContributesToConnection() { - res = utils.MinByteCount(res, f.connFlowController.SendWindowSize()) - } - - return res, nil -} - -func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - f.mutex.RLock() - defer f.mutex.RUnlock() - - return f.connFlowController.SendWindowSize() -} - -// streamID must not be 0 here -func (f *flowControlManager) UpdateStreamWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { - fc, err := f.getFlowController(streamID) - if err != nil { - return false, err - } - return fc.UpdateSendWindow(offset), nil -} - -func (f *flowControlManager) UpdateConnectionWindow(offset protocol.ByteCount) bool { - return f.connFlowController.UpdateSendWindow(offset) -} - -func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*streamFlowController, error) { - streamFlowController, ok := f.streamFlowController[streamID] - if !ok { - return nil, errMapAccess - } - return streamFlowController, nil -} diff --git a/internal/flowcontrol/flow_control_manager_test.go b/internal/flowcontrol/flow_control_manager_test.go deleted file mode 100644 index 4e6b9c87..00000000 --- a/internal/flowcontrol/flow_control_manager_test.go +++ /dev/null @@ -1,369 +0,0 @@ -package flowcontrol - -import ( - "time" - - "github.com/lucas-clemente/quic-go/internal/handshake" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/qerr" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Flow Control Manager", func() { - var fcm *flowControlManager - - BeforeEach(func() { - fcm = NewFlowControlManager( - 0x2000, // maxReceiveStreamWindow - 0x4000, // maxReceiveConnectionWindow - &congestion.RTTStats{}, - ).(*flowControlManager) - }) - - It("creates a connection level flow controller", func() { - Expect(fcm.streamFlowController).To(BeEmpty()) - Expect(fcm.connFlowController.sendWindow).To(BeZero()) - Expect(fcm.connFlowController.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x4000))) - }) - - Context("creating new streams", func() { - It("creates a new stream", func() { - fcm.NewStream(5, false) - Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5))) - fc := fcm.streamFlowController[5] - Expect(fc.streamID).To(Equal(protocol.StreamID(5))) - Expect(fc.ContributesToConnection()).To(BeFalse()) - // the transport parameters have not yet been received. Start with a window of size 0 - Expect(fc.sendWindow).To(BeZero()) - Expect(fc.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x2000))) - }) - - It("creates a new stream after it has received transport parameters", func() { - fcm.UpdateTransportParameters(&handshake.TransportParameters{ - StreamFlowControlWindow: 0x3000, - }) - fcm.NewStream(5, false) - Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5))) - fc := fcm.streamFlowController[5] - Expect(fc.sendWindow).To(Equal(protocol.ByteCount(0x3000))) - }) - - It("doesn't create a new flow controller if called for an existing stream", func() { - fcm.NewStream(5, true) - Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5))) - fcm.streamFlowController[5].bytesRead = 0x1337 - fcm.NewStream(5, false) - fc := fcm.streamFlowController[5] - Expect(fc.bytesRead).To(BeEquivalentTo(0x1337)) - Expect(fc.ContributesToConnection()).To(BeTrue()) - }) - }) - - It("removes streams", func() { - fcm.NewStream(5, true) - Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5))) - fcm.RemoveStream(5) - Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(5))) - }) - - It("updates the send windows for existing streams when receiveing the transport parameters", func() { - fcm.NewStream(5, false) - fcm.UpdateTransportParameters(&handshake.TransportParameters{ - StreamFlowControlWindow: 0x3000, - ConnectionFlowControlWindow: 0x6000, - }) - Expect(fcm.connFlowController.sendWindow).To(Equal(protocol.ByteCount(0x6000))) - Expect(fcm.streamFlowController[5].sendWindow).To(Equal(protocol.ByteCount(0x3000))) - }) - - Context("receiving data", func() { - BeforeEach(func() { - fcm.NewStream(1, false) - fcm.NewStream(4, true) - fcm.NewStream(6, true) - - for _, fc := range fcm.streamFlowController { - fc.receiveWindow = 100 - fc.receiveWindowIncrement = 100 - } - fcm.connFlowController.receiveWindow = 200 - fcm.connFlowController.receiveWindowIncrement = 200 - }) - - It("updates the connection level flow controller if the stream contributes", func() { - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100))) - Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(100))) - }) - - It("adds the offsets of multiple streams for the connection flow control window", func() { - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.UpdateHighestReceived(6, 50) - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100 + 50))) - }) - - It("does not update the connection level flow controller if the stream does not contribute", func() { - err := fcm.UpdateHighestReceived(1, 100) - // fcm.streamFlowController[4].receiveWindow = 0x1000 - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.highestReceived).To(BeZero()) - Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(100))) - }) - - It("returns an error when called with an unknown stream", func() { - err := fcm.UpdateHighestReceived(1337, 0x1337) - Expect(err).To(MatchError(errMapAccess)) - }) - - It("gets the offset of the receive window", func() { - offset, err := fcm.GetReceiveWindow(4) - Expect(err).ToNot(HaveOccurred()) - Expect(offset).To(Equal(protocol.ByteCount(100))) - }) - - It("errors when asked for the receive window of a stream that doesn't exist", func() { - _, err := fcm.GetReceiveWindow(17) - Expect(err).To(MatchError(errMapAccess)) - }) - - It("gets the offset of the connection-level receive window", func() { - offset, err := fcm.GetReceiveWindow(0) - Expect(err).ToNot(HaveOccurred()) - Expect(offset).To(Equal(protocol.ByteCount(200))) - }) - - Context("flow control violations", func() { - It("errors when encountering a stream level flow control violation", func() { - err := fcm.UpdateHighestReceived(4, 101) - Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 101 bytes on stream 4, allowed 100 bytes"))) - }) - - It("errors when encountering a connection-level flow control violation", func() { - fcm.streamFlowController[4].receiveWindow = 300 - fcm.streamFlowController[6].receiveWindow = 300 - err := fcm.UpdateHighestReceived(6, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.UpdateHighestReceived(4, 103) - Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 203 bytes for the connection, allowed 200 bytes"))) - }) - }) - - Context("window updates", func() { - // update the congestion such that it returns a given value for the smoothed RTT - setRtt := func(t time.Duration) { - for _, controller := range fcm.streamFlowController { - controller.rttStats.UpdateRTT(t, 0, time.Now()) - Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked - } - } - - It("gets stream level window updates", func() { - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesRead(4, 90) - Expect(err).ToNot(HaveOccurred()) - updates := fcm.GetWindowUpdates() - Expect(updates).To(HaveLen(1)) - Expect(updates[0]).To(Equal(WindowUpdate{StreamID: 4, Offset: 190})) - }) - - It("gets connection level window updates", func() { - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.UpdateHighestReceived(6, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesRead(4, 90) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesRead(6, 90) - Expect(err).ToNot(HaveOccurred()) - updates := fcm.GetWindowUpdates() - Expect(updates).To(HaveLen(3)) - Expect(updates).ToNot(ContainElement(WindowUpdate{StreamID: 0, Offset: 200})) - }) - - It("errors when AddBytesRead is called for a stream doesn't exist", func() { - err := fcm.AddBytesRead(17, 1000) - Expect(err).To(MatchError(errMapAccess)) - }) - - It("increases the connection-level window, when a stream window was increased by autotuning", func() { - setRtt(10 * time.Millisecond) - fcm.streamFlowController[4].lastWindowUpdateTime = time.Now().Add(-1 * time.Millisecond) - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesRead(4, 90) - Expect(err).ToNot(HaveOccurred()) - updates := fcm.GetWindowUpdates() - Expect(updates).To(HaveLen(2)) - connLevelIncrement := protocol.ByteCount(protocol.ConnectionFlowControlMultiplier * 200) // 300 - Expect(updates).To(ContainElement(WindowUpdate{StreamID: 4, Offset: 290})) - Expect(updates).To(ContainElement(WindowUpdate{StreamID: 0, Offset: 90 + connLevelIncrement})) - }) - - It("doesn't increase the connection-level window, when a non-contributing stream window was increased by autotuning", func() { - setRtt(10 * time.Millisecond) - fcm.streamFlowController[1].lastWindowUpdateTime = time.Now().Add(-1 * time.Millisecond) - err := fcm.UpdateHighestReceived(1, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesRead(1, 90) - Expect(err).ToNot(HaveOccurred()) - updates := fcm.GetWindowUpdates() - Expect(updates).To(HaveLen(1)) - Expect(updates).To(ContainElement(WindowUpdate{StreamID: 1, Offset: 290})) - // the only window update is for stream 1, thus there's no connection-level window update - }) - }) - }) - - Context("resetting a stream", func() { - BeforeEach(func() { - fcm.NewStream(1, false) - fcm.NewStream(4, true) - fcm.NewStream(6, true) - fcm.streamFlowController[1].bytesSent = 41 - fcm.streamFlowController[4].bytesSent = 42 - - for _, fc := range fcm.streamFlowController { - fc.receiveWindow = 100 - fc.receiveWindowIncrement = 100 - } - fcm.connFlowController.receiveWindow = 200 - fcm.connFlowController.receiveWindowIncrement = 200 - }) - - It("updates the connection level flow controller if the stream contributes", func() { - err := fcm.ResetStream(4, 100) - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.highestReceived).To(Equal(protocol.ByteCount(100))) - Expect(fcm.streamFlowController[4].highestReceived).To(Equal(protocol.ByteCount(100))) - }) - - It("does not update the connection level flow controller if the stream does not contribute", func() { - err := fcm.ResetStream(1, 100) - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.highestReceived).To(BeZero()) - Expect(fcm.streamFlowController[1].highestReceived).To(Equal(protocol.ByteCount(100))) - }) - - It("errors if the byteOffset is smaller than a byteOffset that set earlier", func() { - err := fcm.UpdateHighestReceived(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.ResetStream(4, 50) - Expect(err).To(MatchError(qerr.StreamDataAfterTermination)) - }) - - It("returns an error when called with an unknown stream", func() { - err := fcm.ResetStream(1337, 0x1337) - Expect(err).To(MatchError(errMapAccess)) - }) - - Context("flow control violations", func() { - It("errors when encountering a stream level flow control violation", func() { - err := fcm.ResetStream(4, 101) - Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 101 bytes on stream 4, allowed 100 bytes"))) - }) - - It("errors when encountering a connection-level flow control violation", func() { - fcm.streamFlowController[4].receiveWindow = 300 - fcm.streamFlowController[6].receiveWindow = 300 - err := fcm.ResetStream(4, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.ResetStream(6, 101) - Expect(err).To(MatchError(qerr.Error(qerr.FlowControlReceivedTooMuchData, "Received 201 bytes for the connection, allowed 200 bytes"))) - }) - }) - }) - - Context("sending data", func() { - It("adds bytes sent for all stream contributing to connection level flow control", func() { - fcm.NewStream(1, false) - fcm.NewStream(3, true) - fcm.NewStream(5, true) - err := fcm.AddBytesSent(1, 100) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesSent(3, 200) - Expect(err).ToNot(HaveOccurred()) - err = fcm.AddBytesSent(5, 500) - Expect(err).ToNot(HaveOccurred()) - Expect(fcm.connFlowController.bytesSent).To(Equal(protocol.ByteCount(200 + 500))) - }) - - It("errors when called for a stream doesn't exist", func() { - err := fcm.AddBytesSent(17, 1000) - Expect(err).To(MatchError(errMapAccess)) - }) - - Context("window updates", func() { - It("updates the window for a normal stream", func() { - fcm.NewStream(5, true) - updated, err := fcm.UpdateStreamWindow(5, 1000) - Expect(err).ToNot(HaveOccurred()) - Expect(updated).To(BeTrue()) - }) - - It("updates the connection level window", func() { - updated := fcm.UpdateConnectionWindow(1000) - Expect(updated).To(BeTrue()) - }) - - It("errors when called for a stream that doesn't exist", func() { - _, err := fcm.UpdateStreamWindow(17, 1000) - Expect(err).To(MatchError(errMapAccess)) - }) - }) - - Context("window sizes", func() { - It("gets the window size of a stream", func() { - fcm.NewStream(5, false) - updated, err := fcm.UpdateStreamWindow(5, 1000) - Expect(err).ToNot(HaveOccurred()) - Expect(updated).To(BeTrue()) - fcm.AddBytesSent(5, 500) - size, err := fcm.SendWindowSize(5) - Expect(err).ToNot(HaveOccurred()) - Expect(size).To(Equal(protocol.ByteCount(1000 - 500))) - }) - - It("gets the connection window size", func() { - fcm.NewStream(5, true) - updated := fcm.UpdateConnectionWindow(1000) - Expect(updated).To(BeTrue()) - fcm.AddBytesSent(5, 500) - size := fcm.RemainingConnectionWindowSize() - Expect(size).To(Equal(protocol.ByteCount(1000 - 500))) - }) - - It("erros when asked for the send window size of a stream that doesn't exist", func() { - _, err := fcm.SendWindowSize(17) - Expect(err).To(MatchError(errMapAccess)) - }) - - It("limits the stream window size by the connection window size", func() { - fcm.NewStream(5, true) - updated := fcm.UpdateConnectionWindow(500) - Expect(updated).To(BeTrue()) - updated, err := fcm.UpdateStreamWindow(5, 1000) - Expect(err).ToNot(HaveOccurred()) - Expect(updated).To(BeTrue()) - size, err := fcm.SendWindowSize(5) - Expect(err).NotTo(HaveOccurred()) - Expect(size).To(Equal(protocol.ByteCount(500))) - }) - - It("does not reduce the size of the connection level window, if the stream does not contribute", func() { - fcm.NewStream(3, false) - updated := fcm.UpdateConnectionWindow(1000) - Expect(updated).To(BeTrue()) - fcm.AddBytesSent(3, 456) // WindowSize should return the same value no matter how much was sent - size := fcm.RemainingConnectionWindowSize() - Expect(size).To(Equal(protocol.ByteCount(1000))) - }) - }) - }) -}) diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index 69b84a9c..75ec6fac 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -1,29 +1,37 @@ package flowcontrol import "github.com/lucas-clemente/quic-go/internal/protocol" -import "github.com/lucas-clemente/quic-go/internal/handshake" -// WindowUpdate provides the data for WindowUpdateFrames. -type WindowUpdate struct { - StreamID protocol.StreamID - Offset protocol.ByteCount +type flowController interface { + // for sending + SendWindowSize() protocol.ByteCount + IsBlocked() bool + UpdateSendWindow(protocol.ByteCount) + AddBytesSent(protocol.ByteCount) + // for receiving + AddBytesRead(protocol.ByteCount) + GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary } -// A FlowControlManager manages the flow control -type FlowControlManager interface { - NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) - RemoveStream(streamID protocol.StreamID) - UpdateTransportParameters(*handshake.TransportParameters) - // methods needed for receiving data - ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error - GetWindowUpdates() []WindowUpdate - GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) - // methods needed for sending data - AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error - SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) - RemainingConnectionWindowSize() protocol.ByteCount - UpdateStreamWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) - UpdateConnectionWindow(offset protocol.ByteCount) bool +// A StreamFlowController is a flow controller for a QUIC stream. +type StreamFlowController interface { + flowController + // for receiving + // UpdateHighestReceived should be called when a new highest offset is received + // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame + UpdateHighestReceived(offset protocol.ByteCount, final bool) error +} + +// The ConnectionFlowController is the flow controller for the connection. +type ConnectionFlowController interface { + flowController +} + +type connectionFlowControllerI interface { + ConnectionFlowController + // The following two methods are not supposed to be called from outside this packet, but are needed internally + // for sending + EnsureMinimumWindowIncrement(protocol.ByteCount) + // for receiving + IncrementHighestReceived(protocol.ByteCount) error } diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 6ff80f28..8cad9e66 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -1,30 +1,39 @@ package flowcontrol import ( + "fmt" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" ) type streamFlowController struct { baseFlowController + connection connectionFlowControllerI + streamID protocol.StreamID contributesToConnection bool // does the stream contribute to connection level flow control } -// newStreamFlowController gets a new flow controller for a stream -func newStreamFlowController( +var _ StreamFlowController = &streamFlowController{} + +// NewStreamFlowController gets a new flow controller for a stream +func NewStreamFlowController( streamID protocol.StreamID, contributesToConnection bool, + cfc ConnectionFlowController, receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, rttStats *congestion.RTTStats, -) *streamFlowController { +) StreamFlowController { return &streamFlowController{ streamID: streamID, contributesToConnection: contributesToConnection, + connection: cfc.(connectionFlowControllerI), baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, @@ -35,32 +44,73 @@ func newStreamFlowController( } } -func (c *streamFlowController) ContributesToConnection() bool { - return c.contributesToConnection -} - // UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher // it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before -// This error occurs every time StreamFrames get reordered and has to be ignored in that case -// It should only be treated as an error when resetting a stream -func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) { +func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // TODO(#382): check for StreamDataAfterTermination errors, when receiving an offset after we already received a final offset if byteOffset == c.highestReceived { - return 0, nil + return nil } - if byteOffset > c.highestReceived { - increment := byteOffset - c.highestReceived - c.highestReceived = byteOffset - return increment, nil + if byteOffset <= c.highestReceived { + // a STREAM_FRAME with a higher offset was received before. + if final { + // If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream + return qerr.StreamDataAfterTermination + } + // this is a reordered STREAM_FRAME + return nil } - return 0, ErrReceivedSmallerByteOffset + + increment := byteOffset - c.highestReceived + c.highestReceived = byteOffset + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow)) + } + if c.contributesToConnection { + return c.connection.IncrementHighestReceived(increment) + } + return nil } -func (c *streamFlowController) MaybeUpdateWindow() (bool, protocol.ByteCount, protocol.ByteCount) { - oldWindowSize := c.receiveWindowIncrement - updated, newIncrement, newOffset := c.baseFlowController.MaybeUpdateWindow() - // debug log, if the window size was actually increased - if oldWindowSize < c.receiveWindowIncrement { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { + c.baseFlowController.AddBytesRead(n) + if c.contributesToConnection { + c.connection.AddBytesRead(n) } - return updated, newIncrement, newOffset +} + +func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { + c.baseFlowController.AddBytesSent(n) + if c.contributesToConnection { + c.connection.AddBytesSent(n) + } +} + +func (c *streamFlowController) SendWindowSize() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + window := c.baseFlowController.sendWindowSize() + if c.contributesToConnection { + window = utils.MinByteCount(window, c.connection.SendWindowSize()) + } + return window +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + oldWindowIncrement := c.receiveWindowIncrement + offset := c.baseFlowController.getWindowUpdate() + if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if c.contributesToConnection { + c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) + } + } + return offset } diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index 1c273bda..0273ad54 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -1,8 +1,11 @@ package flowcontrol import ( + "time" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -11,8 +14,13 @@ var _ = Describe("Stream Flow controller", func() { var controller *streamFlowController BeforeEach(func() { - controller = &streamFlowController{} - controller.rttStats = &congestion.RTTStats{} + rttStats := &congestion.RTTStats{} + controller = &streamFlowController{ + streamID: 10, + connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController), + } + controller.maxReceiveWindowIncrement = 10000 + controller.rttStats = rttStats }) Context("Constructor", func() { @@ -23,61 +31,171 @@ var _ = Describe("Stream Flow controller", func() { maxReceiveWindow := protocol.ByteCount(3000) sendWindow := protocol.ByteCount(4000) - fc := newStreamFlowController(5, true, receiveWindow, maxReceiveWindow, sendWindow, rttStats) + cc := NewConnectionFlowController(0, 0, nil) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) Expect(fc.sendWindow).To(Equal(sendWindow)) - }) - - It("says if it contributes to connection-level flow control", func() { - fc := newStreamFlowController(1, false, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) - Expect(fc.ContributesToConnection()).To(BeFalse()) - fc = newStreamFlowController(5, true, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) - Expect(fc.ContributesToConnection()).To(BeTrue()) + Expect(fc.contributesToConnection).To(BeTrue()) }) }) - Context("receive flow control", func() { - var receiveWindow protocol.ByteCount = 10000 - var receiveWindowIncrement protocol.ByteCount = 600 + Context("receiving data", func() { + Context("registering received offsets", func() { + var receiveWindow protocol.ByteCount = 10000 + var receiveWindowIncrement protocol.ByteCount = 600 - BeforeEach(func() { - controller.receiveWindow = receiveWindow - controller.receiveWindowIncrement = receiveWindowIncrement + BeforeEach(func() { + controller.receiveWindow = receiveWindow + controller.receiveWindowIncrement = receiveWindowIncrement + }) + + It("updates the highestReceived", func() { + controller.highestReceived = 1337 + err := controller.UpdateHighestReceived(1338, false) + Expect(err).ToNot(HaveOccurred()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338))) + }) + + It("informs the connection flow controller about received data", func() { + controller.highestReceived = 10 + controller.contributesToConnection = true + controller.connection.(*connectionFlowController).highestReceived = 100 + err := controller.UpdateHighestReceived(20, false) + Expect(err).ToNot(HaveOccurred()) + Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10))) + }) + + It("doesn't informs the connection flow controller about received data if it doesn't contribute", func() { + controller.highestReceived = 10 + controller.connection.(*connectionFlowController).highestReceived = 100 + err := controller.UpdateHighestReceived(20, false) + Expect(err).ToNot(HaveOccurred()) + Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100))) + }) + + It("does not decrease the highestReceived", func() { + controller.highestReceived = 1337 + err := controller.UpdateHighestReceived(1000, false) + Expect(err).ToNot(HaveOccurred()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337))) + }) + + It("does nothing when setting the same byte offset", func() { + controller.highestReceived = 1337 + err := controller.UpdateHighestReceived(1337, false) + Expect(err).ToNot(HaveOccurred()) + }) + + It("does not give a flow control violation when using the window completely", func() { + err := controller.UpdateHighestReceived(receiveWindow, false) + Expect(err).ToNot(HaveOccurred()) + }) + + It("detects a flow control violation", func() { + err := controller.UpdateHighestReceived(receiveWindow+1, false) + Expect(err).To(MatchError("FlowControlReceivedTooMuchData: Received 10001 bytes on stream 10, allowed 10000 bytes")) + }) + + It("accepts a final offset higher than the highest received", func() { + controller.highestReceived = 100 + err := controller.UpdateHighestReceived(101, true) + Expect(err).ToNot(HaveOccurred()) + Expect(controller.highestReceived).To(Equal(protocol.ByteCount(101))) + }) + + It("errors when receiving a final offset smaller than the highest offset received so far", func() { + controller.highestReceived = 100 + err := controller.UpdateHighestReceived(99, true) + Expect(err).To(MatchError(qerr.StreamDataAfterTermination)) + }) }) - It("updates the highestReceived", func() { - controller.highestReceived = 1337 - increment, err := controller.UpdateHighestReceived(1338) - Expect(err).ToNot(HaveOccurred()) - Expect(increment).To(Equal(protocol.ByteCount(1338 - 1337))) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338))) + Context("registering data read", func() { + It("saves when data is read, on a stream not contributing to the connection", func() { + controller.AddBytesRead(100) + Expect(controller.bytesRead).To(Equal(protocol.ByteCount(100))) + Expect(controller.connection.(*connectionFlowController).bytesRead).To(BeZero()) + }) + + It("saves when data is read, on a stream not contributing to the connection", func() { + controller.contributesToConnection = true + controller.AddBytesRead(200) + Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200))) + Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200))) + }) }) - It("does not decrease the highestReceived", func() { - controller.highestReceived = 1337 - increment, err := controller.UpdateHighestReceived(1000) - Expect(err).To(MatchError(ErrReceivedSmallerByteOffset)) - Expect(increment).To(BeZero()) - Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337))) + Context("generating window updates", func() { + var oldIncrement protocol.ByteCount + + // update the congestion such that it returns a given value for the smoothed RTT + setRtt := func(t time.Duration) { + controller.rttStats.UpdateRTT(t, 0, time.Now()) + Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked + } + + BeforeEach(func() { + controller.receiveWindow = 100 + controller.receiveWindowIncrement = 60 + controller.connection.(*connectionFlowController).receiveWindowIncrement = 120 + oldIncrement = controller.receiveWindowIncrement + }) + + It("tells the connection flow controller when the window was autotuned", func() { + controller.contributesToConnection = true + controller.AddBytesRead(75) + setRtt(20 * time.Millisecond) + controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + offset := controller.GetWindowUpdate() + Expect(offset).To(Equal(protocol.ByteCount(75 + 2*60))) + Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) + Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(float64(controller.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))) + }) + + It("doesn't tell the connection flow controller if it doesn't contribute", func() { + controller.contributesToConnection = false + controller.AddBytesRead(75) + setRtt(20 * time.Millisecond) + controller.lastWindowUpdateTime = time.Now().Add(-35 * time.Millisecond) + offset := controller.GetWindowUpdate() + Expect(offset).ToNot(BeZero()) + Expect(controller.receiveWindowIncrement).To(Equal(2 * oldIncrement)) + Expect(controller.connection.(*connectionFlowController).receiveWindowIncrement).To(Equal(protocol.ByteCount(120))) // unchanged + }) + }) + }) + + Context("sending data", func() { + It("gets the size of the send window", func() { + controller.UpdateSendWindow(15) + controller.AddBytesSent(5) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10))) }) - It("does not error when setting the same byte offset", func() { - controller.highestReceived = 1337 - increment, err := controller.UpdateHighestReceived(1337) - Expect(err).ToNot(HaveOccurred()) - Expect(increment).To(BeZero()) + It("doesn't care about the connection-level window, if it doesn't contribute", func() { + controller.UpdateSendWindow(15) + controller.connection.UpdateSendWindow(1) + controller.AddBytesSent(5) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10))) }) - It("detects a flow control violation", func() { - controller.UpdateHighestReceived(receiveWindow + 1) - Expect(controller.CheckFlowControlViolation()).To(BeTrue()) + It("makes sure that it doesn't overflow the connection-level window", func() { + controller.contributesToConnection = true + controller.connection.UpdateSendWindow(12) + controller.UpdateSendWindow(20) + controller.AddBytesSent(10) + Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(2))) }) - It("does not give a flow control violation when using the window completely", func() { - controller.UpdateHighestReceived(receiveWindow) - Expect(controller.CheckFlowControlViolation()).To(BeFalse()) + It("doesn't say that it's blocked, if only the connection is blocked", func() { + controller.contributesToConnection = true + controller.connection.UpdateSendWindow(50) + controller.UpdateSendWindow(100) + controller.AddBytesSent(50) + Expect(controller.connection.IsBlocked()).To(BeTrue()) + Expect(controller.IsBlocked()).To(BeFalse()) }) }) }) diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go new file mode 100644 index 00000000..26c77cbe --- /dev/null +++ b/internal/mocks/connection_flow_controller.go @@ -0,0 +1,100 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/flowcontrol (interfaces: ConnectionFlowController) + +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockConnectionFlowController is a mock of ConnectionFlowController interface +type MockConnectionFlowController struct { + ctrl *gomock.Controller + recorder *MockConnectionFlowControllerMockRecorder +} + +// MockConnectionFlowControllerMockRecorder is the mock recorder for MockConnectionFlowController +type MockConnectionFlowControllerMockRecorder struct { + mock *MockConnectionFlowController +} + +// NewMockConnectionFlowController creates a new mock instance +func NewMockConnectionFlowController(ctrl *gomock.Controller) *MockConnectionFlowController { + mock := &MockConnectionFlowController{ctrl: ctrl} + mock.recorder = &MockConnectionFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockConnectionFlowController) EXPECT() *MockConnectionFlowControllerMockRecorder { + return _m.recorder +} + +// AddBytesRead mocks base method +func (_m *MockConnectionFlowController) AddBytesRead(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "AddBytesRead", _param0) +} + +// AddBytesRead indicates an expected call of AddBytesRead +func (_mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method +func (_m *MockConnectionFlowController) AddBytesSent(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "AddBytesSent", _param0) +} + +// AddBytesSent indicates an expected call of AddBytesSent +func (_mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method +func (_m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate +func (_mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) +} + +// IsBlocked mocks base method +func (_m *MockConnectionFlowController) IsBlocked() bool { + ret := _m.ctrl.Call(_m, "IsBlocked") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsBlocked indicates an expected call of IsBlocked +func (_mr *MockConnectionFlowControllerMockRecorder) IsBlocked() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IsBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsBlocked)) +} + +// SendWindowSize mocks base method +func (_m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize +func (_mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) +} + +// UpdateSendWindow mocks base method +func (_m *MockConnectionFlowController) UpdateSendWindow(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "UpdateSendWindow", _param0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow +func (_mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/internal/mocks/flow_control_manager.go b/internal/mocks/flow_control_manager.go deleted file mode 100644 index b90d7c03..00000000 --- a/internal/mocks/flow_control_manager.go +++ /dev/null @@ -1,189 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/flowcontrol (interfaces: FlowControlManager) - -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - flowcontrol "github.com/lucas-clemente/quic-go/internal/flowcontrol" - handshake "github.com/lucas-clemente/quic-go/internal/handshake" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockFlowControlManager is a mock of FlowControlManager interface -type MockFlowControlManager struct { - ctrl *gomock.Controller - recorder *MockFlowControlManagerMockRecorder -} - -// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager -type MockFlowControlManagerMockRecorder struct { - mock *MockFlowControlManager -} - -// NewMockFlowControlManager creates a new mock instance -func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager { - mock := &MockFlowControlManager{ctrl: ctrl} - mock.recorder = &MockFlowControlManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (_m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder { - return _m.recorder -} - -// AddBytesRead mocks base method -func (_m *MockFlowControlManager) AddBytesRead(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "AddBytesRead", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddBytesRead indicates an expected call of AddBytesRead -func (_mr *MockFlowControlManagerMockRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), arg0, arg1) -} - -// AddBytesSent mocks base method -func (_m *MockFlowControlManager) AddBytesSent(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "AddBytesSent", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddBytesSent indicates an expected call of AddBytesSent -func (_mr *MockFlowControlManagerMockRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), arg0, arg1) -} - -// GetReceiveWindow mocks base method -func (_m *MockFlowControlManager) GetReceiveWindow(_param0 protocol.StreamID) (protocol.ByteCount, error) { - ret := _m.ctrl.Call(_m, "GetReceiveWindow", _param0) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetReceiveWindow indicates an expected call of GetReceiveWindow -func (_mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), arg0) -} - -// GetWindowUpdates mocks base method -func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate { - ret := _m.ctrl.Call(_m, "GetWindowUpdates") - ret0, _ := ret[0].([]flowcontrol.WindowUpdate) - return ret0 -} - -// GetWindowUpdates indicates an expected call of GetWindowUpdates -func (_mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates)) -} - -// NewStream mocks base method -func (_m *MockFlowControlManager) NewStream(_param0 protocol.StreamID, _param1 bool) { - _m.ctrl.Call(_m, "NewStream", _param0, _param1) -} - -// NewStream indicates an expected call of NewStream -func (_mr *MockFlowControlManagerMockRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), arg0, arg1) -} - -// RemainingConnectionWindowSize mocks base method -func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize -func (_mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize)) -} - -// RemoveStream mocks base method -func (_m *MockFlowControlManager) RemoveStream(_param0 protocol.StreamID) { - _m.ctrl.Call(_m, "RemoveStream", _param0) -} - -// RemoveStream indicates an expected call of RemoveStream -func (_mr *MockFlowControlManagerMockRecorder) RemoveStream(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), arg0) -} - -// ResetStream mocks base method -func (_m *MockFlowControlManager) ResetStream(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "ResetStream", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -// ResetStream indicates an expected call of ResetStream -func (_mr *MockFlowControlManagerMockRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), arg0, arg1) -} - -// SendWindowSize mocks base method -func (_m *MockFlowControlManager) SendWindowSize(_param0 protocol.StreamID) (protocol.ByteCount, error) { - ret := _m.ctrl.Call(_m, "SendWindowSize", _param0) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SendWindowSize indicates an expected call of SendWindowSize -func (_mr *MockFlowControlManagerMockRecorder) SendWindowSize(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), arg0) -} - -// UpdateConnectionWindow mocks base method -func (_m *MockFlowControlManager) UpdateConnectionWindow(_param0 protocol.ByteCount) bool { - ret := _m.ctrl.Call(_m, "UpdateConnectionWindow", _param0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// UpdateConnectionWindow indicates an expected call of UpdateConnectionWindow -func (_mr *MockFlowControlManagerMockRecorder) UpdateConnectionWindow(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateConnectionWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateConnectionWindow), arg0) -} - -// UpdateHighestReceived mocks base method -func (_m *MockFlowControlManager) UpdateHighestReceived(_param0 protocol.StreamID, _param1 protocol.ByteCount) error { - ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateHighestReceived indicates an expected call of UpdateHighestReceived -func (_mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), arg0, arg1) -} - -// UpdateStreamWindow mocks base method -func (_m *MockFlowControlManager) UpdateStreamWindow(_param0 protocol.StreamID, _param1 protocol.ByteCount) (bool, error) { - ret := _m.ctrl.Call(_m, "UpdateStreamWindow", _param0, _param1) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateStreamWindow indicates an expected call of UpdateStreamWindow -func (_mr *MockFlowControlManagerMockRecorder) UpdateStreamWindow(arg0, arg1 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateStreamWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateStreamWindow), arg0, arg1) -} - -// UpdateTransportParameters mocks base method -func (_m *MockFlowControlManager) UpdateTransportParameters(_param0 *handshake.TransportParameters) { - _m.ctrl.Call(_m, "UpdateTransportParameters", _param0) -} - -// UpdateTransportParameters indicates an expected call of UpdateTransportParameters -func (_mr *MockFlowControlManagerMockRecorder) UpdateTransportParameters(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateTransportParameters", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateTransportParameters), arg0) -} diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index cd59f3e5..4894c9f2 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -1,4 +1,6 @@ package mocks -//go:generate sh -c "./mockgen_internal.sh mocks flow_control_manager.go github.com/lucas-clemente/quic-go/internal/flowcontrol FlowControlManager" +//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "./mockgen_stream.sh mocks stream.go github.com/lucas-clemente/quic-go StreamI" //go:generate sh -c "goimports -w ." diff --git a/internal/mocks/mockgen_internal.sh b/internal/mocks/mockgen_internal.sh index c3793620..c45e4563 100755 --- a/internal/mocks/mockgen_internal.sh +++ b/internal/mocks/mockgen_internal.sh @@ -18,4 +18,4 @@ PACKAGE_PATH=${3/internal/internalpackage} mockgen -package $1 -self_package $1 -destination $2 $PACKAGE_PATH $4 sed -i '' 's/internalpackage/internal/g' $2 -rm -rf "$TEMP_DIR" +rm -r "$TEMP_DIR" diff --git a/internal/mocks/mockgen_stream.sh b/internal/mocks/mockgen_stream.sh new file mode 100755 index 00000000..725a897d --- /dev/null +++ b/internal/mocks/mockgen_stream.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Mockgen refuses to generate mocks for internal packages. +# This script copies the internal directory and renames it to internalpackage. +# That way, mockgen can generate the mock. +# Afterwards, it corrects the import paths (replaces internalpackage back to internal). + +TEMP_DIR=$(mktemp -d) +mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ + +cp -r $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ +echo "type StreamI = streamI" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/stream.go + +export GOPATH="$TEMP_DIR:$GOPATH" + +mockgen -package $1 -self_package $1 -destination $2 $3 $4 + +rm -r "$TEMP_DIR" diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go new file mode 100644 index 00000000..5592282c --- /dev/null +++ b/internal/mocks/stream.go @@ -0,0 +1,283 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamI) + +package mocks + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockStreamI is a mock of StreamI interface +type MockStreamI struct { + ctrl *gomock.Controller + recorder *MockStreamIMockRecorder +} + +// MockStreamIMockRecorder is the mock recorder for MockStreamI +type MockStreamIMockRecorder struct { + mock *MockStreamI +} + +// NewMockStreamI creates a new mock instance +func NewMockStreamI(ctrl *gomock.Controller) *MockStreamI { + mock := &MockStreamI{ctrl: ctrl} + mock.recorder = &MockStreamIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockStreamI) EXPECT() *MockStreamIMockRecorder { + return _m.recorder +} + +// AddStreamFrame mocks base method +func (_m *MockStreamI) AddStreamFrame(_param0 *wire.StreamFrame) error { + ret := _m.ctrl.Call(_m, "AddStreamFrame", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddStreamFrame indicates an expected call of AddStreamFrame +func (_mr *MockStreamIMockRecorder) AddStreamFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddStreamFrame", reflect.TypeOf((*MockStreamI)(nil).AddStreamFrame), arg0) +} + +// Cancel mocks base method +func (_m *MockStreamI) Cancel(_param0 error) { + _m.ctrl.Call(_m, "Cancel", _param0) +} + +// Cancel indicates an expected call of Cancel +func (_mr *MockStreamIMockRecorder) Cancel(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Cancel", reflect.TypeOf((*MockStreamI)(nil).Cancel), arg0) +} + +// Close mocks base method +func (_m *MockStreamI) Close() error { + ret := _m.ctrl.Call(_m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (_mr *MockStreamIMockRecorder) Close() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) +} + +// Context mocks base method +func (_m *MockStreamI) Context() context.Context { + ret := _m.ctrl.Call(_m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context +func (_mr *MockStreamIMockRecorder) Context() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) +} + +// Finished mocks base method +func (_m *MockStreamI) Finished() bool { + ret := _m.ctrl.Call(_m, "Finished") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Finished indicates an expected call of Finished +func (_mr *MockStreamIMockRecorder) Finished() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Finished", reflect.TypeOf((*MockStreamI)(nil).Finished)) +} + +// GetDataForWriting mocks base method +func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) []byte { + ret := _m.ctrl.Call(_m, "GetDataForWriting", _param0) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// GetDataForWriting indicates an expected call of GetDataForWriting +func (_mr *MockStreamIMockRecorder) GetDataForWriting(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetDataForWriting", reflect.TypeOf((*MockStreamI)(nil).GetDataForWriting), arg0) +} + +// GetWindowUpdate mocks base method +func (_m *MockStreamI) GetWindowUpdate() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate +func (_mr *MockStreamIMockRecorder) GetWindowUpdate() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).GetWindowUpdate)) +} + +// GetWriteOffset mocks base method +func (_m *MockStreamI) GetWriteOffset() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetWriteOffset") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWriteOffset indicates an expected call of GetWriteOffset +func (_mr *MockStreamIMockRecorder) GetWriteOffset() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWriteOffset", reflect.TypeOf((*MockStreamI)(nil).GetWriteOffset)) +} + +// IsFlowControlBlocked mocks base method +func (_m *MockStreamI) IsFlowControlBlocked() bool { + ret := _m.ctrl.Call(_m, "IsFlowControlBlocked") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsFlowControlBlocked indicates an expected call of IsFlowControlBlocked +func (_mr *MockStreamIMockRecorder) IsFlowControlBlocked() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IsFlowControlBlocked", reflect.TypeOf((*MockStreamI)(nil).IsFlowControlBlocked)) +} + +// LenOfDataForWriting mocks base method +func (_m *MockStreamI) LenOfDataForWriting() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "LenOfDataForWriting") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// LenOfDataForWriting indicates an expected call of LenOfDataForWriting +func (_mr *MockStreamIMockRecorder) LenOfDataForWriting() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "LenOfDataForWriting", reflect.TypeOf((*MockStreamI)(nil).LenOfDataForWriting)) +} + +// Read mocks base method +func (_m *MockStreamI) Read(_param0 []byte) (int, error) { + ret := _m.ctrl.Call(_m, "Read", _param0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (_mr *MockStreamIMockRecorder) Read(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) +} + +// RegisterRemoteError mocks base method +func (_m *MockStreamI) RegisterRemoteError(_param0 error, _param1 protocol.ByteCount) error { + ret := _m.ctrl.Call(_m, "RegisterRemoteError", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterRemoteError indicates an expected call of RegisterRemoteError +func (_mr *MockStreamIMockRecorder) RegisterRemoteError(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RegisterRemoteError", reflect.TypeOf((*MockStreamI)(nil).RegisterRemoteError), arg0, arg1) +} + +// Reset mocks base method +func (_m *MockStreamI) Reset(_param0 error) { + _m.ctrl.Call(_m, "Reset", _param0) +} + +// Reset indicates an expected call of Reset +func (_mr *MockStreamIMockRecorder) Reset(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Reset", reflect.TypeOf((*MockStreamI)(nil).Reset), arg0) +} + +// SentFin mocks base method +func (_m *MockStreamI) SentFin() { + _m.ctrl.Call(_m, "SentFin") +} + +// SentFin indicates an expected call of SentFin +func (_mr *MockStreamIMockRecorder) SentFin() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SentFin", reflect.TypeOf((*MockStreamI)(nil).SentFin)) +} + +// SetDeadline mocks base method +func (_m *MockStreamI) SetDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (_mr *MockStreamIMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method +func (_m *MockStreamI) SetReadDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetReadDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (_mr *MockStreamIMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method +func (_m *MockStreamI) SetWriteDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetWriteDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (_mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) +} + +// ShouldSendFin mocks base method +func (_m *MockStreamI) ShouldSendFin() bool { + ret := _m.ctrl.Call(_m, "ShouldSendFin") + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShouldSendFin indicates an expected call of ShouldSendFin +func (_mr *MockStreamIMockRecorder) ShouldSendFin() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ShouldSendFin", reflect.TypeOf((*MockStreamI)(nil).ShouldSendFin)) +} + +// StreamID mocks base method +func (_m *MockStreamI) StreamID() protocol.StreamID { + ret := _m.ctrl.Call(_m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (_mr *MockStreamIMockRecorder) StreamID() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) +} + +// UpdateSendWindow mocks base method +func (_m *MockStreamI) UpdateSendWindow(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "UpdateSendWindow", _param0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow +func (_mr *MockStreamIMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamI)(nil).UpdateSendWindow), arg0) +} + +// Write mocks base method +func (_m *MockStreamI) Write(_param0 []byte) (int, error) { + ret := _m.ctrl.Call(_m, "Write", _param0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (_mr *MockStreamIMockRecorder) Write(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) +} diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go new file mode 100644 index 00000000..7e970bb4 --- /dev/null +++ b/internal/mocks/stream_flow_controller.go @@ -0,0 +1,112 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/flowcontrol (interfaces: StreamFlowController) + +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockStreamFlowController is a mock of StreamFlowController interface +type MockStreamFlowController struct { + ctrl *gomock.Controller + recorder *MockStreamFlowControllerMockRecorder +} + +// MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController +type MockStreamFlowControllerMockRecorder struct { + mock *MockStreamFlowController +} + +// NewMockStreamFlowController creates a new mock instance +func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { + mock := &MockStreamFlowController{ctrl: ctrl} + mock.recorder = &MockStreamFlowControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { + return _m.recorder +} + +// AddBytesRead mocks base method +func (_m *MockStreamFlowController) AddBytesRead(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "AddBytesRead", _param0) +} + +// AddBytesRead indicates an expected call of AddBytesRead +func (_mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) +} + +// AddBytesSent mocks base method +func (_m *MockStreamFlowController) AddBytesSent(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "AddBytesSent", _param0) +} + +// AddBytesSent indicates an expected call of AddBytesSent +func (_mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) +} + +// GetWindowUpdate mocks base method +func (_m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetWindowUpdate indicates an expected call of GetWindowUpdate +func (_mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) +} + +// IsBlocked mocks base method +func (_m *MockStreamFlowController) IsBlocked() bool { + ret := _m.ctrl.Call(_m, "IsBlocked") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsBlocked indicates an expected call of IsBlocked +func (_mr *MockStreamFlowControllerMockRecorder) IsBlocked() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IsBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsBlocked)) +} + +// SendWindowSize mocks base method +func (_m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "SendWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// SendWindowSize indicates an expected call of SendWindowSize +func (_mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) +} + +// UpdateHighestReceived mocks base method +func (_m *MockStreamFlowController) UpdateHighestReceived(_param0 protocol.ByteCount, _param1 bool) error { + ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHighestReceived indicates an expected call of UpdateHighestReceived +func (_mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) +} + +// UpdateSendWindow mocks base method +func (_m *MockStreamFlowController) UpdateSendWindow(_param0 protocol.ByteCount) { + _m.ctrl.Call(_m, "UpdateSendWindow", _param0) +} + +// UpdateSendWindow indicates an expected call of UpdateSendWindow +func (_mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) +} diff --git a/packet_packer_test.go b/packet_packer_test.go index fea46350..15455a78 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -6,6 +6,7 @@ import ( "math" "github.com/lucas-clemente/quic-go/ackhandler" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -60,9 +61,9 @@ var _ = Describe("Packet packer", func() { ) BeforeEach(func() { - cryptoStream = &stream{} + cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} - streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer) + streamsMap := newStreamsMap(nil, protocol.PerspectiveServer) streamsMap.streams[1] = cryptoStream streamsMap.openStreams = []protocol.StreamID{1} streamFramer = newStreamFramer(streamsMap, nil) diff --git a/session.go b/session.go index 1da6623a..23ec4cb5 100644 --- a/session.go +++ b/session.go @@ -67,7 +67,7 @@ type session struct { receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - flowControlManager flowcontrol.FlowControlManager + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker @@ -109,7 +109,8 @@ type session struct { sessionCreationTime time.Time lastNetworkActivityTime time.Time - remoteIdleTimeout time.Duration + + peerParams *handshake.TransportParameters timer *utils.Timer // keepAlivePingSent stores whether a Ping frame was sent to the peer or not @@ -251,13 +252,13 @@ func (s *session) setup( return nil, nil, err } - s.flowControlManager = flowcontrol.NewFlowControlManager( - protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) - s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective) - s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) + s.streamsMap = newStreamsMap(s.newStream, s.perspective) + s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController) s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, s.streamFramer, @@ -348,7 +349,7 @@ runLoop: s.sentPacketHandler.OnAlarm() } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.remoteIdleTimeout/2 { + if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { // send the PING frame since there is no activity in the session s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true @@ -389,7 +390,7 @@ func (s *session) Context() context.Context { func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.lastNetworkActivityTime.Add(s.remoteIdleTimeout / 2) + deadline = s.lastNetworkActivityTime.Add(s.peerParams.IdleTimeout / 2) } else { deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout) } @@ -538,7 +539,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { func (s *session) handleWindowUpdateFrame(frame *wire.WindowUpdateFrame) error { if frame.StreamID == 0 { - s.flowControlManager.UpdateConnectionWindow(frame.ByteOffset) + s.connFlowController.UpdateSendWindow(frame.ByteOffset) return nil } @@ -549,8 +550,8 @@ func (s *session) handleWindowUpdateFrame(frame *wire.WindowUpdateFrame) error { if str == nil { return errWindowUpdateOnClosedStream } - _, err = s.flowControlManager.UpdateStreamWindow(frame.StreamID, frame.ByteOffset) - return err + str.UpdateSendWindow(frame.ByteOffset) + return nil } func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { @@ -561,9 +562,7 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } - - str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) } func (s *session) handleAckFrame(frame *wire.AckFrame) error { @@ -627,12 +626,15 @@ func (s *session) handleCloseError(closeErr closeError) error { } func (s *session) processTransportParameters(params *handshake.TransportParameters) { - s.remoteIdleTimeout = params.IdleTimeout - s.flowControlManager.UpdateTransportParameters(params) + s.peerParams = params s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) if params.OmitConnectionID { s.packer.SetOmitConnectionID() } + s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) + s.streamsMap.Range(func(str streamI) { + str.UpdateSendWindow(params.StreamFlowControlWindow) + }) } func (s *session) sendPacket() error { @@ -693,15 +695,10 @@ func (s *session) sendPacket() error { utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) // resend the frames that were in the packet for _, frame := range retransmitPacket.GetFramesForRetransmission() { + // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window switch f := frame.(type) { case *wire.StreamFrame: s.streamFramer.AddFrameForRetransmission(f) - case *wire.WindowUpdateFrame: - // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream - currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) - if err == nil && f.ByteOffset >= currentOffset { - s.packer.QueueControlFrame(f) - } default: s.packer.QueueControlFrame(frame) } @@ -813,14 +810,26 @@ func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.By s.scheduleSending() } -func (s *session) newStream(id protocol.StreamID) *stream { +func (s *session) newStream(id protocol.StreamID) streamI { // TODO: find a better solution for determining which streams contribute to connection level flow control - if id == 1 || id == 3 { - s.flowControlManager.NewStream(id, false) - } else { - s.flowControlManager.NewStream(id, true) + var contributesToConnection bool + if id != 1 && id != 3 { + contributesToConnection = true } - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) + var initialSendWindow protocol.ByteCount + if s.peerParams != nil { + initialSendWindow = s.peerParams.StreamFlowControlWindow + } + flowController := flowcontrol.NewStreamFlowController( + id, + contributesToConnection, + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + initialSendWindow, + s.rttStats, + ) + return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { @@ -862,10 +871,20 @@ func (s *session) tryDecryptingQueuedPackets() { } func (s *session) getWindowUpdateFrames() []*wire.WindowUpdateFrame { - updates := s.flowControlManager.GetWindowUpdates() - res := make([]*wire.WindowUpdateFrame, len(updates)) - for i, u := range updates { - res[i] = &wire.WindowUpdateFrame{StreamID: u.StreamID, ByteOffset: u.Offset} + var res []*wire.WindowUpdateFrame + s.streamsMap.Range(func(str streamI) { + if offset := str.GetWindowUpdate(); offset != 0 { + res = append(res, &wire.WindowUpdateFrame{ + StreamID: str.StreamID(), + ByteOffset: offset, + }) + } + }) + if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { + res = append(res, &wire.WindowUpdateFrame{ + StreamID: 0, + ByteOffset: offset, + }) } return res } diff --git a/session_test.go b/session_test.go index c9d6c1dc..fc7d5567 100644 --- a/session_test.go +++ b/session_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -256,357 +255,241 @@ var _ = Describe("Session", func() { }) }) - Context("when handling stream frames", func() { + Context("frame handling", func() { BeforeEach(func() { - sess.streamsMap.UpdateMaxStreamLimit(100) + sess.streamsMap.newStream = func(id protocol.StreamID) streamI { + str := mocks.NewMockStreamI(mockCtrl) + str.EXPECT().StreamID().Return(id).AnyTimes() + if id == 1 { + str.EXPECT().Finished().AnyTimes() + } + return str + } }) - It("makes new streams", func() { - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, + Context("when handling STREAM frames", func() { + BeforeEach(func() { + sess.streamsMap.UpdateMaxStreamLimit(100) + }) + + It("makes new streams", func() { + f := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + } + newStreamLambda := sess.streamsMap.newStream + sess.streamsMap.newStream = func(id protocol.StreamID) streamI { + str := newStreamLambda(id) + if id == 5 { + str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f) + } + return str + } + err := sess.handleStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.streamsMap.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + }) + + It("handles existing streams", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca}, + } + f2 := &wire.StreamFrame{ + StreamID: 5, + Offset: 2, + Data: []byte{0xfb, 0xad}, + } + newStreamLambda := sess.streamsMap.newStream + sess.streamsMap.newStream = func(id protocol.StreamID) streamI { + str := newStreamLambda(id) + if id == 5 { + str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f1) + str.(*mocks.MockStreamI).EXPECT().AddStreamFrame(f2) + } + return str + } + sess.handleStreamFrame(f1) + numOpenStreams := len(sess.streamsMap.openStreams) + sess.handleStreamFrame(f2) + Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + }) + + It("ignores STREAM frames for closed streams", func() { + sess.streamsMap.streams[5] = nil + str, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) // make sure the stream is gone + err = sess.handleStreamFrame(&wire.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + }) + Expect(err).ToNot(HaveOccurred()) }) - p := make([]byte, 4) - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - _, err = str.Read(p) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) - It("does not reject existing streams with even StreamIDs", func() { + Context("handling RST_STREAM frames", func() { + It("closes the streams for writing", func() { + str, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError( + errors.New("RST_STREAM received with code 42"), + protocol.ByteCount(0x1337), + ) + err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: 5, + ErrorCode: 42, + ByteOffset: 0x1337, + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("queues a RST_STERAM frame", func() { + sess.queueResetStreamFrame(5, 0x1337) + Expect(sess.packer.controlFrames).To(HaveLen(1)) + Expect(sess.packer.controlFrames[0].(*wire.RstStreamFrame)).To(Equal(&wire.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + })) + }) + + It("returns errors", func() { + testErr := errors.New("flow control violation") + str, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().RegisterRemoteError(gomock.Any(), gomock.Any()).Return(testErr) + err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: 5, + ByteOffset: 0x1337, + }) + Expect(err).To(MatchError(testErr)) + }) + + It("ignores the error when the stream is not known", func() { + str, err := sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().Finished().Return(true) + sess.streamsMap.DeleteClosedStreams() + str, err = sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + err = sess.handleFrames([]wire.Frame{&wire.RstStreamFrame{ + StreamID: 3, + ErrorCode: 42, + }}) + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("handling WINDOW_UPDATE frames", func() { + var connFC *mocks.MockConnectionFlowController + + BeforeEach(func() { + connFC = mocks.NewMockConnectionFlowController(mockCtrl) + sess.connFlowController = connFC + }) + + It("updates the flow control window of a stream", func() { + offset := protocol.ByteCount(0x1234) + str, err := sess.GetOrOpenStream(5) + str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(offset) + Expect(err).ToNot(HaveOccurred()) + err = sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ + StreamID: 5, + ByteOffset: offset, + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("updates the flow control window of the connection", func() { + offset := protocol.ByteCount(0x800000) + connFC.EXPECT().UpdateSendWindow(offset) + err := sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ + StreamID: 0, + ByteOffset: offset, + }) + Expect(err).ToNot(HaveOccurred()) + }) + + It("opens a new stream when receiving a WINDOW_UPDATE for an unknown stream", func() { + newStreamLambda := sess.streamsMap.newStream + sess.streamsMap.newStream = func(id protocol.StreamID) streamI { + str := newStreamLambda(id) + if id == 5 { + str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) + } + return str + } + err := sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ + StreamID: 5, + ByteOffset: 0x1337, + }) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.streamsMap.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(str).ToNot(BeNil()) + }) + + It("ignores WINDOW_UPDATEs for a closed stream", func() { + str, err := sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().Finished().Return(true) + err = sess.streamsMap.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + str, err = sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + err = sess.handleFrames([]wire.Frame{&wire.WindowUpdateFrame{ + StreamID: 3, + ByteOffset: 1337, + }}) + Expect(err).NotTo(HaveOccurred()) + }) + }) + + It("handles PING frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.PingFrame{}}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles BLOCKED frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.BlockedFrame{}}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("errors on GOAWAY frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.GoawayFrame{}}) + Expect(err).To(MatchError("unimplemented: handling GOAWAY frames")) + }) + + It("handles STOP_WAITING frames", func() { + err := sess.handleFrames([]wire.Frame{&wire.StopWaitingFrame{LeastUnacked: 10}}) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handles CONNECTION_CLOSE frames", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).To(MatchError("ProofInvalid: foobar")) + close(done) + }() _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, + sess.streamsMap.Range(func(s streamI) { + if s.StreamID() == 1 { // the crypto stream is created by the session setup and is not a mock stream + return + } + s.(*mocks.MockStreamI).EXPECT().Cancel(gomock.Any()) }) - Expect(err).ToNot(HaveOccurred()) - }) - - It("handles existing streams", func() { - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca}, - }) - numOpenStreams := len(sess.streamsMap.openStreams) - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Offset: 2, - Data: []byte{0xfb, 0xad}, - }) - Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) - p := make([]byte, 4) - str, _ := sess.streamsMap.GetOrOpenStream(5) - Expect(str).ToNot(BeNil()) - _, err := str.Read(p) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - }) - - It("cancels streams with error", func() { - testErr := errors.New("test") - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xde, 0xca, 0xfb, 0xad}, - }) - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - p := make([]byte, 4) - _, err = str.Read(p) - Expect(err).ToNot(HaveOccurred()) - sess.handleCloseError(closeError{err: testErr, remote: true}) - _, err = str.Read(p) - Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) - }) - - It("cancels empty streams with error", func() { - testErr := errors.New("test") - sess.GetOrOpenStream(5) - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - sess.handleCloseError(closeError{err: testErr, remote: true}) - _, err = str.Read([]byte{0}) - Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) - }) - - It("informs the FlowControlManager about new streams", func() { - // since the stream doesn't yet exist, this will throw an error - err := sess.flowControlManager.UpdateHighestReceived(5, 1000) - Expect(err).To(HaveOccurred()) - sess.GetOrOpenStream(5) - err = sess.flowControlManager.UpdateHighestReceived(5, 2000) - Expect(err).ToNot(HaveOccurred()) - }) - - It("ignores STREAM frames for closed streams (client-side)", func() { - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - FinBit: true, - }) - str, _ := sess.streamsMap.GetOrOpenStream(5) - Expect(str).ToNot(BeNil()) - _, err := str.Read([]byte{0}) - Expect(err).To(MatchError(io.EOF)) - str.Close() - str.sentFin() - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, _ = sess.streamsMap.GetOrOpenStream(5) - Expect(str).To(BeNil()) // make sure the stream is gone - err = sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) - Expect(err).ToNot(HaveOccurred()) - }) - - It("ignores STREAM frames for closed streams (server-side)", func() { - ostr, err := sess.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(ostr.StreamID()).To(Equal(protocol.StreamID(2))) - err = sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 2, - FinBit: true, - }) - Expect(err).ToNot(HaveOccurred()) - str, _ := sess.streamsMap.GetOrOpenStream(2) - Expect(str).ToNot(BeNil()) - _, err = str.Read([]byte{0}) - Expect(err).To(MatchError(io.EOF)) - str.Close() - str.sentFin() - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, _ = sess.streamsMap.GetOrOpenStream(2) - Expect(str).To(BeNil()) // make sure the stream is gone - err = sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 2, - FinBit: true, - }) - Expect(err).ToNot(HaveOccurred()) - }) - }) - - Context("handling RST_STREAM frames", func() { - It("closes the streams for writing", func() { - s, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ErrorCode: 42, - }) - Expect(err).ToNot(HaveOccurred()) - n, err := s.Write([]byte{0}) - Expect(n).To(BeZero()) - Expect(err).To(MatchError("RST_STREAM received with code 42")) - }) - - It("doesn't close the stream for reading", func() { - s, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ErrorCode: 42, - ByteOffset: 6, - }) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 3) - n, err := s.Read(b) - Expect(n).To(Equal(3)) - Expect(err).ToNot(HaveOccurred()) - }) - - It("queues a RST_STERAM frame with the correct offset", func() { - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.(*stream).writeOffset = 0x1337 - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0].(*wire.RstStreamFrame)).To(Equal(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - })) - Expect(str.(*stream).finished()).To(BeTrue()) - }) - - It("doesn't queue a RST_STREAM for a stream that it already sent a FIN on", func() { - str, err := sess.GetOrOpenStream(5) + err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}) Expect(err).NotTo(HaveOccurred()) - str.(*stream).sentFin() - str.Close() - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(sess.packer.controlFrames).To(BeEmpty()) - Expect(str.(*stream).finished()).To(BeTrue()) + Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(done).Should(BeClosed()) }) - - It("passes the byte offset to the flow controller", func() { - sess.streamsMap.GetOrOpenStream(5) - fcm := mocks.NewMockFlowControlManager(mockCtrl) - sess.flowControlManager = fcm - fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)) - err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }) - Expect(err).ToNot(HaveOccurred()) - }) - - It("returns errors from the flow controller", func() { - testErr := errors.New("flow control violation") - sess.streamsMap.GetOrOpenStream(5) - fcm := mocks.NewMockFlowControlManager(mockCtrl) - sess.flowControlManager = fcm - fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)).Return(testErr) - err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }) - Expect(err).To(MatchError(testErr)) - }) - - It("ignores the error when the stream is not known", func() { - err := sess.handleFrames([]wire.Frame{&wire.RstStreamFrame{ - StreamID: 5, - ErrorCode: 42, - }}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("queues a RST_STREAM when a stream gets reset locally", func() { - testErr := errors.New("testErr") - str, err := sess.streamsMap.GetOrOpenStream(5) - str.writeOffset = 0x1337 - Expect(err).ToNot(HaveOccurred()) - str.Reset(testErr) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0]).To(Equal(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x1337, - })) - Expect(str.finished()).To(BeFalse()) - }) - - It("doesn't queue another RST_STREAM, when it receives an RST_STREAM as a response for the first", func() { - testErr := errors.New("testErr") - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.Reset(testErr) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - err = sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: 5, - ByteOffset: 0x42, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - }) - }) - - Context("handling WINDOW_UPDATE frames", func() { - var fcm *mocks.MockFlowControlManager - - BeforeEach(func() { - fcm = mocks.NewMockFlowControlManager(mockCtrl) - sess.flowControlManager = fcm - fcm.EXPECT().NewStream(gomock.Any(), gomock.Any()).AnyTimes() - }) - - It("updates the Flow Control Window of a stream", func() { - offset := protocol.ByteCount(0x1234) - fcm.EXPECT().UpdateStreamWindow(protocol.StreamID(5), offset) - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - err = sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: offset, - }) - Expect(err).ToNot(HaveOccurred()) - }) - - It("updates the Flow Control Window of the connection", func() { - offset := protocol.ByteCount(0x800000) - fcm.EXPECT().UpdateConnectionWindow(offset) - err := sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ - StreamID: 0, - ByteOffset: offset, - }) - Expect(err).ToNot(HaveOccurred()) - }) - - It("opens a new stream when receiving a WINDOW_UPDATE for an unknown stream", func() { - offset := protocol.ByteCount(0x1337) - fcm.EXPECT().UpdateStreamWindow(protocol.StreamID(5), offset) - err := sess.handleWindowUpdateFrame(&wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: offset, - }) - Expect(err).ToNot(HaveOccurred()) - str, err := sess.streamsMap.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) - - It("ignores WINDOW_UPDATEs for a closed stream", func() { - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - str.Close() - str.(*stream).Cancel(nil) - Expect(str.(*stream).finished()).To(BeTrue()) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, err = sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - err = sess.handleFrames([]wire.Frame{&wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: 1337, - }}) - Expect(err).NotTo(HaveOccurred()) - }) - }) - - It("handles PING frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.PingFrame{}}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("handles BLOCKED frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.BlockedFrame{}}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("errors on GOAWAY frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.GoawayFrame{}}) - Expect(err).To(MatchError("unimplemented: handling GOAWAY frames")) - }) - - It("handles STOP_WAITING frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.StopWaitingFrame{LeastUnacked: 10}}) - Expect(err).NotTo(HaveOccurred()) - }) - - It("handles CONNECTION_CLOSE frames", func(done Done) { - go sess.run() - str, _ := sess.GetOrOpenStream(5) - err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) - Expect(err).NotTo(HaveOccurred()) - Eventually(sess.Context().Done()).Should(BeClosed()) - _, err = str.Read([]byte{0}) - Expect(err).To(MatchError(qerr.Error(42, "foobar"))) - close(done) }) It("tells its versions", func() { @@ -615,22 +498,24 @@ var _ = Describe("Session", func() { }) Context("waiting until the handshake completes", func() { - It("waits until the handshake is complete", func(done Done) { - go sess.run() + It("waits until the handshake is complete", func() { + go func() { + defer GinkgoRecover() + sess.run() + }() - var waitReturned bool + done := make(chan struct{}) go func() { defer GinkgoRecover() err := sess.WaitUntilHandshakeComplete() Expect(err).ToNot(HaveOccurred()) - waitReturned = true + close(done) }() aeadChanged <- protocol.EncryptionForwardSecure - Consistently(func() bool { return waitReturned }).Should(BeFalse()) + Consistently(done).ShouldNot(BeClosed()) close(aeadChanged) - Eventually(func() bool { return waitReturned }).Should(BeTrue()) + Eventually(done).Should(BeClosed()) Expect(sess.Close(nil)).To(Succeed()) - close(done) }) It("errors if the handshake fails", func(done Done) { @@ -668,6 +553,11 @@ var _ = Describe("Session", func() { }) Context("accepting streams", func() { + BeforeEach(func() { + // don't use the mock here + sess.streamsMap.newStream = sess.newStream + }) + It("waits for new streams", func() { strChan := make(chan Stream) // accept two streams @@ -680,10 +570,8 @@ var _ = Describe("Session", func() { } }() Consistently(strChan).ShouldNot(Receive()) - err := sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) + // this could happen e.g. by receiving a STREAM frame + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) var str Stream Eventually(strChan).Should(Receive(&str)) @@ -924,18 +812,26 @@ var _ = Describe("Session", func() { }) It("sends two WindowUpdate frames", func() { - _, err := sess.GetOrOpenStream(5) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1000)) + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0)).Times(2) + str, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - sess.flowControlManager.AddBytesRead(5, protocol.ReceiveStreamFlowControlWindow) + str.(*stream).flowController = mockFC err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + buf := &bytes.Buffer{} + (&wire.WindowUpdateFrame{ + StreamID: 5, + ByteOffset: 0x1000, + }).Write(buf, protocol.VersionWhatever) Expect(mconn.written).To(HaveLen(2)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0})))) + Expect(mconn.written).To(Receive(ContainSubstring(string(buf.Bytes())))) + Expect(mconn.written).To(Receive(ContainSubstring(string(buf.Bytes())))) }) It("sends public reset", func() { @@ -1093,69 +989,6 @@ var _ = Describe("Session", func() { _, ok = sentPackets[1].Frames[0].(*wire.StopWaitingFrame) Expect(ok).To(BeTrue()) }) - - It("retransmits a WindowUpdate if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() { - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - fcm := mocks.NewMockFlowControlManager(mockCtrl) - sess.flowControlManager = fcm - fcm.EXPECT().GetWindowUpdates() - fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x1000), nil) - wuf := &wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: 0x1000, - } - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{wuf}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }} - err = sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(sph.sentPackets).To(HaveLen(1)) - Expect(sph.sentPackets[0].Frames).To(ContainElement(wuf)) - }) - - It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() { - _, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - fcm := mocks.NewMockFlowControlManager(mockCtrl) - sess.flowControlManager = fcm - fcm.EXPECT().GetWindowUpdates() - fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x2000), nil) - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{&wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: 0x1000, - }}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }} - err = sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(sph.sentPackets).To(BeEmpty()) - }) - - It("doesn't retransmit WindowUpdates for closed streams", func() { - str, err := sess.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - // close the stream - str.(*stream).sentFin() - str.Close() - str.(*stream).RegisterRemoteError(nil) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - _, err = sess.flowControlManager.SendWindowSize(5) - Expect(err).To(MatchError("Error accessing the flowController map")) - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{&wire.WindowUpdateFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }} - err = sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(sph.sentPackets).To(BeEmpty()) - }) }) }) @@ -1204,7 +1037,7 @@ var _ = Describe("Session", func() { close(done) }() Eventually(sess.sendingScheduled).Should(Receive()) - s.(*stream).getDataForWriting(1000) // unblock + s.(*stream).GetDataForWriting(1000) // unblock }) It("sets the timer to the ack timer", func() { @@ -1435,16 +1268,17 @@ var _ = Describe("Session", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) go sess.run() - paramsChan <- handshake.TransportParameters{ + params := handshake.TransportParameters{ MaxStreams: 123, IdleTimeout: 90 * time.Second, StreamFlowControlWindow: 0x5000, ConnectionFlowControlWindow: 0x5000, OmitConnectionID: true, } - Eventually(func() time.Duration { return sess.remoteIdleTimeout }).Should(Equal(90 * time.Second)) + paramsChan <- params + Eventually(func() *handshake.TransportParameters { return sess.peerParams }).Should(Equal(¶ms)) Eventually(func() uint32 { return sess.streamsMap.maxOutgoingStreams }).Should(Equal(uint32(123))) - Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000))) + // Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000))) Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue()) Expect(sess.Close(nil)).To(Succeed()) }) @@ -1455,7 +1289,7 @@ var _ = Describe("Session", func() { remoteIdleTimeout := 20 * time.Second BeforeEach(func() { - sess.remoteIdleTimeout = remoteIdleTimeout + sess.peerParams = &handshake.TransportParameters{IdleTimeout: remoteIdleTimeout} }) It("sends a PING", func() { @@ -1565,7 +1399,7 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) str.Close() str.(*stream).Cancel(nil) - Expect(str.(*stream).finished()).To(BeTrue()) + Expect(str.(*stream).Finished()).To(BeTrue()) err = sess.streamsMap.DeleteClosedStreams() Expect(err).ToNot(HaveOccurred()) Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil()) @@ -1601,7 +1435,7 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) err = s.Close() Expect(err).NotTo(HaveOccurred()) - s.(*stream).sentFin() + s.(*stream).SentFin() s.(*stream).CloseRemote(0) _, err = s.Read([]byte("a")) Expect(err).To(MatchError(io.EOF)) @@ -1627,29 +1461,29 @@ var _ = Describe("Session", func() { }) }) - Context("window updates", func() { - It("gets stream level window updates", func() { - _, err := sess.GetOrOpenStream(3) - Expect(err).ToNot(HaveOccurred()) - err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow) - Expect(err).NotTo(HaveOccurred()) - frames := sess.getWindowUpdateFrames() - Expect(frames).To(HaveLen(1)) - Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3))) - Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) - }) + // Context("window updates", func() { + // It("gets stream level window updates", func() { + // _, err := sess.GetOrOpenStream(3) + // Expect(err).ToNot(HaveOccurred()) + // err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow) + // Expect(err).NotTo(HaveOccurred()) + // frames := sess.getWindowUpdateFrames() + // Expect(frames).To(HaveLen(1)) + // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3))) + // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) + // }) - It("gets connection level window updates", func() { - _, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - err = sess.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) - Expect(err).NotTo(HaveOccurred()) - frames := sess.getWindowUpdateFrames() - Expect(frames).To(HaveLen(1)) - Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0))) - Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow * 2)) - }) - }) + // It("gets connection level window updates", func() { + // _, err := sess.GetOrOpenStream(5) + // Expect(err).NotTo(HaveOccurred()) + // err = sess.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) + // Expect(err).NotTo(HaveOccurred()) + // frames := sess.getWindowUpdateFrames() + // Expect(frames).To(HaveLen(1)) + // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0))) + // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow * 2)) + // }) + // }) It("returns the local address", func() { addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} diff --git a/stream.go b/stream.go index 41c8e105..6af8526f 100644 --- a/stream.go +++ b/stream.go @@ -14,6 +14,24 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type streamI interface { + Stream + + AddStreamFrame(*wire.StreamFrame) error + RegisterRemoteError(error, protocol.ByteCount) error + LenOfDataForWriting() protocol.ByteCount + GetDataForWriting(maxBytes protocol.ByteCount) []byte + GetWriteOffset() protocol.ByteCount + Finished() bool + Cancel(error) + ShouldSendFin() bool + SentFin() + // methods needed for flow control + GetWindowUpdate() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + IsFlowControlBlocked() bool +} + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. @@ -56,10 +74,11 @@ type stream struct { writeChan chan struct{} writeDeadline time.Time - flowControlManager flowcontrol.FlowControlManager + flowController flowcontrol.StreamFlowController } var _ Stream = &stream{} +var _ streamI = &stream{} type deadlineError struct{} @@ -73,15 +92,16 @@ var errDeadline net.Error = &deadlineError{} func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), - flowControlManager flowcontrol.FlowControlManager) *stream { + flowController flowcontrol.StreamFlowController, +) *stream { s := &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowControlManager: flowControlManager, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), + onData: onData, + onReset: onReset, + streamID: StreamID, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + writeChan: make(chan struct{}, 1), } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s @@ -162,7 +182,7 @@ func (s *stream) Read(p []byte) (int, error) { // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream if !s.resetRemotely.Get() { - s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) + s.flowController.AddBytesRead(protocol.ByteCount(m)) } s.onData() // so that a possible WINDOW_UPDATE is sent @@ -231,7 +251,11 @@ func (s *stream) Write(p []byte) (int, error) { return len(p), nil } -func (s *stream) lenOfDataForWriting() protocol.ByteCount { +func (s *stream) GetWriteOffset() protocol.ByteCount { + return s.writeOffset +} + +func (s *stream) LenOfDataForWriting() protocol.ByteCount { s.mutex.Lock() var l protocol.ByteCount if s.err == nil { @@ -241,7 +265,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount { return l } -func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { +func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() defer s.mutex.Unlock() @@ -249,6 +273,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { return nil } + // TODO(#657): Flow control for the crypto stream + if s.streamID != 1 { + maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) + } + if maxBytes == 0 { + return nil + } + var ret []byte if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { ret = s.dataForWriting[:maxBytes] @@ -259,6 +291,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.signalWrite() } s.writeOffset += protocol.ByteCount(len(ret)) + s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) return ret } @@ -277,29 +310,27 @@ func (s *stream) shouldSendReset() bool { return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } -func (s *stream) shouldSendFin() bool { +func (s *stream) ShouldSendFin() bool { s.mutex.Lock() res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil s.mutex.Unlock() return res } -func (s *stream) sentFin() { +func (s *stream) SentFin() { s.finSent.Set(true) } // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() - err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) - if err != nil { + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { return err } s.mutex.Lock() defer s.mutex.Unlock() - err = s.frameQueue.Push(frame) - if err != nil && err != errDuplicateStreamData { + if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { return err } s.signalRead() @@ -393,9 +424,9 @@ func (s *stream) Reset(err error) { } // resets the stream remotely -func (s *stream) RegisterRemoteError(err error) { +func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { if s.resetRemotely.Get() { - return + return nil } s.mutex.Lock() s.resetRemotely.Set(true) @@ -405,18 +436,22 @@ func (s *stream) RegisterRemoteError(err error) { s.err = err s.signalWrite() } + if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { + return err + } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) s.rstSent.Set(true) } s.mutex.Unlock() + return nil } func (s *stream) finishedWriteAndSentFin() bool { return s.finishedWriting.Get() && s.finSent.Get() } -func (s *stream) finished() bool { +func (s *stream) Finished() bool { return s.cancelled.Get() || (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || (s.resetRemotely.Get() && s.rstSent.Get()) || @@ -431,3 +466,15 @@ func (s *stream) Context() context.Context { func (s *stream) StreamID() protocol.StreamID { return s.streamID } + +func (s *stream) UpdateSendWindow(n protocol.ByteCount) { + s.flowController.UpdateSendWindow(n) +} + +func (s *stream) IsFlowControlBlocked() bool { + return s.flowController.IsBlocked() +} + +func (s *stream) GetWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} diff --git a/stream_framer.go b/stream_framer.go index d5916c21..0e6137e4 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -3,23 +3,22 @@ package quic import ( "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { streamsMap *streamsMap - flowControlManager flowcontrol.FlowControlManager + connFlowController flowcontrol.ConnectionFlowController retransmissionQueue []*wire.StreamFrame blockedFrameQueue []*wire.BlockedFrame } -func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer { +func newStreamFramer(streamsMap *streamsMap, cfc flowcontrol.ConnectionFlowController) *streamFramer { return &streamFramer{ streamsMap: streamsMap, - flowControlManager: flowControlManager, + connFlowController: cfc, } } @@ -46,13 +45,11 @@ func (f *streamFramer) HasFramesForRetransmission() bool { } func (f *streamFramer) HasCryptoStreamFrame() bool { - // TODO(#657): Flow control cs, _ := f.streamsMap.GetOrOpenStream(1) - return cs.lenOfDataForWriting() > 0 + return cs.LenOfDataForWriting() > 0 } // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. -// TODO(#657): Flow control func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { if !f.HasCryptoStreamFrame() { return nil @@ -60,10 +57,10 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str cs, _ := f.streamsMap.GetOrOpenStream(1) frame := &wire.StreamFrame{ StreamID: 1, - Offset: cs.writeOffset, + Offset: cs.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - frame.Data = cs.getDataForWriting(maxLen - frameHeaderBytes) + frame.Data = cs.GetDataForWriting(maxLen - frameHeaderBytes) return frame } @@ -97,60 +94,45 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] frame := &wire.StreamFrame{DataLenPresent: true} var currentLen protocol.ByteCount - fn := func(s *stream) (bool, error) { - if s == nil || s.streamID == 1 /* crypto stream is handled separately */ { + fn := func(s streamI) (bool, error) { + if s == nil || s.StreamID() == 1 /* crypto stream is handled separately */ { return true, nil } - frame.StreamID = s.streamID + frame.StreamID = s.StreamID() + frame.Offset = s.GetWriteOffset() // not perfect, but thread-safe since writeOffset is only written when getting data - frame.Offset = s.writeOffset frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error if currentLen+frameHeaderBytes > maxBytes { return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here } maxLen := maxBytes - currentLen - frameHeaderBytes - var sendWindowSize protocol.ByteCount - lenStreamData := s.lenOfDataForWriting() - if lenStreamData != 0 { - sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID) - maxLen = utils.MinByteCount(maxLen, sendWindowSize) - } - - if maxLen == 0 { - return true, nil - } - var data []byte - if lenStreamData != 0 { - // Only getDataForWriting() if we didn't have data earlier, so that we - // don't send without FC approval (if a Write() raced). - data = s.getDataForWriting(maxLen) + if s.LenOfDataForWriting() > 0 { + data = s.GetDataForWriting(maxLen) } // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.shouldSendFin() + shouldSendFin := s.ShouldSendFin() if data == nil && !shouldSendFin { return true, nil } if shouldSendFin { frame.FinBit = true - s.sentFin() + s.SentFin() } frame.Data = data - f.flowControlManager.AddBytesSent(s.streamID, protocol.ByteCount(len(data))) // Finally, check if we are now FC blocked and should queue a BLOCKED frame - if f.flowControlManager.RemainingConnectionWindowSize() == 0 { - // We are now connection-level FC blocked - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{StreamID: 0}) - } else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 { - // We are now stream-level FC blocked + if !frame.FinBit && s.IsFlowControlBlocked() { f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{StreamID: s.StreamID()}) } + if f.connFlowController.IsBlocked() { + f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{StreamID: 0}) + } res = append(res, frame) currentLen += frameHeaderBytes + frame.DataLen() diff --git a/stream_framer_test.go b/stream_framer_test.go index 1036faa7..e94f592c 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -3,7 +3,9 @@ package quic import ( "bytes" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -20,8 +22,8 @@ var _ = Describe("Stream Framer", func() { retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame framer *streamFramer streamsMap *streamsMap - stream1, stream2 *stream - mockFcm *mocks.MockFlowControlManager + stream1, stream2 *mocks.MockStreamI + connFC *mocks.MockConnectionFlowController ) BeforeEach(func() { @@ -34,17 +36,26 @@ var _ = Describe("Stream Framer", func() { Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } - stream1 = &stream{streamID: id1} - stream2 = &stream{streamID: id2} + stream1 = mocks.NewMockStreamI(mockCtrl) + stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() + stream2 = mocks.NewMockStreamI(mockCtrl) + stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer) + streamsMap = newStreamsMap(nil, protocol.PerspectiveServer) streamsMap.putStream(stream1) streamsMap.putStream(stream2) - mockFcm = mocks.NewMockFlowControlManager(mockCtrl) - framer = newStreamFramer(streamsMap, mockFcm) + connFC = mocks.NewMockConnectionFlowController(mockCtrl) + framer = newStreamFramer(streamsMap, connFC) }) + setNoData := func(str *mocks.MockStreamI) { + str.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)).AnyTimes() + str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil).AnyTimes() + str.EXPECT().ShouldSendFin().Return(false).AnyTimes() + str.EXPECT().GetWriteOffset().AnyTimes() + } + It("says if it has retransmissions", func() { Expect(framer.HasFramesForRetransmission()).To(BeFalse()) framer.AddFrameForRetransmission(retransmittedFrame1) @@ -52,6 +63,8 @@ var _ = Describe("Stream Framer", func() { }) It("sets the DataLenPresent for dequeued retransmitted frames", func() { + setNoData(stream1) + setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) @@ -59,21 +72,35 @@ var _ = Describe("Stream Framer", func() { }) It("sets the DataLenPresent for dequeued normal frames", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") + connFC.EXPECT().IsBlocked() + setNoData(stream2) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(8)) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().IsFlowControlBlocked() + stream1.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) Expect(fs[0].DataLenPresent).To(BeTrue()) }) Context("Popping", func() { + BeforeEach(func() { + // nothing is blocked here + connFC.EXPECT().IsBlocked().AnyTimes() + stream1.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes() + stream2.EXPECT().IsFlowControlBlocked().Return(false).AnyTimes() + }) + It("returns nil when popping an empty framer", func() { + setNoData(stream1) + setNoData(stream2) Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) }) It("pops frames for retransmission", func() { + setNoData(stream1) + setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(1000) @@ -84,75 +111,71 @@ var _ = Describe("Stream Framer", func() { }) It("returns normal frames", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + setNoData(stream2) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.streamID)) + Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) - }) - - It("returns multiple normal frames", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - mockFcm.EXPECT().SendWindowSize(id2).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id2, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - stream2.dataForWriting = []byte("foobaz") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - // Swap if we dequeued in other order - if fs[0].StreamID != stream1.streamID { - fs[0], fs[1] = fs[1], fs[0] - } - Expect(fs[0].StreamID).To(Equal(stream1.streamID)) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[1].StreamID).To(Equal(stream2.streamID)) - Expect(fs[1].Data).To(Equal([]byte("foobaz"))) - Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) - }) - - It("returns retransmission frames before normal frames", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - framer.AddFrameForRetransmission(retransmittedFrame1) - stream1.dataForWriting = []byte("foobar") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0]).To(Equal(retransmittedFrame1)) - Expect(fs[1].StreamID).To(Equal(stream1.streamID)) - Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) - }) - - It("does not pop empty frames", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - stream1.dataForWriting = []byte("foobar") - fs := framer.PopStreamFrames(4) - Expect(fs).To(HaveLen(0)) - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(1)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - fs = framer.PopStreamFrames(5) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Data).ToNot(BeEmpty()) Expect(fs[0].FinBit).To(BeFalse()) }) + It("returns multiple normal frames", func() { + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz")) + stream2.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream2.EXPECT().GetWriteOffset() + stream2.EXPECT().ShouldSendFin() + fs := framer.PopStreamFrames(1000) + Expect(fs).To(HaveLen(2)) + // Swap if we dequeued in other order + if fs[0].StreamID != stream1.StreamID() { + fs[0], fs[1] = fs[1], fs[0] + } + Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) + Expect(fs[0].Data).To(Equal([]byte("foobar"))) + Expect(fs[1].StreamID).To(Equal(stream2.StreamID())) + Expect(fs[1].Data).To(Equal([]byte("foobaz"))) + }) + + It("returns retransmission frames before normal frames", func() { + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + setNoData(stream2) + framer.AddFrameForRetransmission(retransmittedFrame1) + fs := framer.PopStreamFrames(1000) + Expect(fs).To(HaveLen(2)) + Expect(fs[0]).To(Equal(retransmittedFrame1)) + Expect(fs[1].StreamID).To(Equal(stream1.StreamID())) + }) + + It("does not pop empty frames", func() { + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)) + stream1.EXPECT().ShouldSendFin() + stream1.EXPECT().GetWriteOffset() + setNoData(stream2) + fs := framer.PopStreamFrames(5) + Expect(fs).To(BeEmpty()) + }) + It("uses the round-robin scheduling", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - mockFcm.EXPECT().SendWindowSize(id2).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id2, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = bytes.Repeat([]byte("f"), 100) - stream2.dataForWriting = bytes.Repeat([]byte("e"), 100) + streamFrameHeaderLen := protocol.ByteCount(4) + stream1.EXPECT().GetDataForWriting(10 - streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(10-streamFrameHeaderLen))) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(100)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + stream2.EXPECT().GetDataForWriting(protocol.ByteCount(10 - streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(10-streamFrameHeaderLen))) + stream2.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(100)) + stream2.EXPECT().GetWriteOffset() + stream2.EXPECT().ShouldSendFin() fs := framer.PopStreamFrames(10) Expect(fs).To(HaveLen(1)) // it doesn't matter here if this data is from stream1 or from stream2... @@ -198,6 +221,8 @@ var _ = Describe("Stream Framer", func() { }) It("splits a frame", func() { + setNoData(stream1) + setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame2) origlen := retransmittedFrame2.DataLen() fs := framer.PopStreamFrames(6) @@ -239,6 +264,8 @@ var _ = Describe("Stream Framer", func() { }) It("only removes a frame from the framer after returning all split parts", func() { + setNoData(stream1) + setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(6) Expect(fs).To(HaveLen(1)) @@ -247,149 +274,56 @@ var _ = Describe("Stream Framer", func() { Expect(fs).To(HaveLen(1)) Expect(framer.retransmissionQueue).To(BeEmpty()) }) - - It("gets the whole data of a frame if it was split", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - origdata := []byte("foobar") - stream1.dataForWriting = origdata - fs := framer.PopStreamFrames(7) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Data).To(Equal([]byte("foo"))) - var b bytes.Buffer - fs[0].Write(&b, 0) - Expect(b.Len()).To(Equal(7)) - fs = framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Data).To(Equal([]byte("bar"))) - }) }) Context("sending FINs", func() { It("sends FINs when streams are closed", func() { - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(0)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.writeOffset = 42 - stream1.finishedWriting.Set(true) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.streamID)) - Expect(fs[0].Offset).To(Equal(stream1.writeOffset)) - Expect(fs[0].FinBit).To(BeTrue()) - Expect(fs[0].Data).To(BeEmpty()) - }) + offset := protocol.ByteCount(42) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(0)) + stream1.EXPECT().GetWriteOffset().Return(offset) + stream1.EXPECT().ShouldSendFin().Return(true) + stream1.EXPECT().SentFin() + setNoData(stream2) - It("sends FINs when flow-control blocked", func() { - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(0)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.writeOffset = 42 - stream1.finishedWriting.Set(true) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.streamID)) - Expect(fs[0].Offset).To(Equal(stream1.writeOffset)) + Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) + Expect(fs[0].Offset).To(Equal(offset)) Expect(fs[0].FinBit).To(BeTrue()) Expect(fs[0].Data).To(BeEmpty()) }) It("bundles FINs with data", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - stream1.finishedWriting.Set(true) + offset := protocol.ByteCount(42) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().GetWriteOffset().Return(offset) + stream1.EXPECT().ShouldSendFin().Return(true) + stream1.EXPECT().SentFin() + setNoData(stream2) + fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.streamID)) + Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) Expect(fs[0].Data).To(Equal([]byte("foobar"))) Expect(fs[0].FinBit).To(BeTrue()) }) }) }) - Context("flow control", func() { - It("tells the FlowControlManager how many bytes it sent", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - framer.PopStreamFrames(1000) - }) - - It("does not count retransmitted frames as sent bytes", func() { - framer.AddFrameForRetransmission(retransmittedFrame1) - framer.PopStreamFrames(1000) - }) - - It("returns the whole frame if it fits", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(10+6), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.writeOffset = 10 - stream1.dataForWriting = []byte("foobar") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].DataLen()).To(Equal(protocol.ByteCount(6))) - }) - - It("returns a smaller frame if the whole frame doesn't fit", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(3), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Data).To(Equal([]byte("foo"))) - }) - - It("returns a smaller frame if the whole frame doesn't fit in the stream flow control window, for non-zero StreamFrame offset", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(3), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.writeOffset = 1 - stream1.dataForWriting = []byte("foobar") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].Data).To(Equal([]byte("foo"))) - }) - - It("selects a stream that is not flow control blocked", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(0), nil) - mockFcm.EXPECT().SendWindowSize(id2).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id2, protocol.ByteCount(6)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - stream2.dataForWriting = []byte("foobaz") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream2.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobaz"))) - }) - - It("returns nil if every stream is individually flow control blocked", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(0), nil) - mockFcm.EXPECT().SendWindowSize(id2).Return(protocol.ByteCount(0), nil) - stream1.dataForWriting = []byte("foobar") - stream2.dataForWriting = []byte("foobaz") - fs := framer.PopStreamFrames(1000) - Expect(fs).To(BeEmpty()) - }) - }) - Context("BLOCKED frames", func() { It("Pop returns nil if no frame is queued", func() { Expect(framer.PopBlockedFrame()).To(BeNil()) }) It("queues and pops BLOCKED frames for individually blocked streams", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(3), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foo") + connFC.EXPECT().IsBlocked() + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(6)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + stream1.EXPECT().IsFlowControlBlocked().Return(true) + setNoData(stream2) frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) blockedFrame := framer.PopBlockedFrame() @@ -399,56 +333,34 @@ var _ = Describe("Stream Framer", func() { }) It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(5000), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(0)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foo") + connFC.EXPECT().IsBlocked() + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(3)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin().Return(true) + stream1.EXPECT().SentFin() + setNoData(stream2) frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) - Expect(frames[0].FinBit).To(BeFalse()) - stream1.finishedWriting.Set(true) - frames = framer.PopStreamFrames(1000) - Expect(frames).To(HaveLen(1)) Expect(frames[0].FinBit).To(BeTrue()) - Expect(frames[0].DataLen()).To(BeZero()) + Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3))) blockedFrame := framer.PopBlockedFrame() Expect(blockedFrame).To(BeNil()) }) It("queues and pops BLOCKED frames for connection blocked streams", func() { - // FCM already considers the connection window size - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(3), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.ByteCount(0)) - stream1.dataForWriting = []byte("foo") + connFC.EXPECT().IsBlocked().Return(true) + stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo")) + stream1.EXPECT().LenOfDataForWriting().Return(protocol.ByteCount(3)) + stream1.EXPECT().GetWriteOffset() + stream1.EXPECT().ShouldSendFin() + stream1.EXPECT().IsFlowControlBlocked().Return(false) + setNoData(stream2) framer.PopStreamFrames(1000) blockedFrame := framer.PopBlockedFrame() Expect(blockedFrame).ToNot(BeNil()) Expect(blockedFrame.StreamID).To(BeZero()) Expect(framer.PopBlockedFrame()).To(BeNil()) }) - - It("does not queue BLOCKED frames for non-contributing streams", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.MaxByteCount, nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foo") - framer.PopStreamFrames(1000) - Expect(framer.PopBlockedFrame()).To(BeNil()) - }) - - It("does not queue BLOCKED frames twice", func() { - mockFcm.EXPECT().SendWindowSize(id1).Return(protocol.ByteCount(3), nil) - mockFcm.EXPECT().AddBytesSent(id1, protocol.ByteCount(3)) - mockFcm.EXPECT().RemainingConnectionWindowSize().Return(protocol.MaxByteCount) - stream1.dataForWriting = []byte("foobar") - framer.PopStreamFrames(1000) - blockedFrame := framer.PopBlockedFrame() - Expect(blockedFrame).ToNot(BeNil()) - Expect(blockedFrame.StreamID).To(Equal(stream1.StreamID())) - Expect(framer.PopBlockedFrame()).To(BeNil()) - }) }) }) diff --git a/stream_test.go b/stream_test.go index 9d757851..b8187f50 100644 --- a/stream_test.go +++ b/stream_test.go @@ -30,7 +30,7 @@ var _ = Describe("Stream", func() { resetCalledForStream protocol.StreamID resetCalledAtOffset protocol.ByteCount - mockFcm *mocks.MockFlowControlManager + mockFC *mocks.MockStreamFlowController ) // in the tests for the stream deadlines we set a deadline @@ -58,8 +58,8 @@ var _ = Describe("Stream", func() { BeforeEach(func() { onDataCalled = false resetCalled = false - mockFcm = mocks.NewMockFlowControlManager(mockCtrl) - str = newStream(streamID, onData, onReset, mockFcm) + mockFC = mocks.NewMockStreamFlowController(mockCtrl) + str = newStream(streamID, onData, onReset, mockFC) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { @@ -77,8 +77,8 @@ var _ = Describe("Stream", func() { Context("reading", func() { It("reads a single StreamFrame", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -93,9 +93,9 @@ var _ = Describe("Stream", func() { }) It("reads a single StreamFrame in multiple goes", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -114,9 +114,9 @@ var _ = Describe("Stream", func() { }) It("reads all data available", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -137,9 +137,9 @@ var _ = Describe("Stream", func() { }) It("assembles multiple StreamFrames", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -160,8 +160,8 @@ var _ = Describe("Stream", func() { }) It("waits until data is available", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) go func() { defer GinkgoRecover() frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} @@ -176,9 +176,9 @@ var _ = Describe("Stream", func() { }) It("handles StreamFrames in wrong order", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -199,10 +199,10 @@ var _ = Describe("Stream", func() { }) It("ignores duplicate StreamFrames", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -229,10 +229,10 @@ var _ = Describe("Stream", func() { }) It("doesn't rejects a StreamFrames with an overlapping data range", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foob"), @@ -253,8 +253,8 @@ var _ = Describe("Stream", func() { }) It("calls onData", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -273,7 +273,7 @@ var _ = Describe("Stream", func() { }) It("returns an error when Read is called after the deadline", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)).AnyTimes() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() f := &wire.StreamFrame{Data: []byte("foobar")} err := str.AddStreamFrame(f) Expect(err).ToNot(HaveOccurred()) @@ -332,7 +332,7 @@ var _ = Describe("Stream", func() { }) It("sets a read deadline, when SetDeadline is called", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)).AnyTimes() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes() f := &wire.StreamFrame{Data: []byte("foobar")} err := str.AddStreamFrame(f) Expect(err).ToNot(HaveOccurred()) @@ -347,8 +347,8 @@ var _ = Describe("Stream", func() { Context("closing", func() { Context("with FIN bit", func() { It("returns EOFs", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -366,9 +366,9 @@ var _ = Describe("Stream", func() { }) It("handles out-of-order frames", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)).Times(2) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -393,8 +393,8 @@ var _ = Describe("Stream", func() { }) It("returns EOFs with partial read", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -410,8 +410,8 @@ var _ = Describe("Stream", func() { }) It("handles immediate FINs", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) frame := wire.StreamFrame{ Offset: 0, Data: []byte{}, @@ -428,8 +428,8 @@ var _ = Describe("Stream", func() { Context("when CloseRemote is called", func() { It("closes", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) str.CloseRemote(0) b := make([]byte, 8) n, err := strWithTimeout.Read(b) @@ -438,7 +438,7 @@ var _ = Describe("Stream", func() { }) It("doesn't cancel the context", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.CloseRemote(0) Expect(str.Context().Done()).ToNot(BeClosed()) }) @@ -484,13 +484,14 @@ var _ = Describe("Stream", func() { Context("reset by the peer", func() { It("continues reading after receiving a remote error", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 10) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -498,8 +499,9 @@ var _ = Describe("Stream", func() { }) It("reads a delayed StreamFrame that arrives after receiving a remote error", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) - str.RegisterRemoteError(testErr) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + str.RegisterRemoteError(testErr, 4) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -513,13 +515,14 @@ var _ = Describe("Stream", func() { }) It("returns the error if reading past the offset of the frame received", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 8) b := make([]byte, 10) n, err := strWithTimeout.Read(b) Expect(b[0:4]).To(Equal(frame.Data)) @@ -528,14 +531,15 @@ var _ = Describe("Stream", func() { }) It("returns an EOF when reading past the offset, if the stream received a finbit", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), true) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, FinBit: true, } str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 8) b := make([]byte, 10) n, err := strWithTimeout.Read(b) Expect(b[:4]).To(Equal(frame.Data)) @@ -544,14 +548,15 @@ var _ = Describe("Stream", func() { }) It("continues reading in small chunks after receiving a remote error", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, FinBit: true, } str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 4) b := make([]byte, 3) _, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -564,7 +569,8 @@ var _ = Describe("Stream", func() { }) It("doesn't inform the flow controller about bytes read after receiving the remote error", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(4)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) // No AddBytesRead() frame := wire.StreamFrame{ Offset: 0, @@ -572,13 +578,14 @@ var _ = Describe("Stream", func() { Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, } str.AddStreamFrame(&frame) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 10) b := make([]byte, 3) _, err := strWithTimeout.Read(b) Expect(err).ToNot(HaveOccurred()) }) It("stops writing after receiving a remote error", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -587,12 +594,14 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(testErr)) close(done) }() - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 10) Eventually(done).Should(BeClosed()) - }) It("returns how much was written when recieving a remote error", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -602,19 +611,20 @@ var _ = Describe("Stream", func() { close(done) }() - Eventually(func() []byte { return str.getDataForWriting(4) }).ShouldNot(BeEmpty()) - str.RegisterRemoteError(testErr) + Eventually(func() []byte { return str.GetDataForWriting(4) }).ShouldNot(BeEmpty()) + str.RegisterRemoteError(testErr, 10) Eventually(done).Should(BeClosed()) }) It("calls onReset when receiving a remote error", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) done := make(chan struct{}) str.writeOffset = 0x1000 go func() { _, _ = strWithTimeout.Write([]byte("foobar")) close(done) }() - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeTrue()) Expect(resetCalledForStream).To(Equal(protocol.StreamID(1337))) Expect(resetCalledAtOffset).To(Equal(protocol.ByteCount(0x1000))) @@ -622,25 +632,28 @@ var _ = Describe("Stream", func() { }) It("doesn't call onReset if it already sent a FIN", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.Close() - str.sentFin() - str.RegisterRemoteError(testErr) + str.SentFin() + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeFalse()) }) It("doesn't call onReset if the stream was reset locally before", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.Reset(testErr) Expect(resetCalled).To(BeTrue()) resetCalled = false - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeFalse()) }) It("doesn't call onReset twice, when it gets two remote errors", func() { - str.RegisterRemoteError(testErr) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeTrue()) resetCalled = false - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeFalse()) }) }) @@ -657,7 +670,7 @@ var _ = Describe("Stream", func() { }() Consistently(done).ShouldNot(BeClosed()) str.Reset(testErr) - Expect(str.getDataForWriting(6)).To(BeNil()) + Expect(str.GetDataForWriting(6)).To(BeNil()) Eventually(done).Should(BeClosed()) }) @@ -666,7 +679,7 @@ var _ = Describe("Stream", func() { n, err := strWithTimeout.Write([]byte("foobar")) Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) - Expect(str.getDataForWriting(6)).To(BeNil()) + Expect(str.GetDataForWriting(6)).To(BeNil()) }) It("stops reading", func() { @@ -685,7 +698,7 @@ var _ = Describe("Stream", func() { }) It("doesn't allow further reads", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) str.AddStreamFrame(&wire.StreamFrame{ Data: []byte("foobar"), }) @@ -706,13 +719,14 @@ var _ = Describe("Stream", func() { It("doesn't call onReset if it already sent a FIN", func() { str.Close() - str.sentFin() + str.SentFin() str.Reset(testErr) Expect(resetCalled).To(BeFalse()) }) It("doesn't call onReset if the stream was reset remotely before", func() { - str.RegisterRemoteError(testErr) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) + str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeTrue()) resetCalled = false str.Reset(testErr) @@ -737,6 +751,8 @@ var _ = Describe("Stream", func() { Context("writing", func() { It("writes and gets all data at once", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -752,8 +768,8 @@ var _ = Describe("Stream", func() { }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) Expect(onDataCalled).To(BeTrue()) - Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) - data := str.getDataForWriting(1000) + Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) + data := str.GetDataForWriting(1000) Expect(data).To(Equal([]byte("foobar"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) @@ -761,6 +777,8 @@ var _ = Describe("Stream", func() { }) It("writes and gets data in two turns", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -775,25 +793,27 @@ var _ = Describe("Stream", func() { return str.dataForWriting }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) - Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) - data := str.getDataForWriting(3) + Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(6))) + data := str.GetDataForWriting(3) Expect(data).To(Equal([]byte("foo"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) Expect(str.dataForWriting).ToNot(BeNil()) - Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(3))) - data = str.getDataForWriting(3) + Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(3))) + data = str.GetDataForWriting(3) Expect(data).To(Equal([]byte("bar"))) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) - Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0))) + Expect(str.LenOfDataForWriting()).To(Equal(protocol.ByteCount(0))) Eventually(done).Should(BeClosed()) }) It("getDataForWriting returns nil if no data is available", func() { - Expect(str.getDataForWriting(1000)).To(BeNil()) + Expect(str.GetDataForWriting(1000)).To(BeNil()) }) It("copies the slice while writing", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) s := []byte("foo") go func() { defer GinkgoRecover() @@ -801,9 +821,9 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(3)) }() - Eventually(func() protocol.ByteCount { return str.lenOfDataForWriting() }).ShouldNot(BeZero()) + Eventually(func() protocol.ByteCount { return str.LenOfDataForWriting() }).ShouldNot(BeZero()) s[0] = 'v' - Expect(str.getDataForWriting(3)).To(Equal([]byte("foo"))) + Expect(str.GetDataForWriting(3)).To(Equal([]byte("foo"))) }) It("returns when given a nil input", func() { @@ -892,29 +912,29 @@ var _ = Describe("Stream", func() { It("allows FIN", func() { str.Close() - Expect(str.shouldSendFin()).To(BeTrue()) + Expect(str.ShouldSendFin()).To(BeTrue()) }) It("does not allow FIN when there's still data", func() { str.dataForWriting = []byte("foobar") str.Close() - Expect(str.shouldSendFin()).To(BeFalse()) + Expect(str.ShouldSendFin()).To(BeFalse()) }) It("does not allow FIN when the stream is not closed", func() { - Expect(str.shouldSendFin()).To(BeFalse()) + Expect(str.ShouldSendFin()).To(BeFalse()) }) It("does not allow FIN after an error", func() { str.Cancel(errors.New("test")) - Expect(str.shouldSendFin()).To(BeFalse()) + Expect(str.ShouldSendFin()).To(BeFalse()) }) It("does not allow FIN twice", func() { str.Close() - Expect(str.shouldSendFin()).To(BeTrue()) - str.sentFin() - Expect(str.shouldSendFin()).To(BeFalse()) + Expect(str.ShouldSendFin()).To(BeTrue()) + str.SentFin() + Expect(str.ShouldSendFin()).To(BeFalse()) }) }) @@ -935,18 +955,18 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(testErr)) }() Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) - Expect(str.lenOfDataForWriting()).ToNot(BeZero()) + Expect(str.LenOfDataForWriting()).ToNot(BeZero()) str.Cancel(testErr) - data := str.getDataForWriting(6) + data := str.GetDataForWriting(6) Expect(data).To(BeNil()) - Expect(str.lenOfDataForWriting()).To(BeZero()) + Expect(str.LenOfDataForWriting()).To(BeZero()) }) }) }) It("errors when a StreamFrames causes a flow control violation", func() { testErr := errors.New("flow control violation") - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(8)).Return(testErr) + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false).Return(testErr) frame := wire.StreamFrame{ Offset: 2, Data: []byte("foobar"), @@ -959,6 +979,7 @@ var _ = Describe("Stream", func() { testErr := errors.New("testErr") finishReading := func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) err := str.AddStreamFrame(&wire.StreamFrame{FinBit: true}) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 100) @@ -968,63 +989,84 @@ var _ = Describe("Stream", func() { It("is finished after it is canceled", func() { str.Cancel(testErr) - Expect(str.finished()).To(BeTrue()) + Expect(str.Finished()).To(BeTrue()) }) It("is not finished if it is only closed for writing", func() { str.Close() - str.sentFin() - Expect(str.finished()).To(BeFalse()) + str.SentFin() + Expect(str.Finished()).To(BeFalse()) }) It("cancels the context after it is closed", func() { Expect(str.Context().Done()).ToNot(BeClosed()) str.Close() - str.sentFin() + str.SentFin() Expect(str.Context().Done()).To(BeClosed()) }) It("is not finished if it is only closed for reading", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) finishReading() - Expect(str.finished()).To(BeFalse()) + Expect(str.Finished()).To(BeFalse()) }) It("is finished after receiving a RST and sending one", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) // this directly sends a rst - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 0) Expect(str.rstSent.Get()).To(BeTrue()) - Expect(str.finished()).To(BeTrue()) + Expect(str.Finished()).To(BeTrue()) }) It("cancels the context after receiving a RST", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) Expect(str.Context().Done()).ToNot(BeClosed()) - str.RegisterRemoteError(testErr) + str.RegisterRemoteError(testErr, 0) Expect(str.Context().Done()).To(BeClosed()) }) It("is finished after being locally reset and receiving a RST in response", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) str.Reset(testErr) - Expect(str.finished()).To(BeFalse()) - str.RegisterRemoteError(testErr) - Expect(str.finished()).To(BeTrue()) + Expect(str.Finished()).To(BeFalse()) + str.RegisterRemoteError(testErr, 13) + Expect(str.Finished()).To(BeTrue()) }) It("is finished after finishing writing and receiving a RST", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) str.Close() - str.sentFin() - str.RegisterRemoteError(testErr) - Expect(str.finished()).To(BeTrue()) + str.SentFin() + str.RegisterRemoteError(testErr, 13) + Expect(str.Finished()).To(BeTrue()) }) It("is finished after finishing reading and being locally reset", func() { - mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(0)) - mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(0)) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) finishReading() - Expect(str.finished()).To(BeFalse()) + Expect(str.Finished()).To(BeFalse()) str.Reset(testErr) - Expect(str.finished()).To(BeTrue()) + Expect(str.Finished()).To(BeTrue()) + }) + }) + + Context("flow control", func() { + It("says when it's flow control blocked", func() { + mockFC.EXPECT().IsBlocked().Return(false) + Expect(str.IsFlowControlBlocked()).To(BeFalse()) + mockFC.EXPECT().IsBlocked().Return(true) + Expect(str.IsFlowControlBlocked()).To(BeTrue()) + }) + + It("updates the flow control window", func() { + mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x42)) + str.UpdateSendWindow(0x42) + }) + + It("gets a window update", func() { + mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x100)) + Expect(str.GetWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) }) }) }) diff --git a/streams_map.go b/streams_map.go index dba65ce4..915d35e3 100644 --- a/streams_map.go +++ b/streams_map.go @@ -15,7 +15,7 @@ type streamsMap struct { perspective protocol.Perspective - streams map[protocol.StreamID]*stream + streams map[protocol.StreamID]streamI // needed for round-robin scheduling openStreams []protocol.StreamID roundRobinIndex int @@ -28,8 +28,7 @@ type streamsMap struct { closeErr error nextStreamToAccept protocol.StreamID - newStream newStreamLambda - removeStreamCallback removeStreamCallback + newStream newStreamLambda numOutgoingStreams uint32 numIncomingStreams uint32 @@ -37,13 +36,12 @@ type streamsMap struct { maxOutgoingStreams uint32 } -type streamLambda func(*stream) (bool, error) -type removeStreamCallback func(protocol.StreamID) -type newStreamLambda func(protocol.StreamID) *stream +type streamLambda func(streamI) (bool, error) +type newStreamLambda func(protocol.StreamID) streamI var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *streamsMap { // add some tolerance to the maximum incoming streams value maxStreams := uint32(protocol.MaxIncomingStreams) maxIncomingStreams := utils.MaxUint32( @@ -51,12 +49,11 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), ) sm := streamsMap{ - perspective: pers, - streams: make(map[protocol.StreamID]*stream), - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - removeStreamCallback: removeStreamCallback, - maxIncomingStreams: maxIncomingStreams, + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + openStreams: make([]protocol.StreamID, 0), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex @@ -76,7 +73,7 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. -func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { +func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { m.mutex.RLock() s, ok := m.streams[id] m.mutex.RUnlock() @@ -134,7 +131,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return m.streams[id], nil } -func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { +func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { if m.numIncomingStreams >= m.maxIncomingStreams { return nil, qerr.TooManyOpenStreams } @@ -157,7 +154,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { return s, nil } -func (m *streamsMap) openStreamImpl() (*stream, error) { +func (m *streamsMap) openStreamImpl() (streamI, error) { id := m.nextStream if m.numOutgoingStreams >= m.maxOutgoingStreams { return nil, qerr.TooManyOpenStreams @@ -176,7 +173,7 @@ func (m *streamsMap) openStreamImpl() (*stream, error) { } // OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (*stream, error) { +func (m *streamsMap) OpenStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -186,7 +183,7 @@ func (m *streamsMap) OpenStream() (*stream, error) { return m.openStreamImpl() } -func (m *streamsMap) OpenStreamSync() (*stream, error) { +func (m *streamsMap) OpenStreamSync() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -207,10 +204,10 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) { // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened -func (m *streamsMap) AcceptStream() (*stream, error) { +func (m *streamsMap) AcceptStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - var str *stream + var str streamI for { var ok bool if m.closeErr != nil { @@ -237,10 +234,9 @@ func (m *streamsMap) DeleteClosedStreams() error { if !ok { return errMapAccess } - if !str.finished() { + if !str.Finished() { continue } - m.removeStreamCallback(streamID) numDeletedStreams++ m.openStreams[i] = 0 if streamID%2 == 0 { @@ -312,7 +308,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { } // Range executes a callback for all streams, in pseudo-random order -func (m *streamsMap) Range(cb func(s *stream)) { +func (m *streamsMap) Range(cb func(s streamI)) { m.mutex.RLock() defer m.mutex.RUnlock() @@ -331,7 +327,7 @@ func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (b return fn(str) } -func (m *streamsMap) putStream(s *stream) error { +func (m *streamsMap) putStream(s streamI) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) diff --git a/streams_map_test.go b/streams_map_test.go index 98692a66..e76b3c4f 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -4,25 +4,37 @@ import ( "errors" "sort" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" + + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Streams Map", func() { var ( - m *streamsMap + m *streamsMap + finishedStreams map[protocol.StreamID]*gomock.Call ) - setNewStreamsMap := func(p protocol.Perspective) { - newStream := func(id protocol.StreamID) *stream { - return newStream(id, func() {}, nil, nil) - } - removeStreamCallback := func(protocol.StreamID) {} - m = newStreamsMap(newStream, removeStreamCallback, p) + newStream := func(id protocol.StreamID) streamI { + str := mocks.NewMockStreamI(mockCtrl) + str.EXPECT().StreamID().Return(id).AnyTimes() + c := str.EXPECT().Finished().Return(false).AnyTimes() + finishedStreams[id] = c + return str } + setNewStreamsMap := func(p protocol.Perspective) { + m = newStreamsMap(newStream, p) + } + + BeforeEach(func() { + finishedStreams = make(map[protocol.StreamID]*gomock.Call) + }) + AfterEach(func() { Expect(m.openStreams).To(HaveLen(len(m.streams))) }) @@ -30,8 +42,7 @@ var _ = Describe("Streams Map", func() { deleteStream := func(id protocol.StreamID) { str := m.streams[id] Expect(str).ToNot(BeNil()) - str.cancelled.Set(true) - Expect(str.finished()).To(BeTrue()) + finishedStreams[id].Return(true) err := m.DeleteClosedStreams() Expect(err).ToNot(HaveOccurred()) } @@ -215,7 +226,7 @@ var _ = Describe("Streams Map", func() { It("waits until another stream is closed", func() { openMaxNumStreams() var returned bool - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error @@ -231,19 +242,23 @@ var _ = Describe("Streams Map", func() { }) It("stops waiting when an error is registered", func() { - openMaxNumStreams() testErr := errors.New("test error") - var err error - var returned bool + openMaxNumStreams() + for _, str := range m.streams { + str.(*mocks.MockStreamI).EXPECT().Cancel(testErr) + } + + done := make(chan struct{}) go func() { - _, err = m.OpenStreamSync() - returned = true + defer GinkgoRecover() + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + close(done) }() - Consistently(func() bool { return returned }).Should(BeFalse()) + Consistently(done).ShouldNot(BeClosed()) m.CloseWithError(testErr) - Eventually(func() bool { return returned }).Should(BeTrue()) - Expect(err).To(MatchError(testErr)) + Eventually(done).Should(BeClosed()) }) It("immediately returns when OpenStreamSync is called after an error was registered", func() { @@ -266,7 +281,7 @@ var _ = Describe("Streams Map", func() { }) It("accepts stream 1 first", func() { - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error @@ -280,7 +295,7 @@ var _ = Describe("Streams Map", func() { }) It("returns an implicitly opened stream, if a stream number is skipped", func() { - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error @@ -294,7 +309,7 @@ var _ = Describe("Streams Map", func() { }) It("returns to multiple accepts", func() { - var str1, str2 *stream + var str1, str2 streamI go func() { defer GinkgoRecover() var err error @@ -309,29 +324,29 @@ var _ = Describe("Streams Map", func() { }() _, err := m.GetOrOpenStream(3) // opens stream 1 and 3 Expect(err).ToNot(HaveOccurred()) - Eventually(func() *stream { return str1 }).ShouldNot(BeNil()) - Eventually(func() *stream { return str2 }).ShouldNot(BeNil()) + Eventually(func() streamI { return str1 }).ShouldNot(BeNil()) + Eventually(func() streamI { return str2 }).ShouldNot(BeNil()) Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) }) It("waits a new stream is available", func() { - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) }() - Consistently(func() *stream { return str }).Should(BeNil()) + Consistently(func() streamI { return str }).Should(BeNil()) _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - Eventually(func() *stream { return str }).ShouldNot(BeNil()) + Eventually(func() streamI { return str }).ShouldNot(BeNil()) Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) It("returns multiple streams on subsequent Accept calls, if available", func() { - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error @@ -340,7 +355,7 @@ var _ = Describe("Streams Map", func() { }() _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - Eventually(func() *stream { return str }).ShouldNot(BeNil()) + Eventually(func() streamI { return str }).ShouldNot(BeNil()) Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) @@ -459,7 +474,7 @@ var _ = Describe("Streams Map", func() { Context("accepting streams", func() { It("accepts stream 2 first", func() { - var str *stream + var str streamI go func() { defer GinkgoRecover() var err error @@ -468,7 +483,7 @@ var _ = Describe("Streams Map", func() { }() _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) - Eventually(func() *stream { return str }).ShouldNot(BeNil()) + Eventually(func() streamI { return str }).ShouldNot(BeNil()) Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) }) }) @@ -483,15 +498,13 @@ var _ = Describe("Streams Map", func() { closeStream := func(id protocol.StreamID) { str := m.streams[id] Expect(str).ToNot(BeNil()) - Expect(str.finished()).To(BeFalse()) - str.cancelled.Set(true) - Expect(str.finished()).To(BeTrue()) + finishedStreams[id].Return(true) } Context("deleting streams", func() { BeforeEach(func() { for i := 1; i <= 5; i++ { - err := m.putStream(&stream{streamID: protocol.StreamID(i)}) + err := m.putStream(newStream(protocol.StreamID(i))) Expect(err).ToNot(HaveOccurred()) } Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) @@ -500,6 +513,7 @@ var _ = Describe("Streams Map", func() { It("does not delete streams with Close()", func() { str, err := m.GetOrOpenStream(55) Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().Close() str.Close() err = m.DeleteClosedStreams() Expect(err).ToNot(HaveOccurred()) @@ -546,7 +560,7 @@ var _ = Describe("Streams Map", func() { Context("Ranging", func() { // create 5 streams, ids 4 to 8 var callbackCalledForStream []protocol.StreamID - callback := func(str *stream) { + callback := func(str streamI) { callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] }) } @@ -574,13 +588,13 @@ var _ = Describe("Streams Map", func() { lambdaCalledForStream = lambdaCalledForStream[:0] numIterations = 0 for i := 4; i <= 8; i++ { - err := m.putStream(&stream{streamID: protocol.StreamID(i)}) + err := m.putStream(newStream(protocol.StreamID(i))) Expect(err).NotTo(HaveOccurred()) } }) It("executes the lambda exactly once for every stream", func() { - fn := func(str *stream) (bool, error) { + fn := func(str streamI) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ return true, nil @@ -593,7 +607,7 @@ var _ = Describe("Streams Map", func() { }) It("goes around once when starting in the middle", func() { - fn := func(str *stream) (bool, error) { + fn := func(str streamI) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ return true, nil @@ -607,7 +621,7 @@ var _ = Describe("Streams Map", func() { }) It("picks up at the index+1 where it last stopped", func() { - fn := func(str *stream) (bool, error) { + fn := func(str streamI) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ if str.StreamID() == 5 { @@ -622,7 +636,7 @@ var _ = Describe("Streams Map", func() { Expect(m.roundRobinIndex).To(BeEquivalentTo(2)) numIterations = 0 lambdaCalledForStream = lambdaCalledForStream[:0] - fn2 := func(str *stream) (bool, error) { + fn2 := func(str streamI) (bool, error) { lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) numIterations++ if str.StreamID() == 7 { @@ -681,7 +695,7 @@ var _ = Describe("Streams Map", func() { It("gets crypto- and header stream first, then picks up at the round-robin position", func() { m.roundRobinIndex = 3 // stream 7 - fn := func(str *stream) (bool, error) { + fn := func(str streamI) (bool, error) { if numIterations >= 3 { return false, nil }