diff --git a/connection.go b/connection.go index 59385b78..fbaddf83 100644 --- a/connection.go +++ b/connection.go @@ -1832,15 +1832,15 @@ func (s *connection) maybeSendAckOnlyPacket() error { } now := time.Now() - p, err := s.packer.PackPacket(true, now) + p, buffer, err := s.packer.PackPacket(true, now) if err != nil { if err == errNothingToPack { return nil } return err } - s.sendPackedShortHeaderPacket(p.Buffer, p.Packet, now) - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) return nil } @@ -1903,23 +1903,23 @@ func (s *connection) sendPacket() (bool, error) { return true, nil } else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { ping, size := s.mtuDiscoverer.GetPing() - p, err := s.packer.PackMTUProbePacket(ping, size, now) + p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now) if err != nil { return false, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Buffer.Len(), false) - s.sendPackedShortHeaderPacket(p.Buffer, p.Packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } - p, err := s.packer.PackPacket(false, now) + p, buffer, err := s.packer.PackPacket(false, now) if err != nil { if err == errNothingToPack { return false, nil } return false, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Buffer.Len(), false) - s.sendPackedShortHeaderPacket(p.Buffer, p.Packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } diff --git a/connection_test.go b/connection_test.go index c0898a32..36c5a120 100644 --- a/connection_test.go +++ b/connection_test.go @@ -53,13 +53,10 @@ var _ = Describe("Connection", func() { destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - getShortHeaderPacket := func(pn protocol.PacketNumber) shortHeaderPacket { + getShortHeaderPacket := func(pn protocol.PacketNumber) (shortHeaderPacket, *packetBuffer) { buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{ - Packet: &ackhandler.Packet{PacketNumber: pn}, - Buffer: buffer, - } + return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}}, buffer } getCoalescedPacket := func(pn protocol.PacketNumber) *coalescedPacket { @@ -609,8 +606,9 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() conn.sentPacketHandler = sph - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + p, buffer := getShortHeaderPacket(1) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() runConn() conn.queueControlFrame(&wire.PingFrame{}) conn.scheduleSending() @@ -1235,9 +1233,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph runConn() - p := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + p, buffer := getShortHeaderPacket(1) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) @@ -1246,7 +1244,7 @@ var _ = Describe("Connection", func() { PacketNumber: p.PacketNumber, PacketNumberLen: p.PacketNumberLen, KeyPhase: p.KeyPhase, - }, p.Buffer.Len(), nil, []logging.Frame{}) + }, buffer.Len(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) @@ -1254,7 +1252,7 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true runConn() - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() @@ -1285,14 +1283,14 @@ var _ = Describe("Connection", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() - p := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + p, buffer := getShortHeaderPacket(1) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() conn.connFlowController = fc runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.Buffer.Len(), nil, []logging.Frame{}) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), buffer.Len(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) frames, _ := conn.framer.AppendControlFrames(nil, 1000) @@ -1418,8 +1416,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(10), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(11), nil) + p, buffer := getShortHeaderPacket(10) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + p, buffer = getShortHeaderPacket(11) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Times(2) go func() { @@ -1435,8 +1435,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(10), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer := getShortHeaderPacket(10) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1453,7 +1454,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().PackPacket(true, gomock.Any()).Return(getShortHeaderPacket(10), nil) + p, buffer := getShortHeaderPacket(10) + packer.EXPECT().PackPacket(true, gomock.Any()).Return(p, buffer, nil) + sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1472,7 +1475,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(100), nil) + p, buffer := getShortHeaderPacket(100) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1487,14 +1491,16 @@ var _ = Describe("Connection", func() { It("paces packets", func() { pacingDelay := scaleDuration(100 * time.Millisecond) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + p1, buffer1 := getShortHeaderPacket(100) + p2, buffer2 := getShortHeaderPacket(101) gomock.InOrder( sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(100), nil), + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p1, buffer1, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(101), nil), + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p2, buffer2, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), @@ -1519,9 +1525,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1000), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1001), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1002), nil) + for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { + p, buffer := getShortHeaderPacket(pn) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) @@ -1551,8 +1558,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1000), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer := getShortHeaderPacket(1000) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) @@ -1574,8 +1582,9 @@ var _ = Describe("Connection", func() { }) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1000), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer := getShortHeaderPacket(1000) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) conn.scheduleSending() @@ -1588,7 +1597,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1000), nil) + p, buffer := getShortHeaderPacket(1000) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) @@ -1609,8 +1619,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1001), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer = getShortHeaderPacket(1001) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1624,7 +1635,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1649,7 +1660,8 @@ var _ = Describe("Connection", func() { mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) - packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any()).Return(getShortHeaderPacket(1), nil) + p, buffer := getShortHeaderPacket(1) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any()).Return(p, buffer, nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1693,8 +1705,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer := getShortHeaderPacket(1) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) go func() { defer GinkgoRecover() @@ -1712,8 +1725,9 @@ var _ = Describe("Connection", func() { }) It("sets the timer to the ack timer", func() { - packer.EXPECT().PackPacket(false, gomock.Any()).Return(getShortHeaderPacket(1234), nil) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + p, buffer := getShortHeaderPacket(1234) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1936,14 +1950,14 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket(false, gomock.Any()).DoAndReturn(func(bool, time.Time) (shortHeaderPacket, error) { + packer.EXPECT().PackPacket(false, gomock.Any()).DoAndReturn(func(bool, time.Time) (shortHeaderPacket, *packetBuffer, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) - return shortHeaderPacket{Buffer: getPacketBuffer(), Packet: &ackhandler.Packet{}}, nil + return shortHeaderPacket{Packet: &ackhandler.Packet{}}, getPacketBuffer(), nil }) - packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() diff --git a/mock_packer_test.go b/mock_packer_test.go index 73f5baa8..324e446f 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -111,12 +111,13 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock. } // PackMTUProbePacket mocks base method. -func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, error) { +func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size, now) ret0, _ := ret[0].(shortHeaderPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(*packetBuffer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // PackMTUProbePacket indicates an expected call of PackMTUProbePacket. @@ -126,12 +127,13 @@ func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size, now interface{} } // PackPacket mocks base method. -func (m *MockPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, error) { +func (m *MockPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackPacket", onlyAck, now) ret0, _ := ret[0].(shortHeaderPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(*packetBuffer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // PackPacket indicates an expected call of PackPacket. diff --git a/packet_packer.go b/packet_packer.go index 030f5856..5b33c56c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -19,13 +19,13 @@ var errNothingToPack = errors.New("nothing to pack") type packer interface { PackCoalescedPacket(onlyAck bool) (*coalescedPacket, error) - PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, error) + PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error) MaybePackProbePacket(protocol.EncryptionLevel) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) SetMaxPacketSize(protocol.ByteCount) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, error) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) HandleTransportParameters(*wire.TransportParameters) SetToken([]byte) @@ -53,7 +53,6 @@ type packetContents struct { type shortHeaderPacket struct { *ackhandler.Packet - Buffer *packetBuffer // used for logging DestConnID protocol.ConnectionID Ack *wire.AckFrame @@ -456,28 +455,27 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro // PackPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, error) { +func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { - return shortHeaderPacket{}, err + return shortHeaderPacket{}, nil, err } hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, onlyAck, true) if payload == nil { - return shortHeaderPacket{}, errNothingToPack + return shortHeaderPacket{}, nil, errNothingToPack } buffer := getPacketBuffer() cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, 0, sealer, false) if err != nil { - return shortHeaderPacket{}, err + return shortHeaderPacket{}, nil, err } return shortHeaderPacket{ Packet: cont.ToAckHandlerPacket(now, p.retransmissionQueue), - Buffer: buffer, DestConnID: hdr.DestConnectionID, Ack: payload.ack, PacketNumberLen: hdr.PacketNumberLen, KeyPhase: hdr.KeyPhase, - }, nil + }, buffer, nil } func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { @@ -715,7 +713,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( }, nil } -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, error) { +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) { payload := &payload{ frames: []ackhandler.Frame{ping}, length: ping.Length(p.version), @@ -723,23 +721,22 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B buffer := getPacketBuffer() sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { - return shortHeaderPacket{}, err + return shortHeaderPacket{}, nil, err } hdr := p.getShortHeader(sealer.KeyPhase()) padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, true) if err != nil { - return shortHeaderPacket{}, err + return shortHeaderPacket{}, nil, err } cont.isMTUProbePacket = true return shortHeaderPacket{ Packet: cont.ToAckHandlerPacket(now, p.retransmissionQueue), - Buffer: buffer, DestConnID: hdr.DestConnectionID, Ack: payload.ack, PacketNumberLen: hdr.PacketNumberLen, KeyPhase: hdr.KeyPhase, - }, nil + }, buffer, nil } func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { diff --git a/packet_packer_test.go b/packet_packer_test.go index 9f0cdc38..7a9a0065 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -291,12 +291,12 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - p, err := packer.PackPacket(true, time.Now()) + p, buffer, err := packer.PackPacket(true, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) Expect(p.Frames).To(BeEmpty()) - parsePacket(p.Buffer.Data) + parsePacket(buffer.Data) }) }) @@ -506,7 +506,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData() - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) }) @@ -522,14 +522,14 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket(false, time.Now()) + p, buffer, err := packer.PackPacket(false, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b, err := f.Append(nil, packer.version) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(1)) Expect(p.Frames[0].Frame.(*wire.StreamFrame).StreamID).To(Equal(f.StreamID)) - Expect(p.Buffer.Data).To(ContainSubstring(string(b))) + Expect(buffer.Data).To(ContainSubstring(string(b))) }) It("packs a single ACK", func() { @@ -539,7 +539,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) @@ -557,14 +557,14 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, err := packer.PackPacket(false, time.Now()) + p, buffer, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(2)) for i, f := range p.Frames { Expect(f).To(BeAssignableToTypeOf(frames[i])) } - Expect(p.Buffer.Len()).ToNot(BeZero()) + Expect(buffer.Len()).ToNot(BeZero()) }) It("packs DATAGRAM frames", func() { @@ -586,12 +586,12 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, err := packer.PackPacket(false, time.Now()) + p, buffer, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(1)) Expect(p.Frames[0].Frame).To(Equal(f)) - Expect(p.Buffer.Data).ToNot(BeEmpty()) + Expect(buffer.Data).ToNot(BeEmpty()) Eventually(done).Should(BeClosed()) }) @@ -614,12 +614,12 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, err := packer.PackPacket(false, time.Now()) + p, buffer, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) Expect(p.Frames).To(BeEmpty()) - Expect(p.Buffer.Data).ToNot(BeEmpty()) + Expect(buffer.Data).ToNot(BeEmpty()) datagramQueue.CloseWithError(nil) Eventually(done).Should(BeClosed()) }) @@ -640,7 +640,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) }) @@ -697,13 +697,13 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackPacket(false, time.Now()) + _, buffer, err := packer.PackPacket(false, time.Now()) Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added - packet.Buffer.Data = packet.Buffer.Data[:packet.Buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.Buffer.Data, packer.getDestConnID().Len()) + buffer.Data = buffer.Data[:buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) - data := packet.Buffer.Data + data := buffer.Data r := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(r, packer.version) Expect(err).ToNot(HaveOccurred()) @@ -715,7 +715,7 @@ var _ = Describe("Packet packer", func() { Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the STREAM frame frameParser := wire.NewFrameParser(true, packer.version) - l, frame, err := frameParser.ParseNext(packet.Buffer.Data[len(data)-r.Len():], protocol.Encryption1RTT) + l, frame, err := frameParser.ParseNext(buffer.Data[len(data)-r.Len():], protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) sf := frame.(*wire.StreamFrame) @@ -748,7 +748,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(3)) @@ -767,7 +767,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -784,7 +784,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) var hasPing bool @@ -803,7 +803,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err = packer.PackPacket(false, time.Now()) + p, _, err = packer.PackPacket(false, time.Now()) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -819,7 +819,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) // now add some frame to send expectAppendControlFrames() @@ -830,7 +830,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).To(Equal(ack)) var hasPing bool @@ -852,7 +852,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackPacket(false, time.Now()) + p, _, err := packer.PackPacket(false, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Frames).ToNot(ContainElement(&wire.PingFrame{})) @@ -871,7 +871,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) // now reduce the maxPacketSize packer.HandleTransportParameters(&wire.TransportParameters{ @@ -882,7 +882,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket(false, time.Now()) + _, _, err = packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) }) @@ -897,7 +897,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) // now try to increase the maxPacketSize packer.HandleTransportParameters(&wire.TransportParameters{ @@ -908,7 +908,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket(false, time.Now()) + _, _, err = packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) }) }) @@ -925,7 +925,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket(false, time.Now()) + _, _, err := packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) // now reduce the maxPacketSize const packetSizeIncrease = 50 @@ -935,7 +935,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket(false, time.Now()) + _, _, err = packer.PackPacket(false, time.Now()) Expect(err).To(MatchError(errNothingToPack)) }) }) @@ -1546,11 +1546,11 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} const probePacketSize = maxPacketSize + 42 - p, err := packer.PackMTUProbePacket(ping, probePacketSize, time.Now()) + p, buffer, err := packer.PackMTUProbePacket(ping, probePacketSize, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(p.Length).To(BeEquivalentTo(probePacketSize)) Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(0x43))) - Expect(p.Buffer.Data).To(HaveLen(int(probePacketSize))) + Expect(buffer.Data).To(HaveLen(int(probePacketSize))) Expect(p.IsPathMTUProbePacket).To(BeTrue()) }) })