diff --git a/crypto_stream.go b/crypto_stream.go index 8b8e8bee..de9275b7 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -9,6 +9,7 @@ import ( ) type cryptoStreamI interface { + StreamID() protocol.StreamID io.Reader io.Writer handleStreamFrame(*wire.StreamFrame) error diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go new file mode 100644 index 00000000..d87fb9af --- /dev/null +++ b/mock_crypto_stream_test.go @@ -0,0 +1,152 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoStream) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockCryptoStream is a mock of CryptoStream interface +type MockCryptoStream struct { + ctrl *gomock.Controller + recorder *MockCryptoStreamMockRecorder +} + +// MockCryptoStreamMockRecorder is the mock recorder for MockCryptoStream +type MockCryptoStreamMockRecorder struct { + mock *MockCryptoStream +} + +// NewMockCryptoStream creates a new mock instance +func NewMockCryptoStream(ctrl *gomock.Controller) *MockCryptoStream { + mock := &MockCryptoStream{ctrl: ctrl} + mock.recorder = &MockCryptoStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { + return m.recorder +} + +// Read mocks base method +func (m *MockCryptoStream) Read(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0) +} + +// StreamID mocks base method +func (m *MockCryptoStream) StreamID() protocol.StreamID { + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(protocol.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID)) +} + +// Write mocks base method +func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) +} + +// closeForShutdown mocks base method +func (m *MockCryptoStream) closeForShutdown(arg0 error) { + m.ctrl.Call(m, "closeForShutdown", arg0) +} + +// closeForShutdown indicates an expected call of closeForShutdown +func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0) +} + +// getWindowUpdate mocks base method +func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount { + ret := m.ctrl.Call(m, "getWindowUpdate") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// getWindowUpdate indicates an expected call of getWindowUpdate +func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate)) +} + +// handleMaxStreamDataFrame mocks base method +func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) { + m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0) +} + +// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame +func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0) +} + +// handleStreamFrame mocks base method +func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error { + ret := m.ctrl.Call(m, "handleStreamFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// handleStreamFrame indicates an expected call of handleStreamFrame +func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) +} + +// hasDataForWriting mocks base method +func (m *MockCryptoStream) hasDataForWriting() bool { + ret := m.ctrl.Call(m, "hasDataForWriting") + ret0, _ := ret[0].(bool) + return ret0 +} + +// hasDataForWriting indicates an expected call of hasDataForWriting +func (mr *MockCryptoStreamMockRecorder) hasDataForWriting() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasDataForWriting", reflect.TypeOf((*MockCryptoStream)(nil).hasDataForWriting)) +} + +// popStreamFrame mocks base method +func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { + ret := m.ctrl.Call(m, "popStreamFrame", arg0) + ret0, _ := ret[0].(*wire.StreamFrame) + return ret0 +} + +// popStreamFrame indicates an expected call of popStreamFrame +func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0) +} + +// setReadOffset mocks base method +func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "setReadOffset", arg0) +} + +// setReadOffset indicates an expected call of setReadOffset +func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0) +} diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go new file mode 100644 index 00000000..4b0a5122 --- /dev/null +++ b/mock_stream_getter_test.go @@ -0,0 +1,48 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: StreamGetter) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockStreamGetter is a mock of StreamGetter interface +type MockStreamGetter struct { + ctrl *gomock.Controller + recorder *MockStreamGetterMockRecorder +} + +// MockStreamGetterMockRecorder is the mock recorder for MockStreamGetter +type MockStreamGetterMockRecorder struct { + mock *MockStreamGetter +} + +// NewMockStreamGetter creates a new mock instance +func NewMockStreamGetter(ctrl *gomock.Controller) *MockStreamGetter { + mock := &MockStreamGetter{ctrl: ctrl} + mock.recorder = &MockStreamGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStreamGetter) EXPECT() *MockStreamGetterMockRecorder { + return m.recorder +} + +// GetOrOpenStream mocks base method +func (m *MockStreamGetter) GetOrOpenStream(arg0 protocol.StreamID) (streamI, error) { + ret := m.ctrl.Call(m, "GetOrOpenStream", arg0) + ret0, _ := ret[0].(streamI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOrOpenStream indicates an expected call of GetOrOpenStream +func (mr *MockStreamGetterMockRecorder) GetOrOpenStream(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenStream), arg0) +} diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 64e70b3d..7d1f2a6e 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -6,11 +6,12 @@ package quic import ( context "context" + reflect "reflect" + time "time" + gomock "github.com/golang/mock/gomock" protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" - reflect "reflect" - time "time" ) // MockStreamI is a mock of StreamI interface diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 37326920..637e8cba 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -5,10 +5,11 @@ package quic import ( + reflect "reflect" + gomock "github.com/golang/mock/gomock" protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" - reflect "reflect" ) // MockStreamSender is a mock of StreamSender interface @@ -35,13 +36,13 @@ func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { } // onHasWindowUpdate mocks base method -func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID, arg1 protocol.ByteCount) { - m.ctrl.Call(m, "onHasWindowUpdate", arg0, arg1) +func (m *MockStreamSender) onHasWindowUpdate(arg0 protocol.StreamID) { + m.ctrl.Call(m, "onHasWindowUpdate", arg0) } // onHasWindowUpdate indicates an expected call of onHasWindowUpdate -func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0, arg1) +func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0) } // queueControlFrame mocks base method diff --git a/mockgen.go b/mockgen.go index 21997ace..03ef893d 100644 --- a/mockgen.go +++ b/mockgen.go @@ -2,3 +2,7 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" +//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go" +//go:generate sh -c "goimports -w mock*_test.go" diff --git a/receive_stream.go b/receive_stream.go index 6c9e0460..eb34f24c 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -140,8 +140,8 @@ func (s *receiveStream) Read(p []byte) (int, error) { s.flowController.AddBytesRead(protocol.ByteCount(m)) } // this call triggers the flow controller to increase the flow control window, if necessary - if offset := s.flowController.GetWindowUpdate(); offset != 0 { - s.sender.onHasWindowUpdate(s.streamID, offset) + if s.flowController.HasWindowUpdate() { + s.sender.onHasWindowUpdate(s.streamID) } if s.readPosInFrame >= int(frame.DataLen()) { diff --git a/receive_stream_test.go b/receive_stream_test.go index 058768e8..98e1868a 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -33,6 +33,7 @@ var _ = Describe("Receive Stream", func() { timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = gbytes.TimeoutReader(str, timeout) + strWithTimeout = str }) It("gets stream id", func() { @@ -43,7 +44,7 @@ var _ = Describe("Receive Stream", func() { It("reads a single STREAM frame", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -61,7 +62,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -83,7 +84,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -107,7 +108,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -130,7 +131,7 @@ var _ = Describe("Receive Stream", func() { It("waits until data is available", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() go func() { defer GinkgoRecover() frame := wire.StreamFrame{Data: []byte{0xDE, 0xAD}} @@ -148,7 +149,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -173,7 +174,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -204,7 +205,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foob"), @@ -233,8 +234,8 @@ var _ = Describe("Receive Stream", func() { It("calls the onHasWindowUpdate callback, when the a MAX_STREAM_DATA should be sent", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(1337)) - mockSender.EXPECT().onHasWindowUpdate(streamID, protocol.ByteCount(1337)) + mockFC.EXPECT().HasWindowUpdate().Return(true) + mockSender.EXPECT().onHasWindowUpdate(streamID) frame1 := wire.StreamFrame{ Offset: 0, Data: []byte("foobar"), @@ -318,7 +319,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, @@ -339,7 +340,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) frame1 := wire.StreamFrame{ Offset: 2, Data: []byte{0xBE, 0xEF}, @@ -366,7 +367,7 @@ var _ = Describe("Receive Stream", func() { It("returns EOFs with partial read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD}, @@ -384,7 +385,7 @@ var _ = Describe("Receive Stream", func() { It("handles immediate FINs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() frame := wire.StreamFrame{ Offset: 0, Data: []byte{}, @@ -402,7 +403,7 @@ var _ = Describe("Receive Stream", func() { It("closes when CloseRemote is called", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() str.CloseRemote(0) b := make([]byte, 8) n, err := strWithTimeout.Read(b) @@ -477,7 +478,7 @@ var _ = Describe("Receive Stream", func() { It("doesn't send a RST_STREAM frame, if the FIN was already read", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() // no calls to mockSender.queueControlFrame err := str.handleStreamFrame(&wire.StreamFrame{ StreamID: streamID, @@ -588,7 +589,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) - mockFC.EXPECT().GetWindowUpdate().Times(2) + mockFC.EXPECT().HasWindowUpdate().Times(2) readReturned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -654,7 +655,7 @@ var _ = Describe("Receive Stream", func() { It("is finished if it is only closed for reading", func() { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().GetWindowUpdate() + mockFC.EXPECT().HasWindowUpdate() finishReading() Expect(str.finished()).To(BeTrue()) }) diff --git a/session.go b/session.go index 1cf5ad90..9b269c78 100644 --- a/session.go +++ b/session.go @@ -23,6 +23,10 @@ type unpacker interface { Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) } +type streamGetter interface { + GetOrOpenStream(protocol.StreamID) (streamI, error) +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -323,7 +327,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.perspective, s.version, ) - s.windowUpdateQueue = newWindowUpdateQueue(s.packer.QueueControlFrame) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} return nil } @@ -952,8 +956,8 @@ func (s *session) queueControlFrame(f wire.Frame) { s.scheduleSending() } -func (s *session) onHasWindowUpdate(streamID protocol.StreamID, offset protocol.ByteCount) { - s.windowUpdateQueue.Add(streamID, offset) +func (s *session) onHasWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.Add(id) s.scheduleSending() } diff --git a/session_test.go b/session_test.go index 2eeca88d..58bcb23f 100644 --- a/session_test.go +++ b/session_test.go @@ -879,8 +879,10 @@ var _ = Describe("Session", func() { }) It("adds MAX_STREAM_DATA frames", func() { - sess.windowUpdateQueue.Add(1, 10) - sess.windowUpdateQueue.Add(2, 20) + sess.windowUpdateQueue.callback(&wire.MaxStreamDataFrame{ + StreamID: 2, + ByteOffset: 20, + }) sph := mocks.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().SendingAllowed().Return(true) @@ -888,7 +890,6 @@ var _ = Describe("Session", func() { sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) }) sess.sentPacketHandler = sph @@ -1603,30 +1604,6 @@ var _ = Describe("Session", func() { }) }) - // Context("window updates", func() { - // It("gets stream level window updates", func() { - // _, err := sess.GetOrOpenStream(3) - // Expect(err).ToNot(HaveOccurred()) - // err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow) - // Expect(err).NotTo(HaveOccurred()) - // frames := sess.getWindowUpdateFrames() - // Expect(frames).To(HaveLen(1)) - // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3))) - // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) - // }) - - // It("gets connection level window updates", func() { - // _, err := sess.GetOrOpenStream(5) - // Expect(err).NotTo(HaveOccurred()) - // err = sess.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) - // Expect(err).NotTo(HaveOccurred()) - // frames := sess.getWindowUpdateFrames() - // Expect(frames).To(HaveLen(1)) - // Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0))) - // Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow * 2)) - // }) - // }) - It("returns the local address", func() { addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} mconn.localAddr = addr diff --git a/stream.go b/stream.go index 9da1fb60..398557fd 100644 --- a/stream.go +++ b/stream.go @@ -18,7 +18,7 @@ const ( type streamSender interface { scheduleSending() queueControlFrame(wire.Frame) - onHasWindowUpdate(protocol.StreamID, protocol.ByteCount) + onHasWindowUpdate(protocol.StreamID) } type streamI interface { diff --git a/window_update_queue.go b/window_update_queue.go index bed3d60c..e1ca3152 100644 --- a/window_update_queue.go +++ b/window_update_queue.go @@ -10,31 +10,48 @@ import ( type windowUpdateQueue struct { mutex sync.Mutex - queue map[protocol.StreamID]protocol.ByteCount - callback func(wire.Frame) + queue map[protocol.StreamID]bool // used as a set + callback func(wire.Frame) + cryptoStream cryptoStreamI + streamGetter streamGetter } -func newWindowUpdateQueue(cb func(wire.Frame)) *windowUpdateQueue { +func newWindowUpdateQueue(streamGetter streamGetter, cryptoStream cryptoStreamI, cb func(wire.Frame)) *windowUpdateQueue { return &windowUpdateQueue{ - queue: make(map[protocol.StreamID]protocol.ByteCount), - callback: cb, + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + callback: cb, } } -func (q *windowUpdateQueue) Add(stream protocol.StreamID, offset protocol.ByteCount) { +func (q *windowUpdateQueue) Add(id protocol.StreamID) { q.mutex.Lock() - q.queue[stream] = offset + q.queue[id] = true q.mutex.Unlock() } func (q *windowUpdateQueue) QueueAll() { q.mutex.Lock() - for stream, offset := range q.queue { + var offset protocol.ByteCount + for id := range q.queue { + if id == q.cryptoStream.StreamID() { + offset = q.cryptoStream.getWindowUpdate() + } else { + str, err := q.streamGetter.GetOrOpenStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset = str.getWindowUpdate() + } + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } q.callback(&wire.MaxStreamDataFrame{ - StreamID: stream, + StreamID: id, ByteOffset: offset, }) - delete(q.queue, stream) + delete(q.queue, id) } q.mutex.Unlock() } diff --git a/window_update_queue_test.go b/window_update_queue_test.go index ab488e16..4e377ad7 100644 --- a/window_update_queue_test.go +++ b/window_update_queue_test.go @@ -1,6 +1,7 @@ package quic import ( + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -10,37 +11,77 @@ import ( var _ = Describe("Window Update Queue", func() { var ( q *windowUpdateQueue + streamGetter *MockStreamGetter queuedFrames []wire.Frame + cryptoStream *MockCryptoStream ) BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) + cryptoStream = NewMockCryptoStream(mockCtrl) + cryptoStream.EXPECT().StreamID().Return(protocol.StreamID(0)).AnyTimes() queuedFrames = queuedFrames[:0] - q = newWindowUpdateQueue(func(f wire.Frame) { + q = newWindowUpdateQueue(streamGetter, cryptoStream, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }) }) It("adds stream offsets and gets MAX_STREAM_DATA frames", func() { - q.Add(1, 10) - q.Add(2, 20) - q.Add(3, 30) + stream1 := NewMockStreamI(mockCtrl) + stream1.EXPECT().getWindowUpdate().Return(protocol.ByteCount(10)) + stream3 := NewMockStreamI(mockCtrl) + stream3.EXPECT().getWindowUpdate().Return(protocol.ByteCount(30)) + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(3)).Return(stream3, nil) + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(1)).Return(stream1, nil) + q.Add(3) + q.Add(1) q.QueueAll() Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 1, ByteOffset: 10})) - Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) Expect(queuedFrames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 3, ByteOffset: 30})) }) It("deletes the entry after getting the MAX_STREAM_DATA frame", func() { - q.Add(10, 100) + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(100)) + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(10)).Return(stream10, nil) + q.Add(10) q.QueueAll() Expect(queuedFrames).To(HaveLen(1)) q.QueueAll() Expect(queuedFrames).To(HaveLen(1)) }) - It("replaces old entries", func() { - q.Add(10, 100) - q.Add(10, 200) + It("doesn't queue a MAX_STREAM_DATA for a closed stream", func() { + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(12)).Return(nil, nil) + q.Add(12) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("doesn't queue a MAX_STREAM_DATA if the flow controller returns an offset of 0", func() { + stream5 := NewMockStreamI(mockCtrl) + stream5.EXPECT().getWindowUpdate().Return(protocol.ByteCount(0)) + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(5)).Return(stream5, nil) + q.Add(5) + q.QueueAll() + Expect(queuedFrames).To(BeEmpty()) + }) + + It("adds MAX_STREAM_DATA frames for the crypto stream", func() { + cryptoStream.EXPECT().getWindowUpdate().Return(protocol.ByteCount(42)) + q.Add(0) + q.QueueAll() + Expect(queuedFrames).To(Equal([]wire.Frame{ + &wire.MaxStreamDataFrame{StreamID: 0, ByteOffset: 42}, + })) + }) + + It("deduplicates", func() { + stream10 := NewMockStreamI(mockCtrl) + stream10.EXPECT().getWindowUpdate().Return(protocol.ByteCount(200)) + streamGetter.EXPECT().GetOrOpenStream(protocol.StreamID(10)).Return(stream10, nil) + q.Add(10) + q.Add(10) q.QueueAll() Expect(queuedFrames).To(Equal([]wire.Frame{ &wire.MaxStreamDataFrame{StreamID: 10, ByteOffset: 200},