From d571c0bfd5d15adcf7fbdfa93a62683813d1833d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 20 May 2016 12:16:33 +0700 Subject: [PATCH] remove stream from BlockManager when a WindowUpdate increases its offset fixes #130 --- blocked_manager.go | 7 ++++++ blocked_manager_test.go | 11 ++++++++ packet_packer.go | 5 ++-- session.go | 14 ++++++++--- session_test.go | 56 ++++++++++++++++++++++++++++++----------- stream.go | 4 ++- stream_test.go | 14 +++++++++++ 7 files changed, 91 insertions(+), 20 deletions(-) diff --git a/blocked_manager.go b/blocked_manager.go index 9a5ff03d..a16f41bd 100644 --- a/blocked_manager.go +++ b/blocked_manager.go @@ -25,6 +25,13 @@ func (m *blockedManager) AddBlockedStream(streamID protocol.StreamID, offset pro m.blockedStreams[streamID] = offset } +func (m *blockedManager) RemoveBlockedStream(streamID protocol.StreamID) { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.blockedStreams, streamID) +} + func (m *blockedManager) GetBlockedFrame(streamID protocol.StreamID, offset protocol.ByteCount) *frames.BlockedFrame { m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/blocked_manager_test.go b/blocked_manager_test.go index 8a0ac227..4bc2b8d5 100644 --- a/blocked_manager_test.go +++ b/blocked_manager_test.go @@ -31,6 +31,17 @@ var _ = Describe("WindowUpdateManager", func() { Expect(bm.GetBlockedFrame(1337, 0x1337)).To(BeNil()) }) + It("removes the blocked entry when the stream is not blocked anymore", func() { + bm.AddBlockedStream(1337, 0x1337) + bm.RemoveBlockedStream(1337) + Expect(bm.GetBlockedFrame(1337, 0x1337)).To(BeNil()) + }) + + It("doesn't care if the stream was previously blocked when removing the block", func() { + bm.RemoveBlockedStream(1337) + Expect(bm.GetBlockedFrame(1337, 0x1337)).To(BeNil()) + }) + It("doesn't get a blocked frame for smaller offsets", func() { bm.AddBlockedStream(1337, 0x1337) Expect(bm.GetBlockedFrame(1337, 0x1336)).To(BeNil()) diff --git a/packet_packer.go b/packet_packer.go index c9870e91..98b8467c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -35,14 +35,14 @@ type packetPacker struct { lastPacketNumber protocol.PacketNumber } -func newPacketPacker(connectionID protocol.ConnectionID, aead crypto.AEAD, sentPacketHandler ackhandler.SentPacketHandler, connectionParametersHandler *handshake.ConnectionParametersManager, version protocol.VersionNumber) *packetPacker { +func newPacketPacker(connectionID protocol.ConnectionID, aead crypto.AEAD, sentPacketHandler ackhandler.SentPacketHandler, connectionParametersHandler *handshake.ConnectionParametersManager, blockedManager *blockedManager, version protocol.VersionNumber) *packetPacker { return &packetPacker{ aead: aead, connectionID: connectionID, connectionParametersManager: connectionParametersHandler, version: version, sentPacketHandler: sentPacketHandler, - blockedManager: newBlockedManager(), + blockedManager: blockedManager, } } @@ -57,6 +57,7 @@ func (p *packetPacker) AddHighPrioStreamFrame(f frames.StreamFrame) { func (p *packetPacker) AddBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) { // TODO: send out connection-level BlockedFrames at the right time // see https://github.com/lucas-clemente/quic-go/issues/113 + // TODO: remove this function completely once #113 is resolved if streamID == 0 { p.controlFrames = append(p.controlFrames, &frames.BlockedFrame{StreamID: 0}) } diff --git a/session.go b/session.go index beb67072..f6baa06e 100644 --- a/session.go +++ b/session.go @@ -51,6 +51,7 @@ type Session struct { receivedPacketHandler ackhandler.ReceivedPacketHandler stopWaitingManager ackhandler.StopWaitingManager windowUpdateManager *windowUpdateManager + blockedManager *blockedManager flowController flowcontrol.FlowController // connection level flow controller @@ -96,6 +97,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol stopWaitingManager: stopWaitingManager, flowController: flowcontrol.NewFlowController(0, connectionParametersManager), windowUpdateManager: newWindowUpdateManager(), + blockedManager: newBlockedManager(), receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan struct{}, 1), sendingScheduled: make(chan struct{}, 1), @@ -113,7 +115,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return nil, err } - session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.sentPacketHandler, session.connectionParametersManager, v) + session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.sentPacketHandler, session.connectionParametersManager, session.blockedManager, v) session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} return session, err @@ -302,7 +304,10 @@ func (s *Session) isValidStreamID(streamID protocol.StreamID) bool { func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { if frame.StreamID == 0 { - s.flowController.UpdateSendWindow(frame.ByteOffset) + updated := s.flowController.UpdateSendWindow(frame.ByteOffset) + if updated { + s.blockedManager.RemoveBlockedStream(0) + } s.streamsMutex.RLock() // tell all streams that the connection-level was updated for _, stream := range s.streams { @@ -322,7 +327,10 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error } s.streamsMutex.RUnlock() - stream.UpdateSendFlowControlWindow(frame.ByteOffset) + updated := stream.UpdateSendFlowControlWindow(frame.ByteOffset) + if updated { + s.blockedManager.RemoveBlockedStream(frame.StreamID) + } } return nil diff --git a/session_test.go b/session_test.go index bcc30454..b79e56fb 100644 --- a/session_test.go +++ b/session_test.go @@ -381,26 +381,54 @@ var _ = Describe("Session", func() { Expect(conn.written).To(HaveLen(int(protocol.WindowUpdateNumRepitions))) // no packet was sent }) - It("queues a Blocked frames", func() { - len := 500 - frame := frames.StreamFrame{ - StreamID: 0x1337, - Data: bytes.Repeat([]byte{'f'}, len), - } - session.streamBlocked(0x1337, protocol.ByteCount(len)) - session.packer.AddStreamFrame(frame) - err := session.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(1)) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x05, 0x37, 0x13, 0, 0}))) - }) - It("sends public reset", func() { err := session.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) }) + + Context("Blocked", func() { + It("queues a Blocked frames", func() { + len := 500 + frame := frames.StreamFrame{ + StreamID: 0x1337, + Data: bytes.Repeat([]byte{'f'}, len), + } + session.streamBlocked(0x1337, protocol.ByteCount(len)) + session.packer.AddStreamFrame(frame) + err := session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(conn.written).To(HaveLen(1)) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x05, 0x37, 0x13, 0, 0}))) + }) + + It("does not send a blocked frame for a stream if a WindowUpdate arrived before", func() { + len := 500 + _, err := session.OpenStream(0x1337) + Expect(err).ToNot(HaveOccurred()) + session.streamBlocked(0x1337, protocol.ByteCount(len)) + wuf := frames.WindowUpdateFrame{ + StreamID: 0x1337, + ByteOffset: protocol.ByteCount(len * 2), + } + err = session.handleWindowUpdateFrame(&wuf) + Expect(err).ToNot(HaveOccurred()) + Expect(session.blockedManager.GetBlockedFrame(0x1337, protocol.ByteCount(len))).To(BeNil()) + }) + + It("does not send a blocked frame for the connection if a WindowUpdate arrived before", func() { + len := 500 + session.streamBlocked(0, protocol.ByteCount(len)) + wuf := frames.WindowUpdateFrame{ + StreamID: 0, + ByteOffset: protocol.ByteCount(len * 2), + } + err := session.handleWindowUpdateFrame(&wuf) + Expect(err).ToNot(HaveOccurred()) + Expect(session.blockedManager.GetBlockedFrame(0, protocol.ByteCount(len))).To(BeNil()) + }) + }) }) Context("scheduling sending", func() { diff --git a/stream.go b/stream.go index c76c332c..4af37f01 100644 --- a/stream.go +++ b/stream.go @@ -161,10 +161,12 @@ func (s *stream) ConnectionFlowControlWindowUpdated() { s.windowUpdateOrErrCond.Broadcast() } -func (s *stream) UpdateSendFlowControlWindow(n protocol.ByteCount) { +func (s *stream) UpdateSendFlowControlWindow(n protocol.ByteCount) bool { if s.flowController.UpdateSendWindow(n) { s.windowUpdateOrErrCond.Broadcast() + return true } + return false } func (s *stream) Write(p []byte) (int, error) { diff --git a/stream_test.go b/stream_test.go index a88449f5..e1eef2ff 100644 --- a/stream_test.go +++ b/stream_test.go @@ -352,6 +352,20 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) }) + It("returns true when the flow control window was updated", func() { + updated := str.flowController.UpdateSendWindow(4) + Expect(updated).To(BeTrue()) + updated = str.UpdateSendFlowControlWindow(5) + Expect(updated).To(BeTrue()) + }) + + It("returns false when the flow control window was not updated", func() { + updated := str.flowController.UpdateSendWindow(4) + Expect(updated).To(BeTrue()) + updated = str.UpdateSendFlowControlWindow(3) + Expect(updated).To(BeFalse()) + }) + It("waits for a stream flow control window update", func() { var b bool updated := str.flowController.UpdateSendWindow(1)