refactor packing of packets before and after the handshake is confirmed

This commit is contained in:
Marten Seemann 2020-02-10 14:15:50 +08:00
parent e01995041e
commit a4b4d52063
5 changed files with 70 additions and 86 deletions

View file

@ -49,18 +49,18 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g
}
// MaybePackAckPacket mocks base method
func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) {
func (m *MockPacker) MaybePackAckPacket(arg0 bool) (*packedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MaybePackAckPacket")
ret := m.ctrl.Call(m, "MaybePackAckPacket", arg0)
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MaybePackAckPacket indicates an expected call of MaybePackAckPacket
func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call {
func (mr *MockPackerMockRecorder) MaybePackAckPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), arg0)
}
// MaybePackProbePacket mocks base method
@ -78,6 +78,21 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0)
}
// PackAppDataPacket mocks base method
func (m *MockPacker) PackAppDataPacket() (*packedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackAppDataPacket")
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PackAppDataPacket indicates an expected call of PackAppDataPacket
func (mr *MockPackerMockRecorder) PackAppDataPacket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAppDataPacket", reflect.TypeOf((*MockPacker)(nil).PackAppDataPacket))
}
// PackConnectionClose mocks base method
func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) {
m.ctrl.T.Helper()

View file

@ -16,8 +16,9 @@ import (
type packer interface {
PackPacket() (*packedPacket, error)
PackAppDataPacket() (*packedPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error)
MaybePackAckPacket() (*packedPacket, error)
MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error)
PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
HandleTransportParameters(*handshake.TransportParameters)
@ -138,10 +139,6 @@ type packetPacker struct {
version protocol.VersionNumber
cryptoSetup sealingManager
// Once both Initial and Handshake keys are dropped, we only send 1-RTT packets.
droppedInitial bool
droppedHandshake bool
initialStream cryptoStream
handshakeStream cryptoStream
@ -188,10 +185,6 @@ func newPacketPacker(
}
}
func (p *packetPacker) handshakeConfirmed() bool {
return p.droppedInitial && p.droppedHandshake
}
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
payload := payload{
@ -225,10 +218,10 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
var encLevel protocol.EncryptionLevel
var ack *wire.AckFrame
if !p.handshakeConfirmed() {
if !handshakeConfirmed {
ack = p.acks.GetAckFrame(protocol.EncryptionInitial)
if ack != nil {
encLevel = protocol.EncryptionInitial
@ -261,38 +254,33 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
}
// PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
// PackPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackPacket() (*packedPacket, error) {
if !p.handshakeConfirmed() {
packet, err := p.maybePackCryptoPacket()
if err != nil {
return nil, err
}
if packet != nil {
return packet, nil
if err != nil || packet != nil {
return packet, err
}
return p.maybePackAppDataPacket()
}
// PackAppDataPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) {
return p.maybePackAppDataPacket()
}
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
// Try packing an Initial packet.
packet, err := p.maybePackInitialPacket()
if err == handshake.ErrKeysDropped {
p.droppedInitial = true
} else if err != nil || packet != nil {
if (err != nil && err != handshake.ErrKeysDropped) || packet != nil {
return packet, err
}
// No Initial was packed. Try packing a Handshake packet.
packet, err = p.maybePackHandshakePacket()
if err == handshake.ErrKeysDropped {
p.droppedHandshake = true
return nil, nil
}
if err == handshake.ErrKeysNotYetAvailable {
if err == handshake.ErrKeysDropped || err == handshake.ErrKeysNotYetAvailable {
return nil, nil
}
return packet, err

View file

@ -206,7 +206,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.MaybePackAckPacket()
p, err := packer.MaybePackAckPacket(false)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
@ -218,7 +218,7 @@ var _ = Describe("Packet packer", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack)
p, err := packer.MaybePackAckPacket()
p, err := packer.MaybePackAckPacket(false)
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake))
@ -230,10 +230,8 @@ var _ = Describe("Packet packer", func() {
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err := packer.MaybePackAckPacket()
p, err := packer.MaybePackAckPacket(true)
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
@ -276,8 +274,6 @@ var _ = Describe("Packet packer", func() {
Context("packing normal packets", func() {
BeforeEach(func() {
sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes()
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes()
initialStream.EXPECT().HasData().AnyTimes()
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes()
handshakeStream.EXPECT().HasData().AnyTimes()
@ -291,7 +287,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
framer.EXPECT().AppendControlFrames(nil, gomock.Any())
framer.EXPECT().AppendStreamFrames(nil, gomock.Any())
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
})
@ -307,7 +303,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
}
expectAppendStreamFrames(ackhandler.Frame{Frame: f})
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
b := &bytes.Buffer{}
@ -326,7 +322,7 @@ var _ = Describe("Packet packer", func() {
StreamID: 5,
Data: []byte("foobar"),
}})
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
})
@ -339,7 +335,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
expectAppendControlFrames()
expectAppendStreamFrames()
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.ack).To(Equal(ack))
@ -371,7 +367,7 @@ var _ = Describe("Packet packer", func() {
}
expectAppendControlFrames(frames...)
expectAppendStreamFrames()
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal(frames))
@ -393,7 +389,7 @@ var _ = Describe("Packet packer", func() {
return fs, 0
}),
)
_, err := packer.PackPacket()
_, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
})
@ -409,7 +405,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: f})
packet, err := packer.PackPacket()
packet, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
// cut off the tag that the mock sealer added
packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()]
@ -458,7 +454,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3})
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3))
@ -476,7 +472,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames()
expectAppendStreamFrames()
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.ack).ToNot(BeNil())
@ -492,7 +488,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames()
expectAppendStreamFrames()
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}}))
@ -503,7 +499,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames()
expectAppendStreamFrames()
p, err = packer.PackPacket()
p, err = packer.PackAppDataPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.ack).ToNot(BeNil())
@ -518,7 +514,7 @@ var _ = Describe("Packet packer", func() {
expectAppendControlFrames()
expectAppendStreamFrames()
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
// now add some frame to send
@ -529,7 +525,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err = packer.PackPacket()
p, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.ack).To(Equal(ack))
Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}}))
@ -543,7 +539,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendStreamFrames()
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}})
p, err := packer.PackPacket()
p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{}))
@ -561,7 +557,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0
})
expectAppendStreamFrames()
_, err := packer.PackPacket()
_, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
// now reduce the maxPacketSize
packer.HandleTransportParameters(&handshake.TransportParameters{
@ -572,7 +568,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0
})
expectAppendStreamFrames()
_, err = packer.PackPacket()
_, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
})
@ -586,7 +582,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0
})
expectAppendStreamFrames()
_, err := packer.PackPacket()
_, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
// now try to increase the maxPacketSize
packer.HandleTransportParameters(&handshake.TransportParameters{
@ -597,7 +593,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0
})
expectAppendStreamFrames()
_, err = packer.PackPacket()
_, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred())
})
})
@ -737,31 +733,6 @@ var _ = Describe("Packet packer", func() {
Expect(packet.ack).To(Equal(ack))
Expect(packet.frames).To(HaveLen(1))
})
It("stops packing crypto packets when the keys are dropped", func() {
sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped)
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}})
expectAppendStreamFrames()
packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
// now the packer should have realized that the handshake is confirmed
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43))
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}})
expectAppendStreamFrames()
packet, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
})
})
Context("packing probe packets", func() {

View file

@ -164,6 +164,7 @@ type session struct {
earlySessionReadyChan chan struct{}
handshakeCompleteChan chan struct{} // is closed when the handshake completes
handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool
receivedFirstPacket bool
@ -1139,6 +1140,9 @@ func (s *session) handleCloseError(closeErr closeError) {
}
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
if encLevel == protocol.EncryptionHandshake {
s.handshakeConfirmed = true
}
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
}
@ -1247,7 +1251,7 @@ sendLoop:
}
func (s *session) maybeSendAckOnlyPacket() error {
packet, err := s.packer.MaybePackAckPacket()
packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed)
if err != nil {
return err
}
@ -1305,7 +1309,13 @@ func (s *session) sendPacket() (bool, error) {
}
s.windowUpdateQueue.QueueAll()
packet, err := s.packer.PackPacket()
var packet *packedPacket
var err error
if !s.handshakeConfirmed {
packet, err = s.packer.PackPacket()
} else {
packet, err = s.packer.PackAppDataPacket()
}
if err != nil || packet == nil {
return false, err
}

View file

@ -905,7 +905,7 @@ var _ = Describe("Session", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAck)
sph.EXPECT().ShouldSendNumPackets().Return(1000)
packer.EXPECT().MaybePackAckPacket()
packer.EXPECT().MaybePackAckPacket(false)
sess.sentPacketHandler = sph
Expect(sess.sendPackets()).To(Succeed())
})