diff --git a/framer.go b/framer.go index fbfe9bb7..d5be6fc9 100644 --- a/framer.go +++ b/framer.go @@ -12,7 +12,7 @@ type framer interface { AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame + AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } type framerI struct { @@ -73,7 +73,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { +func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { var length protocol.ByteCount f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet @@ -105,5 +105,5 @@ func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCo length += frame.Length(f.version) } f.mutex.Unlock() - return frames + return frames, length } diff --git a/framer_test.go b/framer_test.go index 214528d7..075bb6d2 100644 --- a/framer_test.go +++ b/framer_test.go @@ -85,8 +85,9 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) - fs := framer.AppendStreamFrames(nil, 1000) + fs, length := framer.AppendStreamFrames(nil, 1000) Expect(fs).To(Equal([]wire.Frame{f})) + Expect(length).To(Equal(f.Length(version))) }) It("appends to a frame slice", func() { @@ -99,8 +100,9 @@ var _ = Describe("Stream Framer", func() { framer.AddActiveStream(id1) mdf := &wire.MaxDataFrame{ByteOffset: 1337} frames := []wire.Frame{mdf} - fs := framer.AppendStreamFrames(frames, 1000) + fs, length := framer.AppendStreamFrames(frames, 1000) Expect(fs).To(Equal([]wire.Frame{mdf, f})) + Expect(length).To(Equal(f.Length(version))) }) It("skips a stream that was reported active, but was completed shortly after", func() { @@ -113,7 +115,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f})) }) It("skips a stream that was reported active, but doesn't have any data", func() { @@ -127,7 +130,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f})) }) It("pops from a stream multiple times, if it has enough data", func() { @@ -137,10 +141,13 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id1) // only add it once - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f1})) - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f1})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f2})) // no further calls to popStreamFrame, after popStreamFrame said there's no more data - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(BeNil()) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(BeNil()) }) It("re-queues a stream at the end, if it has enough data", func() { @@ -155,11 +162,14 @@ var _ = Describe("Stream Framer", func() { framer.AddActiveStream(id1) // only add it once framer.AddActiveStream(id2) // first a frame from stream 1 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f11})) + frames, _ := framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f11})) // then a frame from stream 2 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f2})) // then another frame from stream 1 - Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f12})) + frames, _ = framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) + Expect(frames).To(Equal([]wire.Frame{f12})) }) It("only dequeues data from each stream once per packet", func() { @@ -172,7 +182,9 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f1, f2})) + frames, length := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f1, f2})) + Expect(length).To(Equal(f1.Length(version) + f2.Length(version))) }) It("returns multiple normal frames in the order they were reported active", func() { @@ -184,7 +196,8 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id2) framer.AddActiveStream(id1) - Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f2, f1})) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(Equal([]wire.Frame{f2, f1})) }) It("only asks a stream for data once, even if it was reported active multiple times", func() { @@ -193,12 +206,14 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function framer.AddActiveStream(id1) framer.AddActiveStream(id1) - Expect(framer.AppendStreamFrames(nil, 1000)).To(HaveLen(1)) + frames, _ := framer.AppendStreamFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) }) It("does not pop empty frames", func() { - fs := framer.AppendStreamFrames(nil, 500) + fs, length := framer.AppendStreamFrames(nil, 500) Expect(fs).To(BeEmpty()) + Expect(length).To(BeZero()) }) It("pops frames that have the minimum size", func() { @@ -222,8 +237,9 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false) framer.AddActiveStream(id1) - fs := framer.AppendStreamFrames(nil, 500) + fs, length := framer.AppendStreamFrames(nil, 500) Expect(fs).To(Equal([]wire.Frame{f})) + Expect(length).To(Equal(f.Length(version))) }) }) }) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index e2f682a8..676da023 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -51,11 +51,12 @@ func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1 interface{ } // AppendStreamFrames mocks base method -func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) []wire.Frame { +func (m *MockFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) ret0, _ := ret[0].([]wire.Frame) - return ret0 + ret1, _ := ret[1].(protocol.ByteCount) + return ret0, ret1 } // AppendStreamFrames indicates an expected call of AppendStreamFrames diff --git a/packet_packer.go b/packet_packer.go index bc994b54..69684556 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,6 +25,11 @@ type packer interface { ChangeDestConnectionID(protocol.ConnectionID) } +type payload struct { + frames []wire.Frame + length protocol.ByteCount +} + type packedPacket struct { header *wire.ExtendedHeader raw []byte @@ -104,7 +109,7 @@ type sealingManager interface { } type frameSource interface { - AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame + AppendStreamFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) } @@ -165,10 +170,13 @@ func newPacketPacker( // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { - frames := []wire.Frame{ccf} + payload := payload{ + frames: []wire.Frame{ccf}, + length: ccf.Length(p.version), + } encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { @@ -176,11 +184,14 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { if ack == nil { return nil, nil } + payload := payload{ + frames: []wire.Frame{ack}, + length: ack.Length(p.version), + } // TODO(#1534): only pack ACKs with the right encryption level encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - frames := []wire.Frame{ack} - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } // PackRetransmission packs a retransmission @@ -247,7 +258,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { sf.DataLenPresent = false } - p, err := p.writeAndSealPacket(header, frames, encLevel, sealer) + p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, encLevel, sealer) if err != nil { return nil, err } @@ -275,19 +286,21 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen - frames, err := p.composeNextPacket(maxSize) + payload, err := p.composeNextPacket(maxSize) if err != nil { return nil, err } // Check if we have enough frames to send - if len(frames) == 0 { + if len(payload.frames) == 0 { return nil, nil } // check if this packet only contains an ACK - if !ackhandler.HasAckElicitingFrames(frames) { + if !ackhandler.HasAckElicitingFrames(payload.frames) { if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { - frames = append(frames, &wire.PingFrame{}) + ping := &wire.PingFrame{} + payload.frames = append(payload.frames, ping) + payload.length += ping.Length(p.version) p.numNonAckElicitingAcks = 0 } else { p.numNonAckElicitingAcks++ @@ -296,7 +309,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeAndSealPacket(header, frames, encLevel, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { @@ -336,11 +349,12 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { if hasData { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) frames = append(frames, cf) + length += cf.Length(p.version) } - return p.writeAndSealPacket(hdr, frames, encLevel, sealer) + return p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, encLevel, sealer) } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (payload, error) { var length protocol.ByteCount var frames []wire.Frame @@ -360,14 +374,15 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir // the length is encoded to either 1 or 2 bytes maxFrameSize++ - frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length) + frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-length) if len(frames) > 0 { lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { sf.DataLenPresent = false } + length += lengthAdded } - return frames, nil + return payload{frames: frames, length: length}, nil } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { @@ -401,12 +416,13 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Extend func (p *packetPacker) writeAndSealPacket( header *wire.ExtendedHeader, - frames []wire.Frame, + payload payload, encLevel protocol.EncryptionLevel, sealer handshake.Sealer, ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) + frames := payload.frames addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial @@ -419,11 +435,7 @@ func (p *packetPacker) writeAndSealPacket( header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen } else { // long header packets always use 4 byte packet number, so we never need to pad short payloads - length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) - for _, frame := range frames { - length += frame.Length(p.version) - } - header.Length = length + header.Length = protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen) + payload.length } } diff --git a/packet_packer_test.go b/packet_packer_test.go index fae6389a..222f4ccb 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -37,19 +37,23 @@ var _ = Describe("Packet packer", func() { ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } + appendFrames := func(fs, frames []wire.Frame) ([]wire.Frame, protocol.ByteCount) { + var length protocol.ByteCount + for _, f := range frames { + length += f.Length(packer.version) + } + return append(fs, frames...), length + } + expectAppendStreamFrames := func(frames ...wire.Frame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { - return append(fs, frames...) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + return appendFrames(fs, frames) }) } expectAppendControlFrames := func(frames ...wire.Frame) { framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - var length protocol.ByteCount - for _, f := range frames { - length += f.Length(packer.version) - } - return append(fs, frames...), length + return appendFrames(fs, frames) }) } @@ -311,9 +315,9 @@ var _ = Describe("Packet packer", func() { maxSize = maxLen return fs, 444 }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { Expect(maxLen).To(Equal(maxSize - 444 + 1 /* data length of the STREAM frame */)) - return nil + return fs, 0 }), ) _, err := packer.PackPacket() @@ -803,7 +807,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData() framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}, f.Length(packer.version)) packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added