refactor packing of packets before and after the handshake is confirmed

This commit is contained in:
Marten Seemann 2020-02-10 14:15:50 +08:00
parent e01995041e
commit a4b4d52063
5 changed files with 70 additions and 86 deletions

View file

@ -49,18 +49,18 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g
} }
// MaybePackAckPacket mocks base method // MaybePackAckPacket mocks base method
func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) { func (m *MockPacker) MaybePackAckPacket(arg0 bool) (*packedPacket, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MaybePackAckPacket") ret := m.ctrl.Call(m, "MaybePackAckPacket", arg0)
ret0, _ := ret[0].(*packedPacket) ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// MaybePackAckPacket indicates an expected call of MaybePackAckPacket // MaybePackAckPacket indicates an expected call of MaybePackAckPacket
func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { func (mr *MockPackerMockRecorder) MaybePackAckPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), arg0)
} }
// MaybePackProbePacket mocks base method // MaybePackProbePacket mocks base method
@ -78,6 +78,21 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0)
} }
// PackAppDataPacket mocks base method
func (m *MockPacker) PackAppDataPacket() (*packedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackAppDataPacket")
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PackAppDataPacket indicates an expected call of PackAppDataPacket
func (mr *MockPackerMockRecorder) PackAppDataPacket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAppDataPacket", reflect.TypeOf((*MockPacker)(nil).PackAppDataPacket))
}
// PackConnectionClose mocks base method // PackConnectionClose mocks base method
func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -16,8 +16,9 @@ import (
type packer interface { type packer interface {
PackPacket() (*packedPacket, error) PackPacket() (*packedPacket, error)
PackAppDataPacket() (*packedPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error)
MaybePackAckPacket() (*packedPacket, error) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error)
PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
HandleTransportParameters(*handshake.TransportParameters) HandleTransportParameters(*handshake.TransportParameters)
@ -138,10 +139,6 @@ type packetPacker struct {
version protocol.VersionNumber version protocol.VersionNumber
cryptoSetup sealingManager cryptoSetup sealingManager
// Once both Initial and Handshake keys are dropped, we only send 1-RTT packets.
droppedInitial bool
droppedHandshake bool
initialStream cryptoStream initialStream cryptoStream
handshakeStream cryptoStream handshakeStream cryptoStream
@ -188,10 +185,6 @@ func newPacketPacker(
} }
} }
func (p *packetPacker) handshakeConfirmed() bool {
return p.droppedInitial && p.droppedHandshake
}
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
payload := payload{ payload := payload{
@ -225,10 +218,10 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
return p.writeAndSealPacket(hdr, payload, encLevel, sealer) return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
} }
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var ack *wire.AckFrame var ack *wire.AckFrame
if !p.handshakeConfirmed() { if !handshakeConfirmed {
ack = p.acks.GetAckFrame(protocol.EncryptionInitial) ack = p.acks.GetAckFrame(protocol.EncryptionInitial)
if ack != nil { if ack != nil {
encLevel = protocol.EncryptionInitial encLevel = protocol.EncryptionInitial
@ -261,38 +254,33 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
return p.writeAndSealPacket(hdr, payload, encLevel, sealer) return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
} }
// PackPacket packs a new packet // PackPacket packs a new packet.
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
if !p.handshakeConfirmed() { packet, err := p.maybePackCryptoPacket()
packet, err := p.maybePackCryptoPacket() if err != nil || packet != nil {
if err != nil { return packet, err
return nil, err
}
if packet != nil {
return packet, nil
}
} }
return p.maybePackAppDataPacket()
}
// PackAppDataPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) {
return p.maybePackAppDataPacket() return p.maybePackAppDataPacket()
} }
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
// Try packing an Initial packet. // Try packing an Initial packet.
packet, err := p.maybePackInitialPacket() packet, err := p.maybePackInitialPacket()
if err == handshake.ErrKeysDropped { if (err != nil && err != handshake.ErrKeysDropped) || packet != nil {
p.droppedInitial = true
} else if err != nil || packet != nil {
return packet, err return packet, err
} }
// No Initial was packed. Try packing a Handshake packet. // No Initial was packed. Try packing a Handshake packet.
packet, err = p.maybePackHandshakePacket() packet, err = p.maybePackHandshakePacket()
if err == handshake.ErrKeysDropped { if err == handshake.ErrKeysDropped || err == handshake.ErrKeysNotYetAvailable {
p.droppedHandshake = true
return nil, nil
}
if err == handshake.ErrKeysNotYetAvailable {
return nil, nil return nil, nil
} }
return packet, err return packet, err

View file

@ -206,7 +206,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.MaybePackAckPacket() p, err := packer.MaybePackAckPacket(false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -218,7 +218,7 @@ var _ = Describe("Packet packer", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack)
p, err := packer.MaybePackAckPacket() p, err := packer.MaybePackAckPacket(false)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake))
@ -230,10 +230,8 @@ var _ = Describe("Packet packer", func() {
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err := packer.MaybePackAckPacket() p, err := packer.MaybePackAckPacket(true)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
@ -276,8 +274,6 @@ var _ = Describe("Packet packer", func() {
Context("packing normal packets", func() { Context("packing normal packets", func() {
BeforeEach(func() { BeforeEach(func() {
sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes()
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes()
initialStream.EXPECT().HasData().AnyTimes() initialStream.EXPECT().HasData().AnyTimes()
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes()
handshakeStream.EXPECT().HasData().AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes()
@ -291,7 +287,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendControlFrames(nil, gomock.Any())
framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any())
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -307,7 +303,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad}, Data: []byte{0xde, 0xca, 0xfb, 0xad},
} }
expectAppendStreamFrames(ackhandler.Frame{Frame: f}) expectAppendStreamFrames(ackhandler.Frame{Frame: f})
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -326,7 +322,7 @@ var _ = Describe("Packet packer", func() {
StreamID: 5, StreamID: 5,
Data: []byte("foobar"), Data: []byte("foobar"),
}}) }})
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
}) })
@ -339,7 +335,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames() expectAppendStreamFrames()
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.ack).To(Equal(ack)) Expect(p.ack).To(Equal(ack))
@ -371,7 +367,7 @@ var _ = Describe("Packet packer", func() {
} }
expectAppendControlFrames(frames...) expectAppendControlFrames(frames...)
expectAppendStreamFrames() expectAppendStreamFrames()
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(Equal(frames)) Expect(p.frames).To(Equal(frames))
@ -393,7 +389,7 @@ var _ = Describe("Packet packer", func() {
return fs, 0 return fs, 0
}), }),
) )
_, err := packer.PackPacket() _, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -409,7 +405,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: f}) expectAppendStreamFrames(ackhandler.Frame{Frame: f})
packet, err := packer.PackPacket() packet, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cut off the tag that the mock sealer added // cut off the tag that the mock sealer added
packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()]
@ -458,7 +454,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3})
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3)) Expect(p.frames).To(HaveLen(3))
@ -476,7 +472,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames() expectAppendStreamFrames()
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.ack).ToNot(BeNil()) Expect(p.ack).ToNot(BeNil())
@ -492,7 +488,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames() expectAppendStreamFrames()
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}})) Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}}))
@ -503,7 +499,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames() expectAppendStreamFrames()
p, err = packer.PackPacket() p, err = packer.PackAppDataPacket()
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.ack).ToNot(BeNil()) Expect(p.ack).ToNot(BeNil())
@ -518,7 +514,7 @@ var _ = Describe("Packet packer", func() {
expectAppendControlFrames() expectAppendControlFrames()
expectAppendStreamFrames() expectAppendStreamFrames()
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
// now add some frame to send // now add some frame to send
@ -529,7 +525,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err = packer.PackPacket() p, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.ack).To(Equal(ack)) Expect(p.ack).To(Equal(ack))
Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}}))
@ -543,7 +539,7 @@ var _ = Describe("Packet packer", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendStreamFrames() expectAppendStreamFrames()
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}})
p, err := packer.PackPacket() p, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{}))
@ -561,7 +557,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0 return nil, 0
}) })
expectAppendStreamFrames() expectAppendStreamFrames()
_, err := packer.PackPacket() _, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// now reduce the maxPacketSize // now reduce the maxPacketSize
packer.HandleTransportParameters(&handshake.TransportParameters{ packer.HandleTransportParameters(&handshake.TransportParameters{
@ -572,7 +568,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0 return nil, 0
}) })
expectAppendStreamFrames() expectAppendStreamFrames()
_, err = packer.PackPacket() _, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -586,7 +582,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0 return nil, 0
}) })
expectAppendStreamFrames() expectAppendStreamFrames()
_, err := packer.PackPacket() _, err := packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// now try to increase the maxPacketSize // now try to increase the maxPacketSize
packer.HandleTransportParameters(&handshake.TransportParameters{ packer.HandleTransportParameters(&handshake.TransportParameters{
@ -597,7 +593,7 @@ var _ = Describe("Packet packer", func() {
return nil, 0 return nil, 0
}) })
expectAppendStreamFrames() expectAppendStreamFrames()
_, err = packer.PackPacket() _, err = packer.PackAppDataPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
}) })
@ -737,31 +733,6 @@ var _ = Describe("Packet packer", func() {
Expect(packet.ack).To(Equal(ack)) Expect(packet.ack).To(Equal(ack))
Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames).To(HaveLen(1))
}) })
It("stops packing crypto packets when the keys are dropped", func() {
sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped)
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}})
expectAppendStreamFrames()
packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
// now the packer should have realized that the handshake is confirmed
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43))
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}})
expectAppendStreamFrames()
packet, err = packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
})
}) })
Context("packing probe packets", func() { Context("packing probe packets", func() {

View file

@ -164,6 +164,7 @@ type session struct {
earlySessionReadyChan chan struct{} earlySessionReadyChan chan struct{}
handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeCompleteChan chan struct{} // is closed when the handshake completes
handshakeComplete bool handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool receivedRetry bool
receivedFirstPacket bool receivedFirstPacket bool
@ -1139,6 +1140,9 @@ func (s *session) handleCloseError(closeErr closeError) {
} }
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
if encLevel == protocol.EncryptionHandshake {
s.handshakeConfirmed = true
}
s.sentPacketHandler.DropPackets(encLevel) s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel)
} }
@ -1247,7 +1251,7 @@ sendLoop:
} }
func (s *session) maybeSendAckOnlyPacket() error { func (s *session) maybeSendAckOnlyPacket() error {
packet, err := s.packer.MaybePackAckPacket() packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed)
if err != nil { if err != nil {
return err return err
} }
@ -1305,7 +1309,13 @@ func (s *session) sendPacket() (bool, error) {
} }
s.windowUpdateQueue.QueueAll() s.windowUpdateQueue.QueueAll()
packet, err := s.packer.PackPacket() var packet *packedPacket
var err error
if !s.handshakeConfirmed {
packet, err = s.packer.PackPacket()
} else {
packet, err = s.packer.PackAppDataPacket()
}
if err != nil || packet == nil { if err != nil || packet == nil {
return false, err return false, err
} }

View file

@ -905,7 +905,7 @@ var _ = Describe("Session", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAck) sph.EXPECT().SendMode().Return(ackhandler.SendAck)
sph.EXPECT().ShouldSendNumPackets().Return(1000) sph.EXPECT().ShouldSendNumPackets().Return(1000)
packer.EXPECT().MaybePackAckPacket() packer.EXPECT().MaybePackAckPacket(false)
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
Expect(sess.sendPackets()).To(Succeed()) Expect(sess.sendPackets()).To(Succeed())
}) })