diff --git a/crypto_stream.go b/crypto_stream.go index 94158e17..b9263da6 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -18,7 +18,7 @@ type cryptoStreamI interface { SetReadOffset(protocol.ByteCount) // methods needed for flow control GetWindowUpdate() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) + HandleMaxStreamDataFrame(*wire.MaxStreamDataFrame) IsFlowControlBlocked() bool } diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 0c3c5361..f50574d9 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -94,6 +94,16 @@ func (_mr *MockStreamIMockRecorder) GetWindowUpdate() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).GetWindowUpdate)) } +// HandleMaxStreamDataFrame mocks base method +func (_m *MockStreamI) HandleMaxStreamDataFrame(_param0 *wire.MaxStreamDataFrame) { + _m.ctrl.Call(_m, "HandleMaxStreamDataFrame", _param0) +} + +// HandleMaxStreamDataFrame indicates an expected call of HandleMaxStreamDataFrame +func (_mr *MockStreamIMockRecorder) HandleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HandleMaxStreamDataFrame", reflect.TypeOf((*MockStreamI)(nil).HandleMaxStreamDataFrame), arg0) +} + // HandleStreamFrame mocks base method func (_m *MockStreamI) HandleStreamFrame(_param0 *wire.StreamFrame) error { ret := _m.ctrl.Call(_m, "HandleStreamFrame", _param0) @@ -213,16 +223,6 @@ func (_mr *MockStreamIMockRecorder) StreamID() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) } -// UpdateSendWindow mocks base method -func (_m *MockStreamI) UpdateSendWindow(_param0 protocol.ByteCount) { - _m.ctrl.Call(_m, "UpdateSendWindow", _param0) -} - -// UpdateSendWindow indicates an expected call of UpdateSendWindow -func (_mr *MockStreamIMockRecorder) UpdateSendWindow(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamI)(nil).UpdateSendWindow), arg0) -} - // Write mocks base method func (_m *MockStreamI) Write(_param0 []byte) (int, error) { ret := _m.ctrl.Call(_m, "Write", _param0) diff --git a/session.go b/session.go index 37195656..62aafc07 100644 --- a/session.go +++ b/session.go @@ -586,7 +586,7 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { if frame.StreamID == s.version.CryptoStreamID() { - s.cryptoStream.UpdateSendWindow(frame.ByteOffset) + s.cryptoStream.HandleMaxStreamDataFrame(frame) return nil } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) @@ -597,7 +597,7 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error // stream is closed and already garbage collected return nil } - str.UpdateSendWindow(frame.ByteOffset) + str.HandleMaxStreamDataFrame(frame) return nil } @@ -684,8 +684,12 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete s.packer.SetOmitConnectionID() } s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) + // increase the flow control windows of all streams by sending them a fake MAX_STREAM_DATA frame s.streamsMap.Range(func(str streamI) { - str.UpdateSendWindow(params.StreamFlowControlWindow) + str.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: str.StreamID(), + ByteOffset: params.StreamFlowControlWindow, + }) }) } diff --git a/session_test.go b/session_test.go index 899a0363..eabc6d2f 100644 --- a/session_test.go +++ b/session_test.go @@ -414,14 +414,14 @@ var _ = Describe("Session", func() { }) It("updates the flow control window of a stream", func() { - offset := protocol.ByteCount(0x1234) - str, err := sess.GetOrOpenStream(5) - str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(offset) - Expect(err).ToNot(HaveOccurred()) - err = sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + f := &wire.MaxStreamDataFrame{ StreamID: 5, - ByteOffset: offset, - }) + ByteOffset: 0x1234, + } + str, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + str.(*mocks.MockStreamI).EXPECT().HandleMaxStreamDataFrame(f) + err = sess.handleMaxStreamDataFrame(f) Expect(err).ToNot(HaveOccurred()) }) @@ -432,18 +432,19 @@ var _ = Describe("Session", func() { }) It("opens a new stream when receiving a MAX_STREAM_DATA frame for an unknown stream", func() { + f := &wire.MaxStreamDataFrame{ + StreamID: 5, + ByteOffset: 0x1337, + } newStreamLambda := sess.streamsMap.newStream sess.streamsMap.newStream = func(id protocol.StreamID) streamI { str := newStreamLambda(id) if id == 5 { - str.(*mocks.MockStreamI).EXPECT().UpdateSendWindow(protocol.ByteCount(0x1337)) + str.(*mocks.MockStreamI).EXPECT().HandleMaxStreamDataFrame(f) } return str } - err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: 5, - ByteOffset: 0x1337, - }) + err := sess.handleMaxStreamDataFrame(f) Expect(err).ToNot(HaveOccurred()) str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) diff --git a/stream.go b/stream.go index 708be651..26ee53c9 100644 --- a/stream.go +++ b/stream.go @@ -24,7 +24,7 @@ type streamI interface { Cancel(error) // methods needed for flow control GetWindowUpdate() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) + HandleMaxStreamDataFrame(*wire.MaxStreamDataFrame) IsFlowControlBlocked() bool } @@ -481,8 +481,8 @@ func (s *stream) StreamID() protocol.StreamID { return s.streamID } -func (s *stream) UpdateSendWindow(n protocol.ByteCount) { - s.flowController.UpdateSendWindow(n) +func (s *stream) HandleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { + s.flowController.UpdateSendWindow(frame.ByteOffset) } func (s *stream) IsFlowControlBlocked() bool { diff --git a/stream_test.go b/stream_test.go index 42260cf6..47c1cd17 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1085,7 +1085,10 @@ var _ = Describe("Stream", func() { It("updates the flow control window", func() { mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(0x42)) - str.UpdateSendWindow(0x42) + str.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: streamID, + ByteOffset: 0x42, + }) }) It("gets a window update", func() {