From f8d24ef1e9b8ce829ed799a3d9d5e3db12ced0e0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 2 Jun 2023 14:14:04 +0300 Subject: [PATCH] don't use closures for passing OnLost and OnAcked STREAM frame callbacks (#3833) --- connection.go | 24 +- framer.go | 25 +- framer_test.go | 68 +++--- internal/ackhandler/frame.go | 8 + internal/ackhandler/packet.go | 3 + internal/ackhandler/sent_packet_handler.go | 11 +- mock_frame_source_test.go | 4 +- mock_send_stream_internal_test.go | 7 +- mock_stream_internal_test.go | 7 +- packet_packer.go | 44 ++-- packet_packer_test.go | 85 ++++--- send_stream.go | 27 +-- send_stream_test.go | 262 ++++++++++++--------- stream.go | 2 +- 14 files changed, 327 insertions(+), 250 deletions(-) diff --git a/connection.go b/connection.go index c4a91a35..b45de4dc 100644 --- a/connection.go +++ b/connection.go @@ -1825,7 +1825,7 @@ func (s *connection) maybeSendAckOnlyPacket() error { } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return nil } @@ -1893,7 +1893,7 @@ func (s *connection) sendPacket() (bool, error) { if err != nil { return false, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } @@ -1904,13 +1904,13 @@ func (s *connection) sendPacket() (bool, error) { } return false, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -1989,6 +1989,7 @@ func (s *connection) logShortHeaderPacket( destConnID protocol.ConnectionID, ackFrame *wire.AckFrame, frames []*ackhandler.Frame, + streamFrames []ackhandler.StreamFrame, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, @@ -2004,17 +2005,23 @@ func (s *connection) logShortHeaderPacket( if ackFrame != nil { wire.LogFrame(s.logger, ackFrame, true) } - for _, frame := range frames { - wire.LogFrame(s.logger, frame.Frame, true) + for _, f := range frames { + wire.LogFrame(s.logger, f.Frame, true) + } + for _, f := range streamFrames { + wire.LogFrame(s.logger, f.Frame, true) } } // tracing if s.tracer != nil { - fs := make([]logging.Frame, 0, len(frames)) + fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) } + for _, f := range streamFrames { + fs = append(fs, logutils.ConvertFrame(f.Frame)) + } var ack *logging.AckFrame if ackFrame != nil { ack = logutils.ConvertAckFrame(ackFrame) @@ -2042,6 +2049,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.DestConnID, packet.shortHdrPacket.Ack, packet.shortHdrPacket.Frames, + packet.shortHdrPacket.StreamFrames, packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, @@ -2060,7 +2068,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { s.logLongHeaderPacket(p) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) } } diff --git a/framer.go b/framer.go index 79c28e23..92a3805a 100644 --- a/framer.go +++ b/framer.go @@ -18,7 +18,7 @@ type framer interface { AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) Handle0RTTRejection() error } @@ -91,9 +91,9 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + startLen := len(frames) var length protocol.ByteCount - var lastFrame *ackhandler.Frame f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet numActiveStreams := f.streamQueue.Len() @@ -115,28 +115,27 @@ func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol // Therefore, we can pretend to have more bytes available when popping // the STREAM frame (which will always have the DataLen set). remainingLen += quicvarint.Len(uint64(remainingLen)) - frame, hasMoreData := str.popStreamFrame(remainingLen, v) + frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v) if hasMoreData { // put the stream back in the queue (at the end) f.streamQueue.PushBack(id) - } else { // no more data to send. Stream is not active any more + } else { // no more data to send. Stream is not active delete(f.activeStreams, id) } - // The frame can be nil + // The frame can be "nil" // * if the receiveStream was canceled after it said it had data // * the remaining size doesn't allow us to add another STREAM frame - if frame == nil { + if !ok { continue } frames = append(frames, frame) - length += frame.Length(v) - lastFrame = frame + length += frame.Frame.Length(v) } f.mutex.Unlock() - if lastFrame != nil { - lastFrameLen := lastFrame.Length(v) + if len(frames) > startLen { + l := frames[len(frames)-1].Frame.Length(v) // account for the smaller size of the last STREAM frame - lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false - length += lastFrame.Length(v) - lastFrameLen + frames[len(frames)-1].Frame.DataLenPresent = false + length += frames[len(frames)-1].Frame.Length(v) - l } return frames, length } diff --git a/framer_test.go b/framer_test.go index 041c24bf..7adc53e4 100644 --- a/framer_test.go +++ b/framer_test.go @@ -124,11 +124,11 @@ var _ = Describe("Framer", func() { Offset: 42, DataLenPresent: true, } - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) fs, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1) Expect(fs).To(HaveLen(1)) - Expect(fs[0].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(fs[0].Frame.DataLenPresent).To(BeFalse()) Expect(length).To(Equal(f.Length(version))) }) @@ -139,8 +139,8 @@ var _ = Describe("Framer", func() { Expect(framer.HasData()).To(BeTrue()) f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")} f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f1}, true) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f2}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false) frames, _ := framer.AppendStreamFrames(nil, protocol.MaxByteCount, protocol.Version1) Expect(frames).To(HaveLen(1)) Expect(frames[0].Frame).To(Equal(f1)) @@ -158,15 +158,15 @@ var _ = Describe("Framer", func() { Data: []byte("foobar"), DataLenPresent: true, } - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) - mdf := &wire.MaxDataFrame{MaximumData: 1337} - frames := []*ackhandler.Frame{{Frame: mdf}} + f0 := ackhandler.StreamFrame{Frame: &wire.StreamFrame{StreamID: 9999}} + frames := []ackhandler.StreamFrame{f0} fs, length := framer.AppendStreamFrames(frames, 1000, protocol.Version1) Expect(fs).To(HaveLen(2)) - Expect(fs[0].Frame).To(Equal(mdf)) - Expect(fs[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) - Expect(fs[1].Frame.(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + Expect(fs[0]).To(Equal(f0)) + Expect(fs[1].Frame.Data).To(Equal([]byte("foobar"))) + Expect(fs[1].Frame.DataLenPresent).To(BeFalse()) Expect(length).To(Equal(f.Length(version))) }) @@ -178,7 +178,7 @@ var _ = Describe("Framer", func() { Data: []byte("foobar"), DataLenPresent: true, } - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1) @@ -194,8 +194,8 @@ var _ = Describe("Framer", func() { Data: []byte("foobar"), DataLenPresent: true, } - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(nil, false) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{}, false, false) + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1) @@ -207,8 +207,8 @@ var _ = Describe("Framer", func() { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f1}, true) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f2}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false) framer.AddActiveStream(id1) // only add it once frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1) Expect(frames).To(HaveLen(1)) @@ -227,9 +227,9 @@ var _ = Describe("Framer", func() { f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f11}, true) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f12}, false) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f2}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f11}, true, true) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f12}, true, false) + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false) framer.AddActiveStream(id1) // only add it once framer.AddActiveStream(id2) // first a frame from stream 1 @@ -252,8 +252,8 @@ var _ = Describe("Framer", func() { f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} // both streams have more data, and will be re-queued - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f1}, true) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f2}, true) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, true) + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, true) framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, length := framer.AppendStreamFrames(nil, 1000, protocol.Version1) @@ -268,8 +268,8 @@ var _ = Describe("Framer", func() { streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) f1 := &wire.StreamFrame{Data: []byte("foobar")} f2 := &wire.StreamFrame{Data: []byte("foobaz")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f1}, false) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f2}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, true, false) + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, true, false) framer.AddActiveStream(id2) framer.AddActiveStream(id1) frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1) @@ -281,7 +281,7 @@ var _ = Describe("Framer", func() { It("only asks a stream for data once, even if it was reported active multiple times", func() { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) f := &wire.StreamFrame{Data: []byte("foobar")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) // only one call to this function + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) // only one call to this function framer.AddActiveStream(id1) framer.AddActiveStream(id1) frames, _ := framer.AppendStreamFrames(nil, 1000, protocol.Version1) @@ -297,19 +297,19 @@ var _ = Describe("Framer", func() { It("pops maximum size STREAM frames", func() { for i := protocol.MinStreamFrameSize; i < 2000; i++ { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) { + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id1, DataLenPresent: true, } f.Data = make([]byte, f.MaxDataLen(size, v)) Expect(f.Length(version)).To(Equal(size)) - return &ackhandler.Frame{Frame: f}, false + return ackhandler.StreamFrame{Frame: f}, true, false }) framer.AddActiveStream(id1) frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1) Expect(frames).To(HaveLen(1)) - f := frames[0].Frame.(*wire.StreamFrame) + f := frames[0].Frame Expect(f.DataLenPresent).To(BeFalse()) Expect(f.Length(version)).To(Equal(i)) } @@ -319,29 +319,29 @@ var _ = Describe("Framer", func() { for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) { + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id2, DataLenPresent: true, } f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, v)) - return &ackhandler.Frame{Frame: f}, false + return ackhandler.StreamFrame{Frame: f}, true, false }) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) { + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id2, DataLenPresent: true, } f.Data = make([]byte, f.MaxDataLen(size, v)) Expect(f.Length(version)).To(Equal(size)) - return &ackhandler.Frame{Frame: f}, false + return ackhandler.StreamFrame{Frame: f}, true, false }) framer.AddActiveStream(id1) framer.AddActiveStream(id2) frames, _ := framer.AppendStreamFrames(nil, i, protocol.Version1) Expect(frames).To(HaveLen(2)) - f1 := frames[0].Frame.(*wire.StreamFrame) - f2 := frames[1].Frame.(*wire.StreamFrame) + f1 := frames[0].Frame + f2 := frames[1].Frame Expect(f1.DataLenPresent).To(BeTrue()) Expect(f2.DataLenPresent).To(BeFalse()) Expect(f1.Length(version) + f2.Length(version)).To(Equal(i)) @@ -351,7 +351,7 @@ var _ = Describe("Framer", func() { It("pops frames that when asked for the the minimum STREAM frame size", func() { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) f := &wire.StreamFrame{Data: []byte("foobar")} - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize, protocol.Version1) }) @@ -369,7 +369,7 @@ var _ = Describe("Framer", func() { Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), DataLenPresent: true, } - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(&ackhandler.Frame{Frame: f}, false) + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, true, false) framer.AddActiveStream(id1) fs, length := framer.AppendStreamFrames(nil, 500, protocol.Version1) Expect(fs).To(HaveLen(1)) diff --git a/internal/ackhandler/frame.go b/internal/ackhandler/frame.go index deb23cfc..cd75d56a 100644 --- a/internal/ackhandler/frame.go +++ b/internal/ackhandler/frame.go @@ -27,3 +27,11 @@ func putFrame(f *Frame) { f.OnAcked = nil framePool.Put(f) } + +type StreamFrame struct { + Frame *wire.StreamFrame + Handler interface { + OnLost(*wire.StreamFrame) + OnAcked(*wire.StreamFrame) + } +} diff --git a/internal/ackhandler/packet.go b/internal/ackhandler/packet.go index 849eea69..257dd83b 100644 --- a/internal/ackhandler/packet.go +++ b/internal/ackhandler/packet.go @@ -11,6 +11,7 @@ import ( type Packet struct { SendTime time.Time PacketNumber protocol.PacketNumber + StreamFrames []StreamFrame Frames []*Frame LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK Length protocol.ByteCount @@ -32,6 +33,7 @@ var packetPool = sync.Pool{New: func() any { return &Packet{} }} func GetPacket() *Packet { p := packetPool.Get().(*Packet) p.PacketNumber = 0 + p.StreamFrames = nil p.Frames = nil p.LargestAcked = 0 p.Length = 0 @@ -51,5 +53,6 @@ func putPacket(p *Packet) { putFrame(f) } p.Frames = nil + p.StreamFrames = nil packetPool.Put(p) } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 550dc74d..27f72cce 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -243,7 +243,7 @@ func (h *sentPacketHandler) SentPacket(p *Packet) { } pnSpace.largestSent = p.PacketNumber - isAckEliciting := len(p.Frames) > 0 + isAckEliciting := len(p.StreamFrames) > 0 || len(p.Frames) > 0 if isAckEliciting { pnSpace.lastAckElicitingPacketTime = p.SendTime @@ -424,6 +424,9 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL f.OnAcked(f.Frame) } } + for _, f := range p.StreamFrames { + f.Handler.OnAcked(f.Frame) + } if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } @@ -790,12 +793,16 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) } func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { - if len(p.Frames) == 0 { + if len(p.Frames) == 0 && len(p.StreamFrames) == 0 { panic("no frames") } for _, f := range p.Frames { f.OnLost(f.Frame) } + for _, f := range p.StreamFrames { + f.Handler.OnLost(f.Frame) + } + p.StreamFrames = nil p.Frames = nil } diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index fde92852..21255ec6 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -51,10 +51,10 @@ func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 inte } // AppendStreamFrames mocks base method. -func (m *MockFrameSource) AppendStreamFrames(arg0 []*ackhandler.Frame, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { +func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1, arg2) - ret0, _ := ret[0].([]*ackhandler.Frame) + ret0, _ := ret[0].([]ackhandler.StreamFrame) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index df9e4fbb..c0581bc2 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -161,12 +161,13 @@ func (mr *MockSendStreamIMockRecorder) hasData() *gomock.Call { } // popStreamFrame mocks base method. -func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (*ackhandler.Frame, bool) { +func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1) - ret0, _ := ret[0].(*ackhandler.Frame) + ret0, _ := ret[0].(ackhandler.StreamFrame) ret1, _ := ret[1].(bool) - return ret0, ret1 + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 } // popStreamFrame indicates an expected call of popStreamFrame. diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index f617c6e3..512b4b1d 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -258,12 +258,13 @@ func (mr *MockStreamIMockRecorder) hasData() *gomock.Call { } // popStreamFrame mocks base method. -func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (*ackhandler.Frame, bool) { +func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1) - ret0, _ := ret[0].(*ackhandler.Frame) + ret0, _ := ret[0].(ackhandler.StreamFrame) ret1, _ := ret[1].(bool) - return ret0, ret1 + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 } // popStreamFrame indicates an expected call of popStreamFrame. diff --git a/packet_packer.go b/packet_packer.go index 180a03e1..27e42fa6 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -35,15 +35,17 @@ type sealer interface { } type payload struct { - frames []*ackhandler.Frame - ack *wire.AckFrame - length protocol.ByteCount + streamFrames []ackhandler.StreamFrame + frames []*ackhandler.Frame + ack *wire.AckFrame + length protocol.ByteCount } type longHeaderPacket struct { - header *wire.ExtendedHeader - ack *wire.AckFrame - frames []*ackhandler.Frame + header *wire.ExtendedHeader + ack *wire.AckFrame + frames []*ackhandler.Frame + streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets length protocol.ByteCount @@ -108,6 +110,7 @@ func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQu ap.PacketNumber = p.header.PacketNumber ap.LargestAcked = largestAcked ap.Frames = p.frames + ap.StreamFrames = p.streamFrames ap.Length = p.length ap.EncryptionLevel = encLevel ap.SendTime = now @@ -143,7 +146,7 @@ type sealingManager interface { type frameSource interface { HasData() bool - AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) } @@ -595,7 +598,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) // check if we have anything to send - if len(pl.frames) == 0 { + if len(pl.frames) == 0 && len(pl.streamFrames) == 0 { if pl.ack == nil { return payload{} } @@ -629,7 +632,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return payload{} } - pl := payload{frames: make([]*ackhandler.Frame, 0, 1)} + pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)} hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() @@ -684,7 +687,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) pl.length += lengthAdded - pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length, v) + pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v) pl.length += lengthAdded } return pl @@ -842,10 +845,11 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] return &longHeaderPacket{ - header: header, - ack: pl.ack, - frames: pl.frames, - length: protocol.ByteCount(len(raw)), + header: header, + ack: pl.ack, + frames: pl.frames, + streamFrames: pl.streamFrames, + length: protocol.ByteCount(len(raw)), }, nil } @@ -907,6 +911,7 @@ func (p *packetPacker) appendShortHeaderPacket( ap.PacketNumber = pn ap.LargestAcked = largestAcked ap.Frames = pl.frames + ap.StreamFrames = pl.streamFrames ap.Length = protocol.ByteCount(len(raw)) ap.EncryptionLevel = protocol.Encryption1RTT ap.SendTime = time.Now() @@ -927,9 +932,16 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } - for _, frame := range pl.frames { + for _, f := range pl.frames { var err error - raw, err = frame.Append(raw, v) + raw, err = f.Append(raw, v) + if err != nil { + return nil, err + } + } + for _, f := range pl.streamFrames { + var err error + raw, err = f.Frame.Append(raw, v) if err != nil { return nil, err } diff --git a/packet_packer_test.go b/packet_packer_test.go index 05dc8689..07d9e9a8 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -63,23 +63,23 @@ var _ = Describe("Packet packer", func() { ExpectWithOffset(1, len(data)-l+int(pnLen)).To(BeNumerically(">=", 4)) } - appendFrames := func(fs, frames []*ackhandler.Frame, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount - for _, f := range frames { - length += f.Frame.Length(v) - } - return append(fs, frames...), length - } - - expectAppendStreamFrames := func(frames ...*ackhandler.Frame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []*ackhandler.Frame, _ protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { - return appendFrames(fs, frames, v) + expectAppendStreamFrames := func(frames ...ackhandler.StreamFrame) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.StreamFrame, _ protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + var length protocol.ByteCount + for _, f := range frames { + length += f.Frame.Length(v) + } + return append(fs, frames...), length }) } expectAppendControlFrames := func(frames ...*ackhandler.Frame) { framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []*ackhandler.Frame, _ protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { - return appendFrames(fs, frames, v) + var length protocol.ByteCount + for _, f := range frames { + length += f.Frame.Length(v) + } + return append(fs, frames...), length }) } @@ -168,14 +168,15 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} - expectAppendStreamFrames(&ackhandler.Frame{Frame: f}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) p, err := packer.PackCoalescedPacket(false, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) - Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) - Expect(p.shortHdrPacket.Frames[0].Frame).To(Equal(f)) + Expect(p.shortHdrPacket.Frames).To(BeEmpty()) + Expect(p.shortHdrPacket.StreamFrames).To(HaveLen(1)) + Expect(p.shortHdrPacket.StreamFrames[0].Frame).To(Equal(f)) hdrRawEncrypted := append([]byte{}, hdrRaw...) hdrRawEncrypted[0] ^= 0xff hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff @@ -305,7 +306,7 @@ var _ = Describe("Packet packer", func() { return append(frames, cf), cf.Length(v) }) // TODO: check sizes - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []*ackhandler.Frame, _ protocol.ByteCount, _ protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.StreamFrame, _ protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { return frames, 0 }) p, err := packer.PackCoalescedPacket(false, protocol.Version1) @@ -508,14 +509,15 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, } - expectAppendStreamFrames(&ackhandler.Frame{Frame: f}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) p, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0].Frame.(*wire.StreamFrame).StreamID).To(Equal(f.StreamID)) + Expect(p.Frames).To(BeEmpty()) + Expect(p.StreamFrames).To(HaveLen(1)) + Expect(p.StreamFrames[0].Frame.StreamID).To(Equal(f.StreamID)) Expect(buffer.Data).To(ContainSubstring(string(b))) }) @@ -622,7 +624,7 @@ var _ = Describe("Packet packer", func() { maxSize = maxLen return fs, 444 }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(fs []*ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(fs []ackhandler.StreamFrame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { Expect(maxLen).To(Equal(maxSize - 444)) return fs, 0 }), @@ -683,7 +685,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() - expectAppendStreamFrames(&ackhandler.Frame{Frame: f}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) _, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added @@ -732,14 +734,15 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() - expectAppendStreamFrames(&ackhandler.Frame{Frame: f1}, &ackhandler.Frame{Frame: f2}, &ackhandler.Frame{Frame: f3}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f1}, ackhandler.StreamFrame{Frame: f2}, ackhandler.StreamFrame{Frame: f3}) p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.Frames).To(HaveLen(3)) - Expect(p.Frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.Frames[1].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.Frames[2].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + Expect(p.Frames).To(BeEmpty()) + Expect(p.StreamFrames).To(HaveLen(3)) + Expect(p.StreamFrames[0].Frame.Data).To(Equal([]byte("frame 1"))) + Expect(p.StreamFrames[1].Frame.Data).To(Equal([]byte("frame 2"))) + Expect(p.StreamFrames[2].Frame.Data).To(Equal([]byte("frame 3"))) }) Context("making ACK packets ack-eliciting", func() { @@ -1149,7 +1152,7 @@ var _ = Describe("Packet packer", func() { return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} }) expectAppendControlFrames() - expectAppendStreamFrames(&ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) p, err := packer.PackCoalescedPacket(false, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) @@ -1158,9 +1161,11 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.longHdrPackets[0].streamFrames).To(BeEmpty()) Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) - Expect(p.longHdrPackets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(p.longHdrPackets[1].frames).To(BeEmpty()) + Expect(p.longHdrPackets[1].streamFrames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].streamFrames[0].Frame.Data).To(Equal([]byte("foobar"))) hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1184,7 +1189,7 @@ var _ = Describe("Packet packer", func() { return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} }) expectAppendControlFrames() - expectAppendStreamFrames(&ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) p, err := packer.PackCoalescedPacket(false, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically("<", 100)) @@ -1193,8 +1198,9 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) Expect(p.shortHdrPacket).ToNot(BeNil()) - Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) - Expect(p.shortHdrPacket.Frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(p.shortHdrPacket.Frames).To(BeEmpty()) + Expect(p.shortHdrPacket.StreamFrames).To(HaveLen(1)) + Expect(p.shortHdrPacket.StreamFrames[0].Frame.Data).To(Equal([]byte("foobar"))) hdrs, more := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(1)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeHandshake)) @@ -1481,7 +1487,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) framer.EXPECT().HasData().Return(true) expectAppendControlFrames() - expectAppendStreamFrames(&ackhandler.Frame{Frame: f}) + expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.Version1) Expect(err).ToNot(HaveOccurred()) @@ -1489,8 +1495,9 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket - Expect(packet.Frames).To(HaveLen(1)) - Expect(packet.Frames[0].Frame).To(Equal(f)) + Expect(packet.Frames).To(BeEmpty()) + Expect(packet.StreamFrames).To(HaveLen(1)) + Expect(packet.StreamFrames[0].Frame).To(Equal(f)) }) It("packs a full size 1-RTT probe packet", func() { @@ -1502,10 +1509,10 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) framer.EXPECT().HasData().Return(true) expectAppendControlFrames() - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []*ackhandler.Frame, maxSize protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []ackhandler.StreamFrame, maxSize protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { sf, split := f.MaybeSplitOffFrame(maxSize, v) Expect(split).To(BeTrue()) - return append(fs, &ackhandler.Frame{Frame: sf}), sf.Length(v) + return append(fs, ackhandler.StreamFrame{Frame: sf}), sf.Length(v) }) p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.Version1) @@ -1514,8 +1521,8 @@ var _ = Describe("Packet packer", func() { Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket).ToNot(BeNil()) packet := p.shortHdrPacket - Expect(packet.Frames).To(HaveLen(1)) - Expect(packet.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packet.Frames).To(BeEmpty()) + Expect(packet.StreamFrames).To(HaveLen(1)) Expect(packet.Length).To(Equal(maxPacketSize)) }) diff --git a/send_stream.go b/send_stream.go index cebe30ef..abe1067e 100644 --- a/send_stream.go +++ b/send_stream.go @@ -18,7 +18,7 @@ type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) hasData() bool - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool) closeForShutdown(error) updateSendWindow(protocol.ByteCount) } @@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool /* has more data to send */) { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) { s.mutex.Lock() f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { @@ -207,13 +207,12 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers s.mutex.Unlock() if f == nil { - return nil, hasMoreData + return ackhandler.StreamFrame{}, false, hasMoreData } - af := ackhandler.GetFrame() - af.Frame = f - af.OnLost = s.queueRetransmission - af.OnAcked = s.frameAcked - return af, hasMoreData + return ackhandler.StreamFrame{ + Frame: f, + Handler: s, + }, true, hasMoreData } func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { @@ -348,9 +347,8 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By } } -func (s *sendStream) frameAcked(f wire.Frame) { - f.(*wire.StreamFrame).PutBack() - +func (s *sendStream) OnAcked(f *wire.StreamFrame) { + f.PutBack() s.mutex.Lock() if s.cancelWriteErr != nil { s.mutex.Unlock() @@ -377,15 +375,14 @@ func (s *sendStream) isNewlyCompleted() bool { return false } -func (s *sendStream) queueRetransmission(f wire.Frame) { - sf := f.(*wire.StreamFrame) - sf.DataLenPresent = true +func (s *sendStream) OnLost(f *wire.StreamFrame) { s.mutex.Lock() if s.cancelWriteErr != nil { s.mutex.Unlock() return } - s.retransmissionQueue = append(s.retransmissionQueue, sf) + f.DataLenPresent = true + s.retransmissionQueue = append(s.retransmissionQueue, f) s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") diff --git a/send_stream_test.go b/send_stream_test.go index 0bc4ce90..f2929159 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -85,8 +85,9 @@ var _ = Describe("Send Stream", func() { waitForWrite() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - f := frame.Frame.(*wire.StreamFrame) + frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + f := frame.Frame Expect(f.Data).To(Equal([]byte("foobar"))) Expect(f.Fin).To(BeFalse()) Expect(f.Offset).To(BeZero()) @@ -109,19 +110,22 @@ var _ = Describe("Send Stream", func() { waitForWrite() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) - frame, _ := str.popStreamFrame(expectedFrameHeaderLen(0)+3, protocol.Version1) - f := frame.Frame.(*wire.StreamFrame) + frame, ok, _ := str.popStreamFrame(expectedFrameHeaderLen(0)+3, protocol.Version1) + Expect(ok).To(BeTrue()) + f := frame.Frame Expect(f.Offset).To(BeZero()) Expect(f.Fin).To(BeFalse()) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.DataLenPresent).To(BeTrue()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - f = frame.Frame.(*wire.StreamFrame) + frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + f = frame.Frame Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.Fin).To(BeFalse()) Expect(f.Offset).To(Equal(protocol.ByteCount(3))) Expect(f.DataLenPresent).To(BeTrue()) - Expect(str.popStreamFrame(1000, protocol.Version1)).To(BeNil()) + _, ok, _ = str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Eventually(done).Should(BeClosed()) }) @@ -141,8 +145,9 @@ var _ = Describe("Send Stream", func() { Eventually(done).Should(BeClosed()) // both Write calls returned without any data having been dequeued yet mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) - frame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - f := frame.Frame.(*wire.StreamFrame) + frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + f := frame.Frame Expect(f.Offset).To(BeZero()) Expect(f.Fin).To(BeFalse()) Expect(f.Data).To(Equal([]byte("foobar"))) @@ -163,8 +168,9 @@ var _ = Describe("Send Stream", func() { }() waitForWrite() for i := 0; i < 5; i++ { - frame, _ := str.popStreamFrame(1100, protocol.Version1) - f := frame.Frame.(*wire.StreamFrame) + frame, ok, _ := str.popStreamFrame(1100, protocol.Version1) + Expect(ok).To(BeTrue()) + f := frame.Frame Expect(f.Offset).To(BeNumerically("~", 1100*i, 10*i)) Expect(f.Fin).To(BeFalse()) Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) @@ -186,15 +192,17 @@ var _ = Describe("Send Stream", func() { waitForWrite() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.DataLen()).To(Equal(protocol.ByteCount(2))) Consistently(done).ShouldNot(BeClosed()) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) - frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1)+1, protocol.Version1) + frame, ok, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1)+1, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) - f = frame.Frame.(*wire.StreamFrame) + f = frame.Frame Expect(f.DataLen()).To(Equal(protocol.ByteCount(1))) Eventually(done).Should(BeClosed()) }) @@ -214,22 +222,24 @@ var _ = Describe("Send Stream", func() { waitForWrite() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) - frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+2, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.Data).To(Equal([]byte("fo"))) Consistently(done).ShouldNot(BeClosed()) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) - frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2)+4, protocol.Version1) + frame, ok, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2)+4, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) - f = frame.Frame.(*wire.StreamFrame) + f = frame.Frame Expect(f.Data).To(Equal([]byte("obar"))) Eventually(done).Should(BeClosed()) }) It("popStreamFrame returns nil if no data is available", func() { - frame, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) }) @@ -246,16 +256,16 @@ var _ = Describe("Send Stream", func() { waitForWrite() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) + Expect(frame.Frame.Fin).To(BeFalse()) Expect(hasMoreData).To(BeTrue()) - frame, hasMoreData = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeFalse()) + frame, ok, hasMoreData = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + Expect(frame.Frame.Fin).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeFalse()) Eventually(done).Should(BeClosed()) }) @@ -275,12 +285,14 @@ var _ = Describe("Send Stream", func() { Expect(n).To(Equal(3)) }() waitForWrite() - frame, _ := str.popStreamFrame(frameHeaderSize+1, protocol.Version1) - f := frame.Frame.(*wire.StreamFrame) + frame, ok, _ := str.popStreamFrame(frameHeaderSize+1, protocol.Version1) + Expect(ok).To(BeTrue()) + f := frame.Frame Expect(f.Data).To(Equal([]byte("f"))) - frame, _ = str.popStreamFrame(100, protocol.Version1) + frame, ok, _ = str.popStreamFrame(100, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) + f = frame.Frame Expect(f.Data).To(Equal([]byte("oo"))) s[1] = 'e' Expect(f.Data).To(Equal([]byte("oo"))) @@ -323,8 +335,8 @@ var _ = Describe("Send Stream", func() { Expect(err).ToNot(HaveOccurred()) }() waitForWrite() - f, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(f).To(BeNil()) + _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) // make the Write go routine return str.closeForShutdown(nil) @@ -345,7 +357,8 @@ var _ = Describe("Send Stream", func() { // first pop a STREAM frame of the maximum size allowed by flow control mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) - f, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+3, protocol.Version1) + f, ok, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0)+3, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(f).ToNot(BeNil()) Expect(hasMoreData).To(BeTrue()) @@ -357,8 +370,8 @@ var _ = Describe("Send Stream", func() { StreamID: streamID, MaximumStreamData: 10, }) - f, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(f).To(BeNil()) + _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) // make the Write go routine return str.closeForShutdown(nil) @@ -416,11 +429,12 @@ var _ = Describe("Send Stream", func() { Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(hasMoreData).To(BeTrue()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) + Expect(n).To(BeEquivalentTo(frame.Frame.DataLen())) }) It("doesn't pop any data after the deadline expired", func() { @@ -437,12 +451,13 @@ var _ = Describe("Send Stream", func() { Expect(err).To(MatchError(errDeadline)) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(hasMoreData).To(BeTrue()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - frame, hasMoreData = str.popStreamFrame(50, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, hasMoreData = str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) }) @@ -529,9 +544,10 @@ var _ = Describe("Send Stream", func() { It("allows FIN", func() { mockSender.EXPECT().onHasStreamData(streamID) str.Close() - frame, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.Data).To(BeEmpty()) Expect(f.Fin).To(BeTrue()) Expect(f.DataLenPresent).To(BeTrue()) @@ -546,13 +562,15 @@ var _ = Describe("Send Stream", func() { Expect(str.Close()).To(Succeed()) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - frame, _ := str.popStreamFrame(3+frameHeaderLen, protocol.Version1) + frame, ok, _ := str.popStreamFrame(3+frameHeaderLen, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.Fin).To(BeFalse()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) - f = frame.Frame.(*wire.StreamFrame) + frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) + f = frame.Frame Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.Fin).To(BeTrue()) }) @@ -575,9 +593,10 @@ var _ = Describe("Send Stream", func() { if i == 5 { Eventually(done).Should(BeClosed()) } - frame, _ := str.popStreamFrame(1100, protocol.Version1) + frame, ok, _ := str.popStreamFrame(1100, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) Expect(f.Fin).To(Equal(i == 5)) // the last frame should have the FIN bit set } @@ -585,26 +604,27 @@ var _ = Describe("Send Stream", func() { It("doesn't allow FIN after it is closed for shutdown", func() { str.closeForShutdown(errors.New("test")) - f, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(f).To(BeNil()) + _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) Expect(str.Close()).To(Succeed()) - f, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(f).To(BeNil()) + _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) }) It("doesn't allow FIN twice", func() { mockSender.EXPECT().onHasStreamData(streamID) str.Close() - frame, _ := str.popStreamFrame(1000, protocol.Version1) + frame, ok, _ := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(f.Data).To(BeEmpty()) Expect(f.Fin).To(BeTrue()) - frame, hasMoreData := str.popStreamFrame(1000, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, hasMoreData := str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) }) }) @@ -631,12 +651,13 @@ var _ = Describe("Send Stream", func() { close(done) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) // get a STREAM frame containing some data, but not all + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) // get a STREAM frame containing some data, but not all + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(hasMoreData).To(BeTrue()) str.closeForShutdown(testErr) - frame, hasMoreData = str.popStreamFrame(1000, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, hasMoreData = str.popStreamFrame(1000, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) Eventually(done).Should(BeClosed()) }) @@ -734,12 +755,13 @@ var _ = Describe("Send Stream", func() { close(writeReturned) }() waitForWrite() - frame, _ := str.popStreamFrame(50, protocol.Version1) + frame, ok, _ := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) Eventually(writeReturned).Should(BeClosed()) - Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) + Expect(n).To(BeEquivalentTo(frame.Frame.DataLen())) }) It("doesn't pop STREAM frames after being canceled", func() { @@ -754,13 +776,14 @@ var _ = Describe("Send Stream", func() { close(writeReturned) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) Expect(frame).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) - frame, hasMoreData = str.popStreamFrame(10, protocol.Version1) - Expect(frame).To(BeNil()) + _, ok, hasMoreData = str.popStreamFrame(10, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) Eventually(writeReturned).Should(BeClosed()) }) @@ -782,14 +805,15 @@ var _ = Describe("Send Stream", func() { close(writeReturned) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) Expect(frame).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) - frame, hasMoreData = str.popStreamFrame(10, protocol.Version1) + _, ok, hasMoreData = str.popStreamFrame(10, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) - Expect(frame).To(BeNil()) Eventually(writeReturned).Should(BeClosed()) }) @@ -805,12 +829,13 @@ var _ = Describe("Send Stream", func() { close(writeReturned) }() waitForWrite() - frame, hasMoreData := str.popStreamFrame(50, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(50, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeTrue()) Expect(frame).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) - frame.OnAcked(frame.Frame) + frame.Handler.OnAcked(frame.Frame) }) It("cancels the context", func() { @@ -915,10 +940,11 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(f) - frame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + str.OnLost(f) + frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) + f = frame.Frame Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) Expect(f.Data).To(Equal([]byte("foobar"))) Expect(f.DataLenPresent).To(BeTrue()) @@ -932,17 +958,19 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(sf) - frame, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1) + str.OnLost(sf) + frame, ok, hasMoreData := str.popStreamFrame(sf.Length(protocol.Version1)-3, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f := frame.Frame.(*wire.StreamFrame) + f := frame.Frame Expect(hasMoreData).To(BeTrue()) Expect(f.Offset).To(Equal(protocol.ByteCount(0x42))) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.DataLenPresent).To(BeTrue()) - frame, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + frame, ok, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) - f = frame.Frame.(*wire.StreamFrame) + f = frame.Frame Expect(f.Offset).To(Equal(protocol.ByteCount(0x45))) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.DataLenPresent).To(BeTrue()) @@ -956,10 +984,10 @@ var _ = Describe("Send Stream", func() { DataLenPresent: false, } mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(f) - frame, hasMoreData := str.popStreamFrame(2, protocol.Version1) + str.OnLost(f) + _, ok, hasMoreData := str.popStreamFrame(2, protocol.Version1) + Expect(ok).To(BeFalse()) Expect(hasMoreData).To(BeTrue()) - Expect(frame).To(BeNil()) }) It("queues lost STREAM frames", func() { @@ -974,17 +1002,19 @@ var _ = Describe("Send Stream", func() { close(done) }() waitForWrite() - frame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Eventually(done).Should(BeClosed()) Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(frame.Frame.Data).To(Equal([]byte("foobar"))) // now lose the frame mockSender.EXPECT().onHasStreamData(streamID) - frame.OnLost(frame.Frame) - newFrame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + frame.Handler.OnLost(frame.Frame) + newFrame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(newFrame).ToNot(BeNil()) - Expect(newFrame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(newFrame.Frame.Data).To(Equal([]byte("foobar"))) }) It("doesn't queue retransmissions for a stream that was canceled", func() { @@ -999,7 +1029,8 @@ var _ = Describe("Send Stream", func() { close(done) }() waitForWrite() - f, _ := str.popStreamFrame(100, protocol.Version1) + f, ok, _ := str.popStreamFrame(100, protocol.Version1) + Expect(ok).To(BeTrue()) Eventually(done).Should(BeClosed()) Expect(f).ToNot(BeNil()) gomock.InOrder( @@ -1008,7 +1039,7 @@ var _ = Describe("Send Stream", func() { ) str.CancelWrite(9876) // don't EXPECT any calls to onHasStreamData - f.OnLost(f.Frame) + f.Handler.OnLost(f.Frame) Expect(str.retransmissionQueue).To(BeEmpty()) }) }) @@ -1031,13 +1062,13 @@ var _ = Describe("Send Stream", func() { waitForWrite() // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.Frame + var frames []ackhandler.StreamFrame for { - frame, hasMoreData := str.popStreamFrame(20, protocol.Version1) - if frame == nil { + frame, ok, hasMoreData := str.popStreamFrame(20, protocol.Version1) + if !ok { continue } - frames = append(frames, *frame) + frames = append(frames, frame) if !hasMoreData { break } @@ -1047,16 +1078,17 @@ var _ = Describe("Send Stream", func() { // Acknowledge all frames. // We don't expect the stream to be completed, since we still need to send the FIN. for _, f := range frames { - f.OnAcked(f.Frame) + f.Handler.OnAcked(f.Frame) } // Now close the stream and acknowledge the FIN. mockSender.EXPECT().onHasStreamData(streamID) Expect(str.Close()).To(Succeed()) - frame, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + frame, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(frame).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) - frame.OnAcked(frame.Frame) + frame.Handler.OnAcked(frame.Frame) }) It("says when a stream is completed, if Close() is called before popping the frame", func() { @@ -1072,13 +1104,14 @@ var _ = Describe("Send Stream", func() { Eventually(done).Should(BeClosed()) Expect(str.Close()).To(Succeed()) - frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + frame, ok, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(hasMoreData).To(BeFalse()) Expect(frame).ToNot(BeNil()) - Expect(frame.Frame.(*wire.StreamFrame).Fin).To(BeTrue()) + Expect(frame.Frame.Fin).To(BeTrue()) mockSender.EXPECT().onStreamCompleted(streamID) - frame.OnAcked(frame.Frame) + frame.Handler.OnAcked(frame.Frame) }) It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { @@ -1095,14 +1128,14 @@ var _ = Describe("Send Stream", func() { waitForWrite() // get a bunch of small frames (max. 20 bytes) - var frames []ackhandler.Frame + var frames []ackhandler.StreamFrame for { - frame, _ := str.popStreamFrame(20, protocol.Version1) - if frame == nil { + frame, ok, _ := str.popStreamFrame(20, protocol.Version1) + if !ok { continue } - frames = append(frames, *frame) - if frame.Frame.(*wire.StreamFrame).Fin { + frames = append(frames, frame) + if frame.Frame.Fin { break } } @@ -1110,16 +1143,17 @@ var _ = Describe("Send Stream", func() { // lose the first frame, acknowledge all others for _, f := range frames[1:] { - f.OnAcked(f.Frame) + f.Handler.OnAcked(f.Frame) } mockSender.EXPECT().onHasStreamData(streamID) - frames[0].OnLost(frames[0].Frame) + frames[0].Handler.OnLost(frames[0].Frame) // get the retransmission and acknowledge it - ret, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + ret, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + Expect(ok).To(BeTrue()) Expect(ret).ToNot(BeNil()) mockSender.EXPECT().onStreamCompleted(streamID) - ret.OnAcked(ret.Frame) + ret.Handler.OnAcked(ret.Frame) }) // This test is kind of an integration test. @@ -1154,18 +1188,18 @@ var _ = Describe("Send Stream", func() { if completed { break } - f, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300)+100), protocol.Version1) - if f == nil { + f, ok, _ := str.popStreamFrame(protocol.ByteCount(mrand.Intn(300)+100), protocol.Version1) + if !ok { continue } - sf := f.Frame.(*wire.StreamFrame) + sf := f.Frame // 50%: acknowledge the frame and save the data // 50%: lose the frame if mrand.Intn(100) < 50 { copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) - f.OnAcked(f.Frame) + f.Handler.OnAcked(f.Frame) } else { - f.OnLost(f.Frame) + f.Handler.OnLost(f.Frame) } } Expect(received).To(Equal(data)) diff --git a/stream.go b/stream.go index 98d2fc6e..ab76eaf8 100644 --- a/stream.go +++ b/stream.go @@ -60,7 +60,7 @@ type streamI interface { // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) updateSendWindow(protocol.ByteCount) }