From a4b4d520632c340a15a6e64198689a3c715c84df Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Feb 2020 14:15:50 +0800 Subject: [PATCH 1/7] refactor packing of packets before and after the handshake is confirmed --- mock_packer_test.go | 23 +++++++++++--- packet_packer.go | 46 +++++++++++----------------- packet_packer_test.go | 71 +++++++++++++------------------------------ session.go | 14 +++++++-- session_test.go | 2 +- 5 files changed, 70 insertions(+), 86 deletions(-) diff --git a/mock_packer_test.go b/mock_packer_test.go index 6425c1ea..e257b877 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -49,18 +49,18 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g } // MaybePackAckPacket mocks base method -func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) { +func (m *MockPacker) MaybePackAckPacket(arg0 bool) (*packedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackAckPacket") + ret := m.ctrl.Call(m, "MaybePackAckPacket", arg0) ret0, _ := ret[0].(*packedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // MaybePackAckPacket indicates an expected call of MaybePackAckPacket -func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackAckPacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), arg0) } // MaybePackProbePacket mocks base method @@ -78,6 +78,21 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) } +// PackAppDataPacket mocks base method +func (m *MockPacker) PackAppDataPacket() (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackAppDataPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackAppDataPacket indicates an expected call of PackAppDataPacket +func (mr *MockPackerMockRecorder) PackAppDataPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAppDataPacket", reflect.TypeOf((*MockPacker)(nil).PackAppDataPacket)) +} + // PackConnectionClose mocks base method func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index cf010b88..c96307cb 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -16,8 +16,9 @@ import ( type packer interface { PackPacket() (*packedPacket, error) + PackAppDataPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) - MaybePackAckPacket() (*packedPacket, error) + MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) HandleTransportParameters(*handshake.TransportParameters) @@ -138,10 +139,6 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager - // Once both Initial and Handshake keys are dropped, we only send 1-RTT packets. - droppedInitial bool - droppedHandshake bool - initialStream cryptoStream handshakeStream cryptoStream @@ -188,10 +185,6 @@ func newPacketPacker( } } -func (p *packetPacker) handshakeConfirmed() bool { - return p.droppedInitial && p.droppedHandshake -} - // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { payload := payload{ @@ -225,10 +218,10 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } -func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { +func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { var encLevel protocol.EncryptionLevel var ack *wire.AckFrame - if !p.handshakeConfirmed() { + if !handshakeConfirmed { ack = p.acks.GetAckFrame(protocol.EncryptionInitial) if ack != nil { encLevel = protocol.EncryptionInitial @@ -261,38 +254,33 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } -// PackPacket packs a new packet -// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise +// PackPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. func (p *packetPacker) PackPacket() (*packedPacket, error) { - if !p.handshakeConfirmed() { - packet, err := p.maybePackCryptoPacket() - if err != nil { - return nil, err - } - if packet != nil { - return packet, nil - } + packet, err := p.maybePackCryptoPacket() + if err != nil || packet != nil { + return packet, err } + return p.maybePackAppDataPacket() +} +// PackAppDataPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) { return p.maybePackAppDataPacket() } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { // Try packing an Initial packet. packet, err := p.maybePackInitialPacket() - if err == handshake.ErrKeysDropped { - p.droppedInitial = true - } else if err != nil || packet != nil { + if (err != nil && err != handshake.ErrKeysDropped) || packet != nil { return packet, err } // No Initial was packed. Try packing a Handshake packet. packet, err = p.maybePackHandshakePacket() - if err == handshake.ErrKeysDropped { - p.droppedHandshake = true - return nil, nil - } - if err == handshake.ErrKeysNotYetAvailable { + if err == handshake.ErrKeysDropped || err == handshake.ErrKeysNotYetAvailable { return nil, nil } return packet, err diff --git a/packet_packer_test.go b/packet_packer_test.go index 547c0640..06ce4905 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -206,7 +206,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -218,7 +218,7 @@ var _ = Describe("Packet packer", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(false) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) @@ -230,10 +230,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) @@ -276,8 +274,6 @@ var _ = Describe("Packet packer", func() { Context("packing normal packets", func() { BeforeEach(func() { - sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() initialStream.EXPECT().HasData().AnyTimes() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes() @@ -291,7 +287,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -307,7 +303,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} @@ -326,7 +322,7 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte("foobar"), }}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) @@ -339,7 +335,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.ack).To(Equal(ack)) @@ -371,7 +367,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal(frames)) @@ -393,7 +389,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -409,7 +405,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackPacket() + packet, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] @@ -458,7 +454,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) @@ -476,7 +472,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -492,7 +488,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}})) @@ -503,7 +499,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err = packer.PackPacket() + p, err = packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -518,7 +514,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) // now add some frame to send @@ -529,7 +525,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err = packer.PackPacket() + p, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) @@ -543,7 +539,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) @@ -561,7 +557,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // now reduce the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -572,7 +568,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -586,7 +582,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // now try to increase the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -597,7 +593,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) }) @@ -737,31 +733,6 @@ var _ = Describe("Packet packer", func() { Expect(packet.ack).To(Equal(ack)) Expect(packet.frames).To(HaveLen(1)) }) - - It("stops packing crypto packets when the keys are dropped", func() { - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) - expectAppendStreamFrames() - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - - // now the packer should have realized that the handshake is confirmed - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) - expectAppendStreamFrames() - packet, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - }) }) Context("packing probe packets", func() { diff --git a/session.go b/session.go index 1b065c30..e1b51152 100644 --- a/session.go +++ b/session.go @@ -164,6 +164,7 @@ type session struct { earlySessionReadyChan chan struct{} handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool + handshakeConfirmed bool receivedRetry bool receivedFirstPacket bool @@ -1139,6 +1140,9 @@ func (s *session) handleCloseError(closeErr closeError) { } func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { + if encLevel == protocol.EncryptionHandshake { + s.handshakeConfirmed = true + } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) } @@ -1247,7 +1251,7 @@ sendLoop: } func (s *session) maybeSendAckOnlyPacket() error { - packet, err := s.packer.MaybePackAckPacket() + packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) if err != nil { return err } @@ -1305,7 +1309,13 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() - packet, err := s.packer.PackPacket() + var packet *packedPacket + var err error + if !s.handshakeConfirmed { + packet, err = s.packer.PackPacket() + } else { + packet, err = s.packer.PackAppDataPacket() + } if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 6e744ded..64aebdc0 100644 --- a/session_test.go +++ b/session_test.go @@ -905,7 +905,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAck) sph.EXPECT().ShouldSendNumPackets().Return(1000) - packer.EXPECT().MaybePackAckPacket() + packer.EXPECT().MaybePackAckPacket(false) sess.sentPacketHandler = sph Expect(sess.sendPackets()).To(Succeed()) }) From 077504f55782d7b3a3e540724599ca33542a3f47 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Feb 2020 19:32:52 +0700 Subject: [PATCH 2/7] refactor sealing of packets --- packet_packer.go | 79 ++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index c96307cb..9d3073d7 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -215,7 +215,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac hdr = p.getShortHeader(s.KeyPhase()) } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { @@ -251,7 +251,7 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke if err != nil { return nil, err } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } // PackPacket packs a new packet. @@ -357,7 +357,7 @@ func (p *packetPacker) packCryptoPacket( payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { @@ -403,7 +403,7 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeAndSealPacket(header, payload, encLevel, sealer) + return p.writeSinglePacket(header, payload, encLevel, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { @@ -529,15 +529,37 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex return hdr } -func (p *packetPacker) writeAndSealPacket( +// writeSinglePacket packs a single packet. +func (p *packetPacker) writeSinglePacket( header *wire.ExtendedHeader, payload payload, encLevel protocol.EncryptionLevel, sealer sealer, ) (*packedPacket, error) { + packetBuffer := getPacketBuffer() + + n, err := p.appendPacket(packetBuffer.Slice[:0], header, payload, encLevel, sealer) + if err != nil { + return nil, err + } + return &packedPacket{ + header: header, + raw: packetBuffer.Slice[:n], + ack: payload.ack, + frames: payload.frames, + buffer: packetBuffer, + }, nil +} + +func (p *packetPacker) appendPacket( + raw []byte, + header *wire.ExtendedHeader, + payload payload, + encLevel protocol.EncryptionLevel, + sealer sealer, +) (int, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) - if encLevel != protocol.Encryption1RTT { if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { headerLen := header.GetLength(p.version) @@ -549,27 +571,17 @@ func (p *packetPacker) writeAndSealPacket( } else if payload.length < 4-pnLen { paddingLen = 4 - pnLen - payload.length } - return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer) -} - -func (p *packetPacker) writeAndSealPacketWithPadding( - header *wire.ExtendedHeader, - payload payload, - paddingLen protocol.ByteCount, - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packedPacket, error) { - packetBuffer := getPacketBuffer() - buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) + hdrOffset := len(raw) + buffer := bytes.NewBuffer(raw) if err := header.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } payloadOffset := buffer.Len() if payload.ack != nil { if err := payload.ack.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } } if paddingLen > 0 { @@ -577,40 +589,29 @@ func (p *packetPacker) writeAndSealPacketWithPadding( } for _, frame := range payload.frames { if err := frame.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } } if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - fmt.Printf("%#v\n", payload) - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + return 0, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + return 0, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } - raw := buffer.Bytes() - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) + raw = raw[:buffer.Len()] + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) raw = raw[0 : buffer.Len()+sealer.Overhead()] pnOffset := payloadOffset - int(header.PacketNumberLen) - sealer.EncryptHeader( - raw[pnOffset+4:pnOffset+4+16], - &raw[0], - raw[pnOffset:payloadOffset], - ) + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) num := p.pnManager.PopPacketNumber(encLevel) if num != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + return 0, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return &packedPacket{ - header: header, - raw: raw, - ack: payload.ack, - frames: payload.frames, - buffer: packetBuffer, - }, nil + return len(raw) - hdrOffset, nil } func (p *packetPacker) SetToken(token []byte) { From d642bf9098485792d5070878096d9d9c336f9e14 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 11 Feb 2020 11:50:57 +0700 Subject: [PATCH 3/7] simplify content storage in packed packets It's not necessary to store both the packetBuffer and the slice containing the raw data in the packet. --- buffer_pool.go | 15 ++++++++---- buffer_pool_test.go | 10 ++++++-- packet_handler_map.go | 2 +- packet_packer.go | 57 +++++++++++++++++++++---------------------- packet_packer_test.go | 40 +++++++++++++++--------------- send_queue.go | 10 ++++---- send_queue_test.go | 13 ++++------ server.go | 2 +- session.go | 10 ++++---- session_test.go | 56 +++++++++++++++++++++--------------------- 10 files changed, 112 insertions(+), 103 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index d6fb7673..721677e2 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -7,9 +7,9 @@ import ( ) type packetBuffer struct { - Slice []byte + Data []byte - // refCount counts how many packets the Slice is used in. + // refCount counts how many packets Data is used in. // It doesn't support concurrent use. // It is > 1 when used for coalesced packet. refCount int @@ -50,8 +50,13 @@ func (b *packetBuffer) Release() { b.putBack() } +// Len returns the length of Data +func (b *packetBuffer) Len() protocol.ByteCount { + return protocol.ByteCount(len(b.Data)) +} + func (b *packetBuffer) putBack() { - if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { + if cap(b.Data) != int(protocol.MaxReceivePacketSize) { panic("putPacketBuffer called with packet of wrong size!") } bufferPool.Put(b) @@ -62,14 +67,14 @@ var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { buf := bufferPool.Get().(*packetBuffer) buf.refCount = 1 - buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize] + buf.Data = buf.Data[:0] return buf } func init() { bufferPool.New = func() interface{} { return &packetBuffer{ - Slice: make([]byte, 0, protocol.MaxReceivePacketSize), + Data: make([]byte, 0, protocol.MaxReceivePacketSize), } } } diff --git a/buffer_pool_test.go b/buffer_pool_test.go index 3ee7037e..7aafbc46 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -10,7 +10,7 @@ import ( var _ = Describe("Buffer Pool", func() { It("returns buffers of cap", func() { buf := getPacketBuffer() - Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize))) + Expect(buf.Data).To(HaveCap(int(protocol.MaxReceivePacketSize))) }) It("releases buffers", func() { @@ -18,9 +18,15 @@ var _ = Describe("Buffer Pool", func() { buf.Release() }) + It("gets the length", func() { + buf := getPacketBuffer() + buf.Data = append(buf.Data, []byte("foobar")...) + Expect(buf.Len()).To(BeEquivalentTo(6)) + }) + It("panics if wrong-sized buffers are passed", func() { buf := getPacketBuffer() - buf.Slice = make([]byte, 10) + buf.Data = make([]byte, 10) Expect(func() { buf.Release() }).To(Panic()) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index e40b2776..ab0eac96 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -218,7 +218,7 @@ func (h *packetHandlerMap) listen() { defer close(h.listening) for { buffer := getPacketBuffer() - data := buffer.Slice + data := buffer.Data[:protocol.MaxReceivePacketSize] // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err := h.conn.ReadFrom(data) diff --git a/packet_packer.go b/packet_packer.go index 9d3073d7..b8dfe6d5 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -37,7 +37,6 @@ type payload struct { type packedPacket struct { header *wire.ExtendedHeader - raw []byte ack *wire.AckFrame frames []ackhandler.Frame @@ -87,7 +86,7 @@ func (p *packedPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) PacketNumber: p.header.PacketNumber, LargestAcked: largestAcked, Frames: p.frames, - Length: protocol.ByteCount(len(p.raw)), + Length: p.buffer.Len(), EncryptionLevel: encLevel, SendTime: now, } @@ -536,28 +535,25 @@ func (p *packetPacker) writeSinglePacket( encLevel protocol.EncryptionLevel, sealer sealer, ) (*packedPacket, error) { - packetBuffer := getPacketBuffer() - - n, err := p.appendPacket(packetBuffer.Slice[:0], header, payload, encLevel, sealer) - if err != nil { + buffer := getPacketBuffer() + if err := p.appendPacket(buffer, header, payload, encLevel, sealer); err != nil { return nil, err } return &packedPacket{ + buffer: buffer, header: header, - raw: packetBuffer.Slice[:n], ack: payload.ack, frames: payload.frames, - buffer: packetBuffer, }, nil } func (p *packetPacker) appendPacket( - raw []byte, + buffer *packetBuffer, header *wire.ExtendedHeader, payload payload, encLevel protocol.EncryptionLevel, sealer sealer, -) (int, error) { +) error { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if encLevel != protocol.Encryption1RTT { @@ -572,46 +568,49 @@ func (p *packetPacker) appendPacket( paddingLen = 4 - pnLen - payload.length } - hdrOffset := len(raw) - buffer := bytes.NewBuffer(raw) - if err := header.Write(buffer, p.version); err != nil { - return 0, err + hdrOffset := buffer.Len() + buf := bytes.NewBuffer(buffer.Data) + if err := header.Write(buf, p.version); err != nil { + return err } - payloadOffset := buffer.Len() + payloadOffset := buf.Len() if payload.ack != nil { - if err := payload.ack.Write(buffer, p.version); err != nil { - return 0, err + if err := payload.ack.Write(buf, p.version); err != nil { + return err } } if paddingLen > 0 { - buffer.Write(bytes.Repeat([]byte{0}, int(paddingLen))) + buf.Write(bytes.Repeat([]byte{0}, int(paddingLen))) } for _, frame := range payload.frames { - if err := frame.Write(buffer, p.version); err != nil { - return 0, err + if err := frame.Write(buf, p.version); err != nil { + return err } } - if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - return 0, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { + return fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } - if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { - return 0, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { + return fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } - raw = raw[:buffer.Len()] + raw := buffer.Data + // encrypt the packet + raw = raw[:buf.Len()] _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) - raw = raw[0 : buffer.Len()+sealer.Overhead()] - + raw = raw[0 : buf.Len()+sealer.Overhead()] + // apply header protection pnOffset := payloadOffset - int(header.PacketNumberLen) sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) + buffer.Data = raw num := p.pnManager.PopPacketNumber(encLevel) if num != header.PacketNumber { - return 0, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + return errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return len(raw) - hdrOffset, nil + return nil } func (p *packetPacker) SetToken(token []byte) { diff --git a/packet_packer_test.go b/packet_packer_test.go index 06ce4905..f045a586 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -183,8 +183,8 @@ var _ = Describe("Packet packer", func() { hdrRawEncrypted[0] ^= 0xff hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff hdrRawEncrypted[len(hdrRaw)-1] ^= 0xff - Expect(p.raw[0:len(hdrRaw)]).To(Equal(hdrRawEncrypted)) - Expect(p.raw[len(p.raw)-4:]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + Expect(p.buffer.Data[0:len(hdrRaw)]).To(Equal(hdrRawEncrypted)) + Expect(p.buffer.Data[p.buffer.Len()-4:]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) }) @@ -309,7 +309,7 @@ var _ = Describe("Packet packer", func() { b := &bytes.Buffer{} f.Write(b, packer.version) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.raw).To(ContainSubstring(b.String())) + Expect(p.buffer.Data).To(ContainSubstring(b.String())) }) It("stores the encryption level a packet was sealed with", func() { @@ -371,7 +371,7 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal(frames)) - Expect(p.raw).NotTo(BeEmpty()) + Expect(p.buffer.Len()).ToNot(BeZero()) }) It("accounts for the space consumed by control frames", func() { @@ -408,10 +408,10 @@ var _ = Describe("Packet packer", func() { packet, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added - packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] - hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.getDestConnID())) + packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(packet.raw) + r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) @@ -613,7 +613,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - checkLength(p.raw) + checkLength(p.buffer.Data) }) It("packs a maximum size Handshake packet", func() { @@ -635,9 +635,9 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - Expect(p.raw).To(HaveLen(int(packer.maxPacketSize))) + Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) Expect(p.header.IsLongHeader).To(BeTrue()) - checkLength(p.raw) + checkLength(p.buffer.Data) }) It("adds retransmissions", func() { @@ -654,7 +654,7 @@ var _ = Describe("Packet packer", func() { Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) Expect(p.header.IsLongHeader).To(BeTrue()) - checkLength(p.raw) + checkLength(p.buffer.Data) }) It("sends an Initial packet containing only an ACK", func() { @@ -709,11 +709,11 @@ var _ = Describe("Packet packer", func() { packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(packet.header.Token).To(Equal(token)) - Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) + Expect(packet.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) Expect(packet.frames).To(HaveLen(1)) cf := packet.frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) - checkLength(packet.raw) + checkLength(packet.buffer.Data) }) It("adds an ACK frame", func() { @@ -729,7 +729,7 @@ var _ = Describe("Packet packer", func() { packer.perspective = protocol.PerspectiveClient packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) + Expect(packet.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) Expect(packet.ack).To(Equal(ack)) Expect(packet.frames).To(HaveLen(1)) }) @@ -751,7 +751,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames[0].Frame).To(Equal(f)) - checkLength(packet.raw) + checkLength(packet.buffer.Data) }) It("packs a Handshake probe packet", func() { @@ -769,7 +769,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames[0].Frame).To(Equal(f)) - checkLength(packet.raw) + checkLength(packet.buffer.Data) }) It("packs a 1-RTT probe packet", func() { @@ -795,11 +795,13 @@ var _ = Describe("Packet packer", func() { var _ = Describe("Converting to AckHandler packets", func() { It("convert a packet", func() { + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("foobar")...) packet := &packedPacket{ + buffer: buffer, header: &wire.ExtendedHeader{Header: wire.Header{}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, - raw: []byte("foobar"), } t := time.Now() p := packet.ToAckHandlerPacket(t, nil) @@ -811,9 +813,9 @@ var _ = Describe("Converting to AckHandler packets", func() { It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { packet := &packedPacket{ + buffer: getPacketBuffer(), header: &wire.ExtendedHeader{Header: wire.Header{}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, - raw: []byte("foobar"), } p := packet.ToAckHandlerPacket(time.Now(), nil) Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) @@ -822,12 +824,12 @@ var _ = Describe("Converting to AckHandler packets", func() { It("doesn't overwrite the OnLost callback, if it is set", func() { var pingLost bool packet := &packedPacket{ + buffer: getPacketBuffer(), header: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeHandshake}}, frames: []ackhandler.Frame{ {Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, }, - raw: []byte("foobar"), } p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) Expect(p.Frames).To(HaveLen(2)) diff --git a/send_queue.go b/send_queue.go index d4992cfe..8d9ec5ed 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,7 +1,7 @@ package quic type sendQueue struct { - queue chan *packedPacket + queue chan *packetBuffer closeCalled chan struct{} // runStopped when Close() is called runStopped chan struct{} // runStopped when the run loop returns conn connection @@ -12,12 +12,12 @@ func newSendQueue(conn connection) *sendQueue { conn: conn, runStopped: make(chan struct{}), closeCalled: make(chan struct{}), - queue: make(chan *packedPacket, 1), + queue: make(chan *packetBuffer, 1), } return s } -func (h *sendQueue) Send(p *packedPacket) { +func (h *sendQueue) Send(p *packetBuffer) { h.queue <- p } @@ -34,10 +34,10 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case p := <-h.queue: - if err := h.conn.Write(p.raw); err != nil { + if err := h.conn.Write(p.Data); err != nil { return err } - p.buffer.Release() + p.Release() } } } diff --git a/send_queue_test.go b/send_queue_test.go index 37b5a2e4..978a9dde 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -15,14 +15,11 @@ var _ = Describe("Send Queue", func() { q = newSendQueue(c) }) - getPacket := func(b []byte) *packedPacket { + getPacket := func(b []byte) *packetBuffer { buf := getPacketBuffer() - buf.Slice = buf.Slice[:len(b)] - copy(buf.Slice, b) - return &packedPacket{ - buffer: buf, - raw: buf.Slice, - } + buf.Data = buf.Data[:len(b)] + copy(buf.Data, b) + return buf } It("sends a packet", func() { @@ -30,7 +27,7 @@ var _ = Describe("Send Queue", func() { q.Send(p) written := make(chan struct{}) - c.EXPECT().Write(p.raw).Do(func([]byte) { close(written) }) + c.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/server.go b/server.go index 93129721..44e1e6b3 100644 --- a/server.go +++ b/server.go @@ -494,7 +494,7 @@ func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) packetBuffer := getPacketBuffer() defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Slice[:0]) + buf := bytes.NewBuffer(packetBuffer.Data) ccf := &wire.ConnectionCloseFrame{ErrorCode: qerr.ServerBusy} diff --git a/session.go b/session.go index e1b51152..4595ace3 100644 --- a/session.go +++ b/session.go @@ -1340,7 +1340,7 @@ func (s *session) sendPackedPacket(packet *packedPacket) { TransportState: s.sentPacketHandler.GetStats(), EncryptionLevel: packet.EncryptionLevel(), PacketNumber: packet.header.PacketNumber, - PacketSize: protocol.ByteCount(len(packet.raw)), + PacketSize: packet.buffer.Len(), Frames: frames, }) } @@ -1349,11 +1349,11 @@ func (s *session) sendPackedPacket(packet *packedPacket) { for _, f := range packet.frames { frames = append(frames, f.Frame) } - s.qlogger.SentPacket(now, packet.header, protocol.ByteCount(len(packet.raw)), packet.ack, frames) + s.qlogger.SentPacket(now, packet.header, packet.buffer.Len(), packet.ack, frames) } s.logPacket(packet) s.connIDManager.SentPacket() - s.sendQueue.Send(packet) + s.sendQueue.Send(packet.buffer) } func (s *session) sendConnectionClose(quicErr *qerr.QuicError) ([]byte, error) { @@ -1376,7 +1376,7 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) ([]byte, error) { return nil, err } s.logPacket(packet) - return packet.raw, s.conn.Write(packet.raw) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data) } func (s *session) logPacket(packet *packedPacket) { @@ -1384,7 +1384,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.logID, packet.EncryptionLevel()) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) packet.header.Log(s.logger) if packet.ack != nil { wire.LogFrame(s.logger, packet.ack, true) diff --git a/session_test.go b/session_test.go index 64aebdc0..0833c7df 100644 --- a/session_test.go +++ b/session_test.go @@ -54,10 +54,8 @@ var _ = Describe("Session", func() { getPacket := func(pn protocol.PacketNumber) *packedPacket { buffer := getPacketBuffer() - data := buffer.Slice[:0] - data = append(data, []byte("foobar")...) + buffer.Data = append(buffer.Data, []byte("foobar")...) return &packedPacket{ - raw: data, buffer: buffer, header: &wire.ExtendedHeader{PacketNumber: pn}, } @@ -417,12 +415,14 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(qerr.ApplicationError(0, "")) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("connection close")...) packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { Expect(f.IsApplicationError).To(BeTrue()) Expect(f.ErrorCode).To(Equal(qerr.NoError)) Expect(f.FrameType).To(BeZero()) Expect(f.ReasonPhrase).To(BeEmpty()) - return &packedPacket{raw: []byte("connection close")}, nil + return &packedPacket{buffer: buffer}, nil }) mconn.EXPECT().Write([]byte("connection close")) sess.shutdown() @@ -434,7 +434,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) sess.shutdown() sess.shutdown() @@ -450,7 +450,7 @@ var _ = Describe("Session", func() { Expect(f.IsApplicationError).To(BeTrue()) Expect(f.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(f.ReasonPhrase).To(Equal("test error")) - return &packedPacket{}, nil + return &packedPacket{buffer: getPacketBuffer()}, nil }) mconn.EXPECT().Write(gomock.Any()) sess.CloseWithError(0x1337, "test error") @@ -468,7 +468,7 @@ var _ = Describe("Session", func() { Expect(f.FrameType).To(BeEquivalentTo(0x42)) Expect(f.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(f.ReasonPhrase).To(Equal("test error")) - return &packedPacket{}, nil + return &packedPacket{buffer: getPacketBuffer()}, nil }) mconn.EXPECT().Write(gomock.Any()) sess.closeLocal(testErr) @@ -485,7 +485,7 @@ var _ = Describe("Session", func() { Expect(f.IsApplicationError).To(BeFalse()) Expect(f.ErrorCode).To(BeEquivalentTo(0x15a)) Expect(f.ReasonPhrase).To(BeEmpty()) - return &packedPacket{}, nil + return &packedPacket{buffer: getPacketBuffer()}, nil }) mconn.EXPECT().Write(gomock.Any()) sess.CloseWithError(0x1337, "test error") @@ -518,7 +518,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -607,7 +607,7 @@ var _ = Describe("Session", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -629,7 +629,7 @@ var _ = Describe("Session", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -653,7 +653,7 @@ var _ = Describe("Session", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) runErr := make(chan error) go func() { defer GinkgoRecover() @@ -679,7 +679,7 @@ var _ = Describe("Session", func() { }, nil) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -875,7 +875,7 @@ var _ = Describe("Session", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1002,7 +1002,7 @@ var _ = Describe("Session", func() { AfterEach(func() { // make the go routine return - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1117,7 +1117,7 @@ var _ = Describe("Session", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1206,7 +1206,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1247,7 +1247,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1258,7 +1258,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackPacket().AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() go func() { defer GinkgoRecover() @@ -1301,7 +1301,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1318,7 +1318,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1337,7 +1337,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) Expect(sess.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -1374,7 +1374,7 @@ var _ = Describe("Session", func() { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) s.shutdown() }).Times(4) // initial connection ID + initial client dest conn ID + 2 newly issued conn IDs - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1406,7 +1406,7 @@ var _ = Describe("Session", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sess.shutdown() @@ -1506,7 +1506,7 @@ var _ = Describe("Session", func() { sess.lastPacketReceivedTime = time.Now().Add(-time.Minute) packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { Expect(f.ErrorCode).To(Equal(qerr.NoError)) - return &packedPacket{}, nil + return &packedPacket{buffer: getPacketBuffer()}, nil }) // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity @@ -1562,7 +1562,7 @@ var _ = Describe("Session", func() { }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1738,7 +1738,7 @@ var _ = Describe("Client Session", func() { PacketNumberLen: protocol.PacketNumberLen2, }, []byte{0}))).To(BeTrue()) // make sure the go routine returns - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1863,7 +1863,7 @@ var _ = Describe("Client Session", func() { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) s.shutdown() }) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil).MaxTimes(1) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) } From 5aad7cae5dc2415f1d6f9458cfa76916dff399f1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 13 Feb 2020 17:13:29 +0700 Subject: [PATCH 4/7] send coalesced packets --- internal/protocol/params.go | 4 + mock_packer_test.go | 4 +- packet_packer.go | 201 +++++++++++++++++++++++----------- packet_packer_test.go | 207 +++++++++++++++++++++++++----------- session.go | 108 ++++++++++++------- session_test.go | 114 +++++++++++++++++--- 6 files changed, 452 insertions(+), 186 deletions(-) diff --git a/internal/protocol/params.go b/internal/protocol/params.go index fe88b495..7e3dab51 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -81,6 +81,10 @@ const MaxStreamFrameSorterGaps = 1000 // very small STREAM frames to consume a lot of memory. const MinStreamFrameBufferSize = 128 +// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. +// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. +const MinCoalescedPacketSize = 128 + // MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. // This limits the size of the ClientHello and Certificates that can be received. const MaxCryptoStreamOffset = 16 * (1 << 10) diff --git a/mock_packer_test.go b/mock_packer_test.go index e257b877..1fc58d12 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -109,10 +109,10 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock. } // PackPacket mocks base method -func (m *MockPacker) PackPacket() (*packedPacket, error) { +func (m *MockPacker) PackPacket() (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackPacket") - ret0, _ := ret[0].(*packedPacket) + ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/packet_packer.go b/packet_packer.go index b8dfe6d5..2eaa8410 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -15,7 +15,7 @@ import ( ) type packer interface { - PackPacket() (*packedPacket, error) + PackPacket() (*coalescedPacket, error) PackAppDataPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) @@ -36,14 +36,25 @@ type payload struct { } type packedPacket struct { + buffer *packetBuffer + *packetContents +} + +type packetContents struct { header *wire.ExtendedHeader ack *wire.AckFrame frames []ackhandler.Frame - buffer *packetBuffer + length protocol.ByteCount } -func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { +type coalescedPacket struct { + buffer *packetBuffer + + packets []*packetContents +} + +func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { if !p.header.IsLongHeader { return protocol.Encryption1RTT } @@ -59,11 +70,11 @@ func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { } } -func (p *packedPacket) IsAckEliciting() bool { +func (p *packetContents) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } -func (p *packedPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { +func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { largestAcked := protocol.InvalidPacketNumber if p.ack != nil { largestAcked = p.ack.LargestAcked() @@ -86,7 +97,7 @@ func (p *packedPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) PacketNumber: p.header.PacketNumber, LargestAcked: largestAcked, Frames: p.frames, - Length: p.buffer.Len(), + Length: p.length, EncryptionLevel: encLevel, SendTime: now, } @@ -256,36 +267,79 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke // PackPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackPacket() (*packedPacket, error) { - packet, err := p.maybePackCryptoPacket() - if err != nil || packet != nil { - return packet, err +func (p *packetPacker) PackPacket() (*coalescedPacket, error) { + buffer := getPacketBuffer() + packet, err := p.packCoalescedPacket(buffer) + if err != nil { + return nil, err } - return p.maybePackAppDataPacket() + + if len(packet.packets) == 0 { // nothing to send + buffer.Release() + return nil, nil + } + return packet, nil +} + +func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPacket, error) { + packet := &coalescedPacket{ + buffer: buffer, + packets: make([]*packetContents, 0, 3), + } + // Try packing an Initial packet. + contents, err := p.maybeAppendInitialPacket(buffer) + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + if contents != nil { + packet.packets = append(packet.packets, contents) + } + if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize { + return packet, nil + } + + // Add a Handshake packet. + contents, err = p.maybeAppendHandshakePacket(buffer) + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if contents != nil { + packet.packets = append(packet.packets, contents) + } + if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize { + return packet, nil + } + + // Add a 0-RTT / 1-RTT packet. + contents, err = p.maybeAppendAppDataPacket(buffer) + if err == handshake.ErrKeysNotYetAvailable { + return packet, nil + } + if err != nil { + return nil, err + } + if contents != nil { + packet.packets = append(packet.packets, contents) + } + return packet, nil } // PackAppDataPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) { - return p.maybePackAppDataPacket() + buffer := getPacketBuffer() + contents, err := p.maybeAppendAppDataPacket(buffer) + if err != nil || contents == nil { + buffer.Release() + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil } -func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { - // Try packing an Initial packet. - packet, err := p.maybePackInitialPacket() - if (err != nil && err != handshake.ErrKeysDropped) || packet != nil { - return packet, err - } - - // No Initial was packed. Try packing a Handshake packet. - packet, err = p.maybePackHandshakePacket() - if err == handshake.ErrKeysDropped || err == handshake.ErrKeysNotYetAvailable { - return nil, nil - } - return packet, err -} - -func (p *packetPacker) maybePackInitialPacket() (*packedPacket, error) { +func (p *packetPacker) maybeAppendInitialPacket(buffer *packetBuffer) (*packetContents, error) { sealer, err := p.cryptoSetup.GetInitialSealer() if err != nil { return nil, err @@ -297,69 +351,76 @@ func (p *packetPacker) maybePackInitialPacket() (*packedPacket, error) { // nothing to send return nil, nil } - return p.packCryptoPacket(protocol.EncryptionInitial, sealer, ack, hasRetransmission) + return p.appendCryptoPacket(buffer, protocol.EncryptionInitial, sealer, ack, hasRetransmission) } -func (p *packetPacker) maybePackHandshakePacket() (*packedPacket, error) { +func (p *packetPacker) maybeAppendHandshakePacket(buffer *packetBuffer) (*packetContents, error) { sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { return nil, err } hasRetransmission := p.retransmissionQueue.HasHandshakeData() + // TODO: make sure that the ACK always fits ack := p.acks.GetAckFrame(protocol.EncryptionHandshake) if !p.handshakeStream.HasData() && !hasRetransmission && ack == nil { // nothing to send return nil, nil } - return p.packCryptoPacket(protocol.EncryptionHandshake, sealer, ack, hasRetransmission) + return p.appendCryptoPacket(buffer, protocol.EncryptionHandshake, sealer, ack, hasRetransmission) } -func (p *packetPacker) packCryptoPacket( +func (p *packetPacker) appendCryptoPacket( + buffer *packetBuffer, encLevel protocol.EncryptionLevel, sealer handshake.LongHeaderSealer, ack *wire.AckFrame, hasRetransmission bool, -) (*packedPacket, error) { - s := p.initialStream - if encLevel == protocol.EncryptionHandshake { - s = p.handshakeStream +) (*packetContents, error) { + s := p.handshakeStream + maxPacketSize := p.maxPacketSize + if encLevel == protocol.EncryptionInitial { + s = p.initialStream + if p.perspective == protocol.PerspectiveClient { + maxPacketSize = protocol.MinInitialPacketSize + } } + remainingLen := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) var payload payload if ack != nil { payload.ack = ack payload.length = ack.Length(p.version) + remainingLen -= payload.length } hdr := p.getLongHeader(encLevel) - hdrLen := hdr.GetLength(p.version) + remainingLen -= hdr.GetLength(p.version) if hasRetransmission { for { var f wire.Frame switch encLevel { case protocol.EncryptionInitial: - remainingLen := protocol.MinInitialPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length f = p.retransmissionQueue.GetInitialFrame(remainingLen) case protocol.EncryptionHandshake: - remainingLen := p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length f = p.retransmissionQueue.GetHandshakeFrame(remainingLen) } if f == nil { break } payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - payload.length += f.Length(p.version) + frameLen := f.Length(p.version) + payload.length += frameLen + remainingLen -= frameLen } } else if s.HasData() { - cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) + cf := s.PopCryptoFrame(remainingLen) payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } - return p.writeSinglePacket(hdr, payload, encLevel, sealer) + return p.appendPacket(buffer, hdr, payload, encLevel, sealer) } -func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { +func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetContents, error) { var sealer sealer var header *wire.ExtendedHeader var encLevel protocol.EncryptionLevel @@ -382,7 +443,7 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { } headerLen := header.GetLength(p.version) - maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen + maxSize := p.maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen payload := p.composeNextPacket(maxSize) // check if we have anything to send @@ -402,7 +463,7 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeSinglePacket(header, payload, encLevel, sealer) + return p.appendPacket(buffer, header, payload, encLevel, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { @@ -437,16 +498,26 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payloa } func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { + var contents *packetContents + var err error + buffer := getPacketBuffer() switch encLevel { case protocol.EncryptionInitial: - return p.maybePackInitialPacket() + contents, err = p.maybeAppendInitialPacket(buffer) case protocol.EncryptionHandshake: - return p.maybePackHandshakePacket() + contents, err = p.maybeAppendHandshakePacket(buffer) case protocol.Encryption1RTT: - return p.maybePackAppDataPacket() + contents, err = p.maybeAppendAppDataPacket(buffer) default: panic("unknown encryption level") } + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil } func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { @@ -536,14 +607,13 @@ func (p *packetPacker) writeSinglePacket( sealer sealer, ) (*packedPacket, error) { buffer := getPacketBuffer() - if err := p.appendPacket(buffer, header, payload, encLevel, sealer); err != nil { + contents, err := p.appendPacket(buffer, header, payload, encLevel, sealer) + if err != nil { return nil, err } return &packedPacket{ - buffer: buffer, - header: header, - ack: payload.ack, - frames: payload.frames, + buffer: buffer, + packetContents: contents, }, nil } @@ -553,7 +623,7 @@ func (p *packetPacker) appendPacket( payload payload, encLevel protocol.EncryptionLevel, sealer sealer, -) error { +) (*packetContents, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if encLevel != protocol.Encryption1RTT { @@ -571,13 +641,13 @@ func (p *packetPacker) appendPacket( hdrOffset := buffer.Len() buf := bytes.NewBuffer(buffer.Data) if err := header.Write(buf, p.version); err != nil { - return err + return nil, err } payloadOffset := buf.Len() if payload.ack != nil { if err := payload.ack.Write(buf, p.version); err != nil { - return err + return nil, err } } if paddingLen > 0 { @@ -585,15 +655,15 @@ func (p *packetPacker) appendPacket( } for _, frame := range payload.frames { if err := frame.Write(buf, p.version); err != nil { - return err + return nil, err } } if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - return fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { - return fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } raw := buffer.Data @@ -603,14 +673,19 @@ func (p *packetPacker) appendPacket( raw = raw[0 : buf.Len()+sealer.Overhead()] // apply header protection pnOffset := payloadOffset - int(header.PacketNumberLen) - sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) buffer.Data = raw num := p.pnManager.PopPacketNumber(encLevel) if num != header.PacketNumber { - return errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return nil + return &packetContents{ + header: header, + ack: payload.ack, + frames: payload.frames, + length: buffer.Len() - hdrOffset, + }, nil } func (p *packetPacker) SetToken(token []byte) { diff --git a/packet_packer_test.go b/packet_packer_test.go index f045a586..b5ee06bb 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -178,7 +178,8 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) hdrRawEncrypted := append([]byte{}, hdrRaw...) hdrRawEncrypted[0] ^= 0xff hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff @@ -189,17 +190,17 @@ var _ = Describe("Packet packer", func() { }) Context("packing packets", func() { - var sealer *mocks.MockShortHeaderSealer - - BeforeEach(func() { - sealer = mocks.NewMockShortHeaderSealer(mockCtrl) + // getSealer gets a sealer that's expected to seal exactly one packet + getSealer := func() *mocks.MockShortHeaderSealer { + sealer := mocks.NewMockShortHeaderSealer(mockCtrl) sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes() sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { return append(src, bytes.Repeat([]byte{0}, sealer.Overhead())...) }).AnyTimes() - }) + return sealer + } Context("packing ACK packets", func() { It("doesn't pack a packet if there's no ACK to send", func() { @@ -214,7 +215,7 @@ var _ = Describe("Packet packer", func() { It("packs Handshake ACK-only packets", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) @@ -228,7 +229,7 @@ var _ = Describe("Packet packer", func() { It("packs 1-RTT ACK-only packets", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) p, err := packer.MaybePackAckPacket(true) @@ -253,22 +254,24 @@ var _ = Describe("Packet packer", func() { }) It("packs a 0-RTT packet", func() { - sealingManager.EXPECT().Get0RTTSealer().Return(sealer, nil).AnyTimes() + sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil).AnyTimes() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{ByteOffset: 0x1337}} framer.EXPECT().AppendControlFrames(nil, gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return append(frames, cf), cf.Length(packer.version) }) + // TODO: check sizes framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return frames, 0 }) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.header.Type).To(Equal(protocol.PacketType0RTT)) - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.frames).To(Equal([]ackhandler.Frame{cf})) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{cf})) }) }) @@ -283,7 +286,7 @@ var _ = Describe("Packet packer", func() { It("returns nil when no packet is queued", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) // don't expect any calls to PopPacketNumber - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) @@ -295,7 +298,7 @@ var _ = Describe("Packet packer", func() { It("packs single packets", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() f := &wire.StreamFrame{ @@ -315,7 +318,7 @@ var _ = Describe("Packet packer", func() { It("stores the encryption level a packet was sealed with", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{ @@ -332,7 +335,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackAppDataPacket() @@ -349,7 +352,7 @@ var _ = Describe("Packet packer", func() { ErrorCode: 0x1337, ReasonPhrase: "foobar", } - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) p, err := packer.PackConnectionClose(&ccf) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) @@ -359,7 +362,7 @@ var _ = Describe("Packet packer", func() { It("packs control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) frames := []ackhandler.Frame{ {Frame: &wire.ResetStreamFrame{}}, @@ -376,7 +379,7 @@ var _ = Describe("Packet packer", func() { It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) var maxSize protocol.ByteCount gomock.InOrder( @@ -401,6 +404,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealer := getSealer() sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() @@ -450,7 +454,7 @@ var _ = Describe("Packet packer", func() { } pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) @@ -468,7 +472,7 @@ var _ = Describe("Packet packer", func() { for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() @@ -484,7 +488,7 @@ var _ = Describe("Packet packer", func() { sendMaxNumNonAckElicitingAcks() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() @@ -495,7 +499,7 @@ var _ = Describe("Packet packer", func() { // make sure the next packet doesn't contain another PING pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() @@ -510,7 +514,7 @@ var _ = Describe("Packet packer", func() { sendMaxNumNonAckElicitingAcks() // nothing to send pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) @@ -522,7 +526,7 @@ var _ = Describe("Packet packer", func() { expectAppendStreamFrames() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) p, err = packer.PackAppDataPacket() @@ -535,7 +539,7 @@ var _ = Describe("Packet packer", func() { sendMaxNumNonAckElicitingAcks() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) @@ -549,7 +553,7 @@ var _ = Describe("Packet packer", func() { Context("max packet size", func() { It("sets the maximum packet size", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { @@ -574,7 +578,7 @@ var _ = Describe("Packet packer", func() { It("doesn't increase the max packet size", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { @@ -610,7 +614,9 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true).AnyTimes() initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkLength(p.buffer.Data) @@ -621,7 +627,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) initialStream.EXPECT().HasData() @@ -634,9 +640,68 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.header.IsLongHeader).To(BeTrue()) + checkLength(p.buffer.Data) + }) + + It("packs a coalesced packet", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} + }) + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} + }) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + hdr, _, rest, err = wire.ParsePacket(rest, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(rest).To(BeEmpty()) + }) + + It("doesn't add a coalesced packet if the remaining size is smaller than MaxCoalescedPacketSize", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + s := size - protocol.MinCoalescedPacketSize + f := &wire.CryptoFrame{Offset: 0x1337} + f.Data = bytes.Repeat([]byte{'f'}, int(s-f.Length(packer.version)-1)) + Expect(f.Length(packer.version)).To(Equal(s)) + return f + }) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(len(p.buffer.Data)).To(BeEquivalentTo(maxPacketSize - protocol.MinCoalescedPacketSize)) checkLength(p.buffer.Data) }) @@ -646,14 +711,17 @@ var _ = Describe("Packet packer", func() { retransmissionQueue.AddHandshake(&wire.CryptoFrame{Data: []byte("Handshake")}) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.header.IsLongHeader).To(BeTrue()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) checkLength(p.buffer.Data) }) @@ -661,16 +729,19 @@ var _ = Describe("Packet packer", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData().Times(2) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.ack).To(Equal(ack)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) }) It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() @@ -687,12 +758,14 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData().Times(2) sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.ack).To(Equal(ack)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) }) It("pads Initial packets to the required minimum packet size", func() { @@ -701,19 +774,23 @@ var _ = Describe("Packet packer", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - packet, err := packer.PackPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(packet.header.Token).To(Equal(token)) - Expect(packet.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) - Expect(packet.frames).To(HaveLen(1)) - cf := packet.frames[0].Frame.(*wire.CryptoFrame) + Expect(p.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].header.Token).To(Equal(token)) + Expect(p.packets[0].frames).To(HaveLen(1)) + cf := p.packets[0].frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) - checkLength(packet.buffer.Data) + checkLength(p.buffer.Data) }) It("adds an ACK frame", func() { @@ -721,17 +798,21 @@ var _ = Describe("Packet packer", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 42, Largest: 1337}}} pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient - packet, err := packer.PackPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(packet.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) - Expect(packet.ack).To(Equal(ack)) - Expect(packet.frames).To(HaveLen(1)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) }) }) @@ -739,7 +820,7 @@ var _ = Describe("Packet packer", func() { It("packs an Initial probe packet", func() { f := &wire.CryptoFrame{Data: []byte("Initial")} retransmissionQueue.AddInitial(f) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) @@ -757,7 +838,7 @@ var _ = Describe("Packet packer", func() { It("packs a Handshake probe packet", func() { f := &wire.CryptoFrame{Data: []byte("Handshake")} retransmissionQueue.AddHandshake(f) - sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) handshakeStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) @@ -775,7 +856,7 @@ var _ = Describe("Packet packer", func() { It("packs a 1-RTT probe packet", func() { f := &wire.StreamFrame{Data: []byte("1-RTT")} retransmissionQueue.AddInitial(f) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) @@ -795,25 +876,22 @@ var _ = Describe("Packet packer", func() { var _ = Describe("Converting to AckHandler packets", func() { It("convert a packet", func() { - buffer := getPacketBuffer() - buffer.Data = append(buffer.Data, []byte("foobar")...) - packet := &packedPacket{ - buffer: buffer, + packet := &packetContents{ header: &wire.ExtendedHeader{Header: wire.Header{}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, + length: 42, } t := time.Now() p := packet.ToAckHandlerPacket(t, nil) - Expect(p.Length).To(Equal(protocol.ByteCount(6))) + Expect(p.Length).To(Equal(protocol.ByteCount(42))) Expect(p.Frames).To(Equal(packet.frames)) Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100))) Expect(p.SendTime).To(Equal(t)) }) It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { - packet := &packedPacket{ - buffer: getPacketBuffer(), + packet := &packetContents{ header: &wire.ExtendedHeader{Header: wire.Header{}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, } @@ -823,8 +901,7 @@ var _ = Describe("Converting to AckHandler packets", func() { It("doesn't overwrite the OnLost callback, if it is set", func() { var pingLost bool - packet := &packedPacket{ - buffer: getPacketBuffer(), + packet := &packetContents{ header: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeHandshake}}, frames: []ackhandler.Frame{ {Frame: &wire.MaxDataFrame{}}, diff --git a/session.go b/session.go index 4595ace3..2795c9a6 100644 --- a/session.go +++ b/session.go @@ -1309,13 +1309,24 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() - var packet *packedPacket - var err error if !s.handshakeConfirmed { - packet, err = s.packer.PackPacket() - } else { - packet, err = s.packer.PackAppDataPacket() + now := time.Now() + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + for _, p := range packet.packets { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) + } + s.connIDManager.SentPacket() + s.logCoalescedPacket(now, packet) + s.sendQueue.Send(packet.buffer) + return true, nil } + packet, err := s.packer.PackAppDataPacket() if err != nil || packet == nil { return false, err } @@ -1324,35 +1335,13 @@ func (s *session) sendPacket() (bool, error) { } func (s *session) sendPackedPacket(packet *packedPacket) { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { - s.firstAckElicitingPacketAfterIdleSentTime = time.Now() - } now := time.Now() - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue)) - if s.traceCallback != nil { - frames := make([]wire.Frame, 0, len(packet.frames)) - for _, f := range packet.frames { - frames = append(frames, f.Frame) - } - s.traceCallback(quictrace.Event{ - Time: now, - EventType: quictrace.PacketSent, - TransportState: s.sentPacketHandler.GetStats(), - EncryptionLevel: packet.EncryptionLevel(), - PacketNumber: packet.header.PacketNumber, - PacketSize: packet.buffer.Len(), - Frames: frames, - }) + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now } - if s.qlogger != nil { - frames := make([]wire.Frame, 0, len(packet.frames)) - for _, f := range packet.frames { - frames = append(frames, f.Frame) - } - s.qlogger.SentPacket(now, packet.header, packet.buffer.Len(), packet.ack, frames) - } - s.logPacket(packet) + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(time.Now(), s.retransmissionQueue)) s.connIDManager.SentPacket() + s.logPacket(now, packet) s.sendQueue.Send(packet.buffer) } @@ -1375,25 +1364,66 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) ([]byte, error) { if err != nil { return nil, err } - s.logPacket(packet) + s.logPacket(time.Now(), packet) return packet.buffer.Data, s.conn.Write(packet.buffer.Data) } -func (s *session) logPacket(packet *packedPacket) { +func (s *session) logPacketContents(now time.Time, p *packetContents) { + // qlog + if s.qlogger != nil { + frames := make([]wire.Frame, 0, len(p.frames)) + for _, f := range p.frames { + frames = append(frames, f.Frame) + } + s.qlogger.SentPacket(now, p.header, p.length, p.ack, frames) + } + + // quic-trace + if s.traceCallback != nil { + frames := make([]wire.Frame, 0, len(p.frames)) + for _, f := range p.frames { + frames = append(frames, f.Frame) + } + s.traceCallback(quictrace.Event{ + Time: now, + EventType: quictrace.PacketSent, + TransportState: s.sentPacketHandler.GetStats(), + EncryptionLevel: p.EncryptionLevel(), + PacketNumber: p.header.PacketNumber, + PacketSize: p.length, + Frames: frames, + }) + } + + // quic-go logging if !s.logger.Debug() { - // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) - packet.header.Log(s.logger) - if packet.ack != nil { - wire.LogFrame(s.logger, packet.ack, true) + p.header.Log(s.logger) + if p.ack != nil { + wire.LogFrame(s.logger, p.ack, true) } - for _, frame := range packet.frames { + for _, frame := range p.frames { wire.LogFrame(s.logger, frame.Frame, true) } } +func (s *session) logCoalescedPacket(now time.Time, packet *coalescedPacket) { + if s.logger.Debug() { + s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) + } + for _, p := range packet.packets { + s.logPacketContents(now, p) + } +} + +func (s *session) logPacket(now time.Time, packet *packedPacket) { + if s.logger.Debug() { + s.logger.Debugf("-> Sending packet %#x (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) + } + s.logPacketContents(now, packet.packetContents) +} + // AcceptStream returns the next stream openend by the peer func (s *session) AcceptStream(ctx context.Context) (Stream, error) { return s.streamsMap.AcceptStream(ctx) diff --git a/session_test.go b/session_test.go index 0833c7df..28744de2 100644 --- a/session_test.go +++ b/session_test.go @@ -57,7 +57,9 @@ var _ = Describe("Session", func() { buffer.Data = append(buffer.Data, []byte("foobar")...) return &packedPacket{ buffer: buffer, - header: &wire.ExtendedHeader{PacketNumber: pn}, + packetContents: &packetContents{ + header: &wire.ExtendedHeader{PacketNumber: pn}, + }, } } @@ -884,7 +886,8 @@ var _ = Describe("Session", func() { }) It("sends packets", func() { - packer.EXPECT().PackPacket().Return(getPacket(1), nil) + sess.handshakeConfirmed = true + packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() @@ -893,7 +896,8 @@ var _ = Describe("Session", func() { }) It("doesn't send packets if there's nothing to send", func() { - packer.EXPECT().PackPacket().Return(nil, nil) + sess.handshakeConfirmed = true + packer.EXPECT().PackAppDataPacket().Return(nil, nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) @@ -911,9 +915,10 @@ var _ = Describe("Session", func() { }) It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { + sess.handshakeConfirmed = true fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) - packer.EXPECT().PackPacket().Return(getPacket(1), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) sess.connFlowController = fc mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() @@ -996,6 +1001,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sess.handshakeConfirmed = true sess.sentPacketHandler = sph streamManager.EXPECT().CloseWithError(gomock.Any()) }) @@ -1016,8 +1022,8 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) // allow 2 packets... - packer.EXPECT().PackPacket().Return(getPacket(10), nil) - packer.EXPECT().PackPacket().Return(getPacket(11), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(10), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(11), nil) mconn.EXPECT().Write(gomock.Any()).Times(2) go func() { defer GinkgoRecover() @@ -1036,7 +1042,7 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - packer.EXPECT().PackPacket().Return(getPacket(100), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(100), nil) mconn.EXPECT().Write(gomock.Any()) go func() { defer GinkgoRecover() @@ -1055,8 +1061,8 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(100), nil) - packer.EXPECT().PackPacket().Return(getPacket(101), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(100), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(101), nil) written := make(chan struct{}, 2) mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { written <- struct{}{} @@ -1079,9 +1085,9 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1002), nil) written := make(chan struct{}, 3) mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { written <- struct{}{} @@ -1100,7 +1106,7 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket() + packer.EXPECT().PackAppDataPacket() // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1113,6 +1119,10 @@ var _ = Describe("Session", func() { }) Context("scheduling sending", func() { + BeforeEach(func() { + sess.handshakeConfirmed = true + }) + AfterEach(func() { // make the go routine return expectReplaceWithClosed() @@ -1132,7 +1142,7 @@ var _ = Describe("Session", func() { sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) sph.EXPECT().SentPacket(gomock.Any()) sess.sentPacketHandler = sph - packer.EXPECT().PackPacket().Return(getPacket(1), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) go func() { defer GinkgoRecover() @@ -1149,7 +1159,7 @@ var _ = Describe("Session", func() { }) It("sets the timer to the ack timer", func() { - packer.EXPECT().PackPacket().Return(getPacket(1234), nil) + packer.EXPECT().PackAppDataPacket().Return(getPacket(1234), nil) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) @@ -1177,6 +1187,75 @@ var _ = Describe("Session", func() { }) }) + It("sends coalesced packets before the handshake is confirmed", func() { + sess.handshakeConfirmed = false + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("foobar")...) + packer.EXPECT().PackPacket().Return(&coalescedPacket{ + buffer: buffer, + packets: []*packetContents{ + { + header: &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + }, + PacketNumber: 13, + }, + length: 123, + }, + { + header: &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + }, + PacketNumber: 37, + }, + length: 1234, + }, + }, + }, nil) + packer.EXPECT().PackPacket().AnyTimes() + + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(13))) + Expect(p.Length).To(BeEquivalentTo(123)) + }) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionHandshake)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(37))) + Expect(p.Length).To(BeEquivalentTo(1234)) + }) + sent := make(chan struct{}) + mconn.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(sent) }) + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + }() + + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) + + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + sess.shutdown() + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { packer.EXPECT().PackPacket().AnyTimes() finishHandshake := make(chan struct{}) @@ -1282,7 +1361,9 @@ var _ = Describe("Session", func() { Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) return &packedPacket{ - header: &wire.ExtendedHeader{}, + packetContents: &packetContents{ + header: &wire.ExtendedHeader{}, + }, buffer: getPacketBuffer(), }, nil }) @@ -1303,7 +1384,6 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) sess.shutdown() Eventually(sess.Context().Done()).Should(BeClosed()) }) From 29b784e782fb90860f528ed556e4327157c47121 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 13 Feb 2020 17:18:19 +0700 Subject: [PATCH 5/7] rename packet packing functions in the packet packer --- mock_packer_test.go | 18 ++++++------- packet_packer.go | 12 ++++----- packet_packer_test.go | 60 +++++++++++++++++++++---------------------- session.go | 4 +-- session_test.go | 52 ++++++++++++++++++------------------- 5 files changed, 73 insertions(+), 73 deletions(-) diff --git a/mock_packer_test.go b/mock_packer_test.go index 1fc58d12..2349880f 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -78,19 +78,19 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) } -// PackAppDataPacket mocks base method -func (m *MockPacker) PackAppDataPacket() (*packedPacket, error) { +// PackCoalescedPacket mocks base method +func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackAppDataPacket") - ret0, _ := ret[0].(*packedPacket) + ret := m.ctrl.Call(m, "PackCoalescedPacket") + ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } -// PackAppDataPacket indicates an expected call of PackAppDataPacket -func (mr *MockPackerMockRecorder) PackAppDataPacket() *gomock.Call { +// PackCoalescedPacket indicates an expected call of PackCoalescedPacket +func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAppDataPacket", reflect.TypeOf((*MockPacker)(nil).PackAppDataPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket)) } // PackConnectionClose mocks base method @@ -109,10 +109,10 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock. } // PackPacket mocks base method -func (m *MockPacker) PackPacket() (*coalescedPacket, error) { +func (m *MockPacker) PackPacket() (*packedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackPacket") - ret0, _ := ret[0].(*coalescedPacket) + ret0, _ := ret[0].(*packedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/packet_packer.go b/packet_packer.go index 2eaa8410..b612d6ef 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -15,8 +15,8 @@ import ( ) type packer interface { - PackPacket() (*coalescedPacket, error) - PackAppDataPacket() (*packedPacket, error) + PackCoalescedPacket() (*coalescedPacket, error) + PackPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) @@ -264,10 +264,10 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke return p.writeSinglePacket(hdr, payload, encLevel, sealer) } -// PackPacket packs a new packet. +// PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackPacket() (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { buffer := getPacketBuffer() packet, err := p.packCoalescedPacket(buffer) if err != nil { @@ -324,9 +324,9 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack return packet, nil } -// PackAppDataPacket packs a packet in the application data packet number space. +// PackPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) { +func (p *packetPacker) PackPacket() (*packedPacket, error) { buffer := getPacketBuffer() contents, err := p.maybeAppendAppDataPacket(buffer) if err != nil || contents == nil { diff --git a/packet_packer_test.go b/packet_packer_test.go index b5ee06bb..d980125f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -175,7 +175,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.packets).To(HaveLen(1)) @@ -265,7 +265,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return frames, 0 }) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) @@ -290,7 +290,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -306,7 +306,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} @@ -325,7 +325,7 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte("foobar"), }}) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) @@ -338,7 +338,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.ack).To(Equal(ack)) @@ -370,7 +370,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal(frames)) @@ -392,7 +392,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, err := packer.PackAppDataPacket() + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -409,7 +409,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackAppDataPacket() + packet, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] @@ -458,7 +458,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) @@ -476,7 +476,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -492,7 +492,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}})) @@ -503,7 +503,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err = packer.PackAppDataPacket() + p, err = packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -518,7 +518,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) // now add some frame to send @@ -529,7 +529,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err = packer.PackAppDataPacket() + p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) @@ -543,7 +543,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackAppDataPacket() + p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) @@ -561,7 +561,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackAppDataPacket() + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // now reduce the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -572,7 +572,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackAppDataPacket() + _, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -586,7 +586,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackAppDataPacket() + _, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) // now try to increase the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -597,7 +597,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackAppDataPacket() + _, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) }) }) @@ -617,7 +617,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) checkLength(p.buffer.Data) }) @@ -638,7 +638,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(size)) return f }) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].frames).To(HaveLen(1)) @@ -665,7 +665,7 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} }) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(2)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -697,7 +697,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(s)) return f }) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -716,7 +716,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -734,7 +734,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -746,7 +746,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -762,7 +762,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -782,7 +782,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) Expect(p.packets).To(HaveLen(1)) @@ -807,7 +807,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient - p, err := packer.PackPacket() + p, err := packer.PackCoalescedPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) diff --git a/session.go b/session.go index 2795c9a6..5587ba7c 100644 --- a/session.go +++ b/session.go @@ -1311,7 +1311,7 @@ func (s *session) sendPacket() (bool, error) { if !s.handshakeConfirmed { now := time.Now() - packet, err := s.packer.PackPacket() + packet, err := s.packer.PackCoalescedPacket() if err != nil || packet == nil { return false, err } @@ -1326,7 +1326,7 @@ func (s *session) sendPacket() (bool, error) { s.sendQueue.Send(packet.buffer) return true, nil } - packet, err := s.packer.PackAppDataPacket() + packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 28744de2..5495056a 100644 --- a/session_test.go +++ b/session_test.go @@ -887,7 +887,7 @@ var _ = Describe("Session", func() { It("sends packets", func() { sess.handshakeConfirmed = true - packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) + packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() @@ -897,7 +897,7 @@ var _ = Describe("Session", func() { It("doesn't send packets if there's nothing to send", func() { sess.handshakeConfirmed = true - packer.EXPECT().PackAppDataPacket().Return(nil, nil) + packer.EXPECT().PackPacket().Return(nil, nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) @@ -918,7 +918,7 @@ var _ = Describe("Session", func() { sess.handshakeConfirmed = true fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) - packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) + packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.connFlowController = fc mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() @@ -1022,8 +1022,8 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) // allow 2 packets... - packer.EXPECT().PackAppDataPacket().Return(getPacket(10), nil) - packer.EXPECT().PackAppDataPacket().Return(getPacket(11), nil) + packer.EXPECT().PackPacket().Return(getPacket(10), nil) + packer.EXPECT().PackPacket().Return(getPacket(11), nil) mconn.EXPECT().Write(gomock.Any()).Times(2) go func() { defer GinkgoRecover() @@ -1042,7 +1042,7 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - packer.EXPECT().PackAppDataPacket().Return(getPacket(100), nil) + packer.EXPECT().PackPacket().Return(getPacket(100), nil) mconn.EXPECT().Write(gomock.Any()) go func() { defer GinkgoRecover() @@ -1061,8 +1061,8 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackAppDataPacket().Return(getPacket(100), nil) - packer.EXPECT().PackAppDataPacket().Return(getPacket(101), nil) + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + packer.EXPECT().PackPacket().Return(getPacket(101), nil) written := make(chan struct{}, 2) mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { written <- struct{}{} @@ -1085,9 +1085,9 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) - packer.EXPECT().PackAppDataPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackAppDataPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackAppDataPacket().Return(getPacket(1002), nil) + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackPacket().Return(getPacket(1002), nil) written := make(chan struct{}, 3) mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { written <- struct{}{} @@ -1106,7 +1106,7 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackAppDataPacket() + packer.EXPECT().PackPacket() // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1142,7 +1142,7 @@ var _ = Describe("Session", func() { sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) sph.EXPECT().SentPacket(gomock.Any()) sess.sentPacketHandler = sph - packer.EXPECT().PackAppDataPacket().Return(getPacket(1), nil) + packer.EXPECT().PackPacket().Return(getPacket(1), nil) go func() { defer GinkgoRecover() @@ -1159,7 +1159,7 @@ var _ = Describe("Session", func() { }) It("sets the timer to the ack timer", func() { - packer.EXPECT().PackAppDataPacket().Return(getPacket(1234), nil) + packer.EXPECT().PackPacket().Return(getPacket(1234), nil) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().Return(time.Now()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) @@ -1193,7 +1193,7 @@ var _ = Describe("Session", func() { sess.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackPacket().Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ buffer: buffer, packets: []*packetContents{ { @@ -1218,7 +1218,7 @@ var _ = Describe("Session", func() { }, }, }, nil) - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1257,7 +1257,7 @@ var _ = Describe("Session", func() { }) It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph @@ -1294,7 +1294,7 @@ var _ = Describe("Session", func() { It("sends a session ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() finishHandshake := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) go func() { @@ -1334,7 +1334,7 @@ var _ = Describe("Session", func() { }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{buffer: getPacketBuffer()}, nil) @@ -1355,7 +1355,7 @@ var _ = Describe("Session", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket().DoAndReturn(func() (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1367,7 +1367,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1441,7 +1441,7 @@ var _ = Describe("Session", func() { } streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) - packer.EXPECT().PackPacket().MaxTimes(3) + packer.EXPECT().PackCoalescedPacket().MaxTimes(3) Expect(sess.earlySessionReady()).ToNot(BeClosed()) sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) @@ -1497,7 +1497,7 @@ var _ = Describe("Session", func() { setRemoteIdleTimeout(5 * time.Second) sess.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil }) @@ -1510,7 +1510,7 @@ var _ = Describe("Session", func() { setRemoteIdleTimeout(time.Hour) sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil }) @@ -1606,7 +1606,7 @@ var _ = Describe("Session", func() { }) It("closes the session due to the idle timeout after handshake", func() { - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket().AnyTimes() gomock.InOrder( sessionRunner.EXPECT().Retire(clientDestConnID), sessionRunner.EXPECT().Remove(gomock.Any()), @@ -1965,7 +1965,7 @@ var _ = Describe("Client Session", func() { }, } packer.EXPECT().HandleTransportParameters(gomock.Any()) - packer.EXPECT().PackPacket().MaxTimes(1) + packer.EXPECT().PackCoalescedPacket().MaxTimes(1) sess.processTransportParameters(params) cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(HaveLen(1)) From db7fc0eb022fbbd4257178d77ac5d6ef6d6763ac Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 13 Feb 2020 19:35:31 +0700 Subject: [PATCH 6/7] simplify packing of Initial and Handshake packets --- packet_packer.go | 76 ++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index b612d6ef..4128d3e3 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -287,7 +287,7 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack packets: make([]*packetContents, 0, 3), } // Try packing an Initial packet. - contents, err := p.maybeAppendInitialPacket(buffer) + contents, err := p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial) if err != nil && err != handshake.ErrKeysDropped { return nil, err } @@ -299,7 +299,7 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack } // Add a Handshake packet. - contents, err = p.maybeAppendHandshakePacket(buffer) + contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake) if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } @@ -339,52 +339,38 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeAppendInitialPacket(buffer *packetBuffer) (*packetContents, error) { - sealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, err - } - - hasRetransmission := p.retransmissionQueue.HasInitialData() - ack := p.acks.GetAckFrame(protocol.EncryptionInitial) - if !p.initialStream.HasData() && !hasRetransmission && ack == nil { - // nothing to send - return nil, nil - } - return p.appendCryptoPacket(buffer, protocol.EncryptionInitial, sealer, ack, hasRetransmission) -} - -func (p *packetPacker) maybeAppendHandshakePacket(buffer *packetBuffer) (*packetContents, error) { - sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, err - } - - hasRetransmission := p.retransmissionQueue.HasHandshakeData() - // TODO: make sure that the ACK always fits - ack := p.acks.GetAckFrame(protocol.EncryptionHandshake) - if !p.handshakeStream.HasData() && !hasRetransmission && ack == nil { - // nothing to send - return nil, nil - } - return p.appendCryptoPacket(buffer, protocol.EncryptionHandshake, sealer, ack, hasRetransmission) -} - -func (p *packetPacker) appendCryptoPacket( - buffer *packetBuffer, - encLevel protocol.EncryptionLevel, - sealer handshake.LongHeaderSealer, - ack *wire.AckFrame, - hasRetransmission bool, -) (*packetContents, error) { - s := p.handshakeStream +func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel protocol.EncryptionLevel) (*packetContents, error) { + var sealer sealer + var s cryptoStream + var hasRetransmission bool maxPacketSize := p.maxPacketSize - if encLevel == protocol.EncryptionInitial { - s = p.initialStream + switch encLevel { + case protocol.EncryptionInitial: if p.perspective == protocol.PerspectiveClient { maxPacketSize = protocol.MinInitialPacketSize } + s = p.initialStream + hasRetransmission = p.retransmissionQueue.HasInitialData() + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + case protocol.EncryptionHandshake: + s = p.handshakeStream + hasRetransmission = p.retransmissionQueue.HasHandshakeData() + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } } + ack := p.acks.GetAckFrame(encLevel) + if !s.HasData() && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + remainingLen := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) var payload payload @@ -503,9 +489,9 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( buffer := getPacketBuffer() switch encLevel { case protocol.EncryptionInitial: - contents, err = p.maybeAppendInitialPacket(buffer) + contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial) case protocol.EncryptionHandshake: - contents, err = p.maybeAppendHandshakePacket(buffer) + contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake) case protocol.Encryption1RTT: contents, err = p.maybeAppendAppDataPacket(buffer) default: From 7a532326ec29026ae5686e6324c909c6375b9881 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 14 Feb 2020 14:43:31 +0700 Subject: [PATCH 7/7] don't pack ACK frames in the second part of a coalesced packet This prevents a possible overflow of the maximum packet size if the ACK frames ends up being really large. --- packet_packer.go | 21 ++++++++++++++------- packet_packer_test.go | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 4128d3e3..2e0c3b2b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -365,7 +365,11 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel pr return nil, err } } - ack := p.acks.GetAckFrame(encLevel) + + var ack *wire.AckFrame + if encLevel != protocol.EncryptionHandshake || buffer.Len() == 0 { + ack = p.acks.GetAckFrame(encLevel) + } if !s.HasData() && !hasRetransmission && ack == nil { // nothing to send return nil, nil @@ -430,7 +434,7 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetCo headerLen := header.GetLength(p.version) maxSize := p.maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen - payload := p.composeNextPacket(maxSize) + payload := p.composeNextPacket(maxSize, encLevel != protocol.Encryption0RTT && buffer.Len() == 0) // check if we have anything to send if len(payload.frames) == 0 && payload.ack == nil { @@ -452,13 +456,16 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetCo return p.appendPacket(buffer, header, payload, encLevel, sealer) } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) payload { var payload payload - // TODO: we don't need to request ACKs when sending 0-RTT packets - if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil { - payload.ack = ack - payload.length += ack.Length(p.version) + var ack *wire.AckFrame + if ackAllowed { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT) + if ack != nil { + payload.ack = ack + payload.length += ack.Length(p.version) + } } for { diff --git a/packet_packer_test.go b/packet_packer_test.go index d980125f..1456252a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -647,7 +647,7 @@ var _ = Describe("Packet packer", func() { checkLength(p.buffer.Data) }) - It("packs a coalesced packet", func() { + It("packs a coalesced packet with Initial / Handshake", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) @@ -656,7 +656,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + // don't EXPECT any calls for a Handshake ACK frame initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} @@ -683,6 +683,40 @@ var _ = Describe("Packet packer", func() { Expect(rest).To(BeEmpty()) }) + It("packs a coalesced packet with Handshake / 1-RTT", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + // don't EXPECT any calls for a 1-RTT ACK frame + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} + }) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) + p, err := packer.PackCoalescedPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.packets).To(HaveLen(2)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.packets[1].frames).To(HaveLen(1)) + Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + hdr, _, rest, err = wire.ParsePacket(rest, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsLongHeader).To(BeFalse()) + Expect(rest).To(BeEmpty()) + }) + It("doesn't add a coalesced packet if the remaining size is smaller than MaxCoalescedPacketSize", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24))