From 9930de68c88bc29a7afec382f17a34ecc34ba29a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 19 May 2016 11:53:51 +0700 Subject: [PATCH] send stream-level Blocked frames at the right time fixes #110 --- blocked_manager.go | 42 +++++++++++++++++++++ blocked_manager_test.go | 43 +++++++++++++++++++++ packet_packer.go | 42 ++++++++++++++++++++- packet_packer_test.go | 82 +++++++++++++++++++++++++++++++++++++++++ session.go | 32 ++++------------ session_test.go | 26 ++++--------- stream.go | 8 ++-- stream_test.go | 2 +- 8 files changed, 229 insertions(+), 48 deletions(-) create mode 100644 blocked_manager.go create mode 100644 blocked_manager_test.go diff --git a/blocked_manager.go b/blocked_manager.go new file mode 100644 index 00000000..9a5ff03d --- /dev/null +++ b/blocked_manager.go @@ -0,0 +1,42 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/protocol" +) + +type blockedManager struct { + blockedStreams map[protocol.StreamID]protocol.ByteCount + mutex sync.RWMutex +} + +func newBlockedManager() *blockedManager { + return &blockedManager{ + blockedStreams: make(map[protocol.StreamID]protocol.ByteCount), + } +} + +func (m *blockedManager) AddBlockedStream(streamID protocol.StreamID, offset protocol.ByteCount) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.blockedStreams[streamID] = offset +} + +func (m *blockedManager) GetBlockedFrame(streamID protocol.StreamID, offset protocol.ByteCount) *frames.BlockedFrame { + m.mutex.RLock() + defer m.mutex.RUnlock() + + blockedOffset, ok := m.blockedStreams[streamID] + if !ok { + return nil + } + if blockedOffset > offset { + return nil + } + + delete(m.blockedStreams, streamID) + return &frames.BlockedFrame{StreamID: streamID} +} diff --git a/blocked_manager_test.go b/blocked_manager_test.go new file mode 100644 index 00000000..8a0ac227 --- /dev/null +++ b/blocked_manager_test.go @@ -0,0 +1,43 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("WindowUpdateManager", func() { + var bm *blockedManager + + BeforeEach(func() { + bm = newBlockedManager() + }) + + It("accepts new entries", func() { + bm.AddBlockedStream(1337, 0x1337) + Expect(bm.blockedStreams).To(HaveKey(protocol.StreamID(1337))) + Expect(bm.blockedStreams[1337]).To(Equal(protocol.ByteCount(0x1337))) + }) + + It("gets a blocked frame for the right offset", func() { + bm.AddBlockedStream(1337, 0x1337) + Expect(bm.GetBlockedFrame(1337, 0x1337)).To(Equal(&frames.BlockedFrame{StreamID: 1337})) + }) + + It("doesn't get a blocked frame twice for the same offset", func() { + bm.AddBlockedStream(1337, 0x1337) + Expect(bm.GetBlockedFrame(1337, 0x1337)).ToNot(BeNil()) + Expect(bm.GetBlockedFrame(1337, 0x1337)).To(BeNil()) + }) + + It("doesn't get a blocked frame for smaller offsets", func() { + bm.AddBlockedStream(1337, 0x1337) + Expect(bm.GetBlockedFrame(1337, 0x1336)).To(BeNil()) + }) + + It("doesn't get a blocked frame for the wrong stream", func() { + bm.AddBlockedStream(1337, 0x1337) + Expect(bm.GetBlockedFrame(1336, 0x1337)).To(BeNil()) + }) +}) diff --git a/packet_packer.go b/packet_packer.go index 8750360e..698d1574 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -30,10 +30,22 @@ type packetPacker struct { streamFrameQueue streamFrameQueue controlFrames []frames.Frame + blockedManager *blockedManager lastPacketNumber protocol.PacketNumber } +func newPacketPacker(connectionID protocol.ConnectionID, aead crypto.AEAD, sentPacketHandler ackhandler.SentPacketHandler, connectionParametersHandler *handshake.ConnectionParametersManager, version protocol.VersionNumber) *packetPacker { + return &packetPacker{ + aead: aead, + connectionID: connectionID, + connectionParametersManager: connectionParametersHandler, + version: version, + sentPacketHandler: sentPacketHandler, + blockedManager: newBlockedManager(), + } +} + func (p *packetPacker) AddStreamFrame(f frames.StreamFrame) { p.streamFrameQueue.Push(&f, false) } @@ -42,6 +54,16 @@ func (p *packetPacker) AddHighPrioStreamFrame(f frames.StreamFrame) { p.streamFrameQueue.Push(&f, true) } +func (p *packetPacker) AddBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) { + // TODO: send out connection-level BlockedFrames at the right time + // see https://github.com/lucas-clemente/quic-go/issues/113 + if streamID == 0 { + p.controlFrames = append(p.controlFrames, &frames.BlockedFrame{StreamID: 0}) + } + + p.blockedManager.AddBlockedStream(streamID, byteOffset) +} + func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, includeStreamFrames bool) (*packedPacket, error) { // don't send out packets that only contain a StopWaitingFrame if len(controlFrames) == 0 && (p.streamFrameQueue.Len() == 0 || !includeStreamFrames) { @@ -157,6 +179,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra } hasStreamFrames := false + lastFrameIsBlockedFrame := false // temporarily increase the maxFrameSize by 2 bytes // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set @@ -191,11 +214,28 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra payloadLength += frameMinLength payloadFrames = append(payloadFrames, frame) hasStreamFrames = true + lastFrameIsBlockedFrame = false + + blockedFrame := p.blockedManager.GetBlockedFrame(frame.StreamID, frame.Offset+protocol.ByteCount(len(frame.Data))) + if blockedFrame != nil { + blockedMinLength, _ := blockedFrame.MinLength() // BlockedFrame.MinLength *never* returns an error + if payloadLength+blockedMinLength <= maxFrameSize { + payloadFrames = append(payloadFrames, blockedFrame) + payloadLength += blockedMinLength + lastFrameIsBlockedFrame = true + } else { + p.controlFrames = append(p.controlFrames, blockedFrame) + } + } } // remove the dataLen for the last StreamFrame in the packet if hasStreamFrames { - lastStreamFrame, ok := payloadFrames[len(payloadFrames)-1].(*frames.StreamFrame) + lastStreamFrameIndex := len(payloadFrames) - 1 + if lastFrameIsBlockedFrame { + lastStreamFrameIndex-- + } + lastStreamFrame, ok := payloadFrames[lastStreamFrameIndex].(*frames.StreamFrame) if !ok { return nil, errors.New("PacketPacker BUG: StreamFrame type assertion failed") } diff --git a/packet_packer_test.go b/packet_packer_test.go index 7171e72a..3d35e99b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -40,6 +40,7 @@ var _ = Describe("Packet packer", func() { aead: aead, connectionParametersManager: handshake.NewConnectionParamatersManager(), sentPacketHandler: newMockSentPacketHandler(), + blockedManager: newBlockedManager(), } publicHeaderLen = 1 + 8 + 1 // 1 flag byte, 8 connection ID, 1 packet number }) @@ -332,6 +333,87 @@ var _ = Describe("Packet packer", func() { }) }) + Context("Blocked frames", func() { + It("adds a blocked frame to a packet if there is enough space", func() { + length := 100 + packer.AddBlocked(5, protocol.ByteCount(length)) + f := frames.StreamFrame{ + StreamID: 5, + Data: bytes.Repeat([]byte{'f'}, length), + } + packer.AddStreamFrame(f) + p, err := packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(2)) + Expect(p[1]).To(Equal(&frames.BlockedFrame{StreamID: 5})) + }) + + It("removes the dataLen attribute from the last StreamFrame, even if the last frame is a BlockedFrame", func() { + length := 100 + packer.AddBlocked(5, protocol.ByteCount(length)) + f := frames.StreamFrame{ + StreamID: 5, + Data: bytes.Repeat([]byte{'f'}, length), + } + packer.AddStreamFrame(f) + p, err := packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) + }) + + It("correctly removes the dataLen attribute from the last StreamFrame, when packing one StreamFrame, one BlockedFrame, and another StreamFrame", func() { + length := 10 + packer.AddBlocked(5, protocol.ByteCount(length)) + f := frames.StreamFrame{ + StreamID: 5, + Data: bytes.Repeat([]byte{'f'}, length), + } + packer.AddStreamFrame(f) + f = frames.StreamFrame{ + StreamID: 7, + Data: []byte("foobar"), + } + packer.AddStreamFrame(f) + p, err := packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(3)) + Expect(p[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p[2].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) + }) + + It("packs a BlockedFrame in the next packet if the current packet doesn't have enough space", func() { + dataLen := int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen) - (1 + 1 + 2) + 1 + packer.AddBlocked(5, protocol.ByteCount(dataLen)) + f := frames.StreamFrame{ + StreamID: 5, + Data: bytes.Repeat([]byte{'f'}, dataLen), + } + packer.AddStreamFrame(f) + p, err := packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(1)) + p, err = packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(1)) + Expect(p[0]).To(Equal(&frames.BlockedFrame{StreamID: 5})) + }) + + // TODO: fix this once connection-level BlockedFrames are sent out at the right time + // see https://github.com/lucas-clemente/quic-go/issues/113 + It("packs a connection-level BlockedFrame", func() { + packer.AddBlocked(0, 0x1337) + f := frames.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + } + packer.AddStreamFrame(f) + p, err := packer.composeNextPacket(nil, publicHeaderLen, true) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(2)) + Expect(p[0]).To(Equal(&frames.BlockedFrame{StreamID: 0})) + }) + }) + It("says whether it is empty", func() { Expect(packer.Empty()).To(BeTrue()) f := frames.StreamFrame{ diff --git a/session.go b/session.go index e553823d..06810f53 100644 --- a/session.go +++ b/session.go @@ -46,12 +46,10 @@ type Session struct { streams map[protocol.StreamID]*stream streamsMutex sync.RWMutex - sentPacketHandler ackhandler.SentPacketHandler - receivedPacketHandler ackhandler.ReceivedPacketHandler - stopWaitingManager ackhandler.StopWaitingManager - windowUpdateManager *windowUpdateManager - blockedFrameQueue []*frames.BlockedFrame - blockedFrameQueueMutex sync.Mutex + sentPacketHandler ackhandler.SentPacketHandler + receivedPacketHandler ackhandler.ReceivedPacketHandler + stopWaitingManager ackhandler.StopWaitingManager + windowUpdateManager *windowUpdateManager flowController *flowController // connection level flow controller @@ -108,13 +106,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return nil, err } - session.packer = &packetPacker{ - aead: session.cryptoSetup, - connectionParametersManager: session.connectionParametersManager, - sentPacketHandler: session.sentPacketHandler, - connectionID: connectionID, - version: v, - } + session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.sentPacketHandler, session.connectionParametersManager, v) session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} return session, err @@ -456,13 +448,6 @@ func (s *Session) sendPacket() error { controlFrames = append(controlFrames, wuf) } - s.blockedFrameQueueMutex.Lock() - for _, bf := range s.blockedFrameQueue { - controlFrames = append(controlFrames, bf) - } - s.blockedFrameQueue = s.blockedFrameQueue[:0] - s.blockedFrameQueueMutex.Unlock() - ack, err := s.receivedPacketHandler.GetAckFrame(true) if err != nil { return err @@ -542,11 +527,8 @@ func (s *Session) updateReceiveFlowControlWindow(streamID protocol.StreamID, byt return nil } -func (s *Session) streamBlocked(streamID protocol.StreamID) { - s.blockedFrameQueueMutex.Lock() - defer s.blockedFrameQueueMutex.Unlock() - - s.blockedFrameQueue = append(s.blockedFrameQueue, &frames.BlockedFrame{StreamID: streamID}) +func (s *Session) streamBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) { + s.packer.AddBlocked(streamID, byteOffset) } // OpenStream creates a new stream open for reading and writing diff --git a/session_test.go b/session_test.go index c4543585..aff121c2 100644 --- a/session_test.go +++ b/session_test.go @@ -381,28 +381,18 @@ var _ = Describe("Session", func() { Expect(conn.written).To(HaveLen(int(protocol.WindowUpdateNumRepitions))) // no packet was sent }) - It("sends queued Blocked frames", func() { - bf1 := frames.BlockedFrame{StreamID: 0x1337} - bf2 := frames.BlockedFrame{StreamID: 0xDECAFBAD} - session.blockedFrameQueue = append(session.blockedFrameQueue, &bf1) - session.blockedFrameQueue = append(session.blockedFrameQueue, &bf2) + It("queues a Blocked frames", func() { + len := 500 + frame := frames.StreamFrame{ + StreamID: 0x1337, + Data: bytes.Repeat([]byte{'f'}, len), + } + session.streamBlocked(0x1337, protocol.ByteCount(len)) + session.packer.AddStreamFrame(frame) err := session.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x05, 0x37, 0x13, 0, 0}))) - Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x05, 0xAD, 0xFB, 0xCA, 0xDE}))) - }) - - It("only sends every queued Blocked frame once", func() { - bf := frames.BlockedFrame{StreamID: 0x1337} - session.blockedFrameQueue = append(session.blockedFrameQueue, &bf) - err := session.sendPacket() - Expect(err).NotTo(HaveOccurred()) - session.queueStreamFrame(&frames.StreamFrame{StreamID: 5, Data: []byte("foobar")}) // queue something, so that a packet can actually be sent - err = session.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(conn.written).To(HaveLen(2)) - Expect(conn.written[1]).ToNot(ContainSubstring(string([]byte{0x05, 0x37, 0x13, 0, 0}))) }) It("sends public reset", func() { diff --git a/stream.go b/stream.go index b8787032..200f340f 100644 --- a/stream.go +++ b/stream.go @@ -15,7 +15,7 @@ import ( type streamHandler interface { queueStreamFrame(*frames.StreamFrame) error updateReceiveFlowControlWindow(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - streamBlocked(streamID protocol.StreamID) + streamBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) } var ( @@ -277,14 +277,16 @@ func (s *stream) maybeTriggerBlocked() { streamBlocked := s.flowController.MaybeTriggerBlocked() if streamBlocked { - s.session.streamBlocked(s.streamID) + s.session.streamBlocked(s.streamID, s.writeOffset) } if s.contributesToConnectionFlowControl { connectionBlocked := s.connectionFlowController.MaybeTriggerBlocked() if connectionBlocked { - s.session.streamBlocked(0) + // TODO: send out connection-level BlockedFrames at the right time + // see https://github.com/lucas-clemente/quic-go/issues/113 + s.session.streamBlocked(0, 0) } } } diff --git a/stream_test.go b/stream_test.go index 934a5d48..6cbcf713 100644 --- a/stream_test.go +++ b/stream_test.go @@ -28,7 +28,7 @@ func (m *mockStreamHandler) queueStreamFrame(f *frames.StreamFrame) error { return nil } -func (m *mockStreamHandler) streamBlocked(streamID protocol.StreamID) { +func (m *mockStreamHandler) streamBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) { m.receivedBlockedCalled = true m.receivedBlockedForStream = streamID }