Merge pull request #2811 from lucas-clemente/fix-first-key-update

allow the first key update immediately after handshake confirmation
This commit is contained in:
Marten Seemann 2020-10-06 20:23:19 +07:00 committed by GitHub
commit 145e7b10d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 125 additions and 86 deletions

View file

@ -653,7 +653,9 @@ func (h *cryptoSetup) dropInitialKeys() {
h.logger.Debugf("Dropping Initial keys.") h.logger.Debugf("Dropping Initial keys.")
} }
func (h *cryptoSetup) DropHandshakeKeys() { func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool var dropped bool
h.mutex.Lock() h.mutex.Lock()
if h.handshakeOpener != nil { if h.handshakeOpener != nil {

View file

@ -77,7 +77,7 @@ type CryptoSetup interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber) error SetLargest1RTTAcked(protocol.PacketNumber) error
DropHandshakeKeys() SetHandshakeConfirmed()
ConnectionState() ConnectionState ConnectionState() ConnectionState
GetInitialOpener() (LongHeaderOpener, error) GetInitialOpener() (LongHeaderOpener, error)

View file

@ -22,9 +22,10 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
type updatableAEAD struct { type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13 suite *qtls.CipherSuiteTLS13
keyPhase protocol.KeyPhase keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool
keyUpdateInterval uint64 keyUpdateInterval uint64
invalidPacketLimit uint64 invalidPacketLimit uint64
@ -172,35 +173,24 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
} }
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() { if kp != a.keyPhase.Bit() {
var receivedWrongInitialKeyPhase bool if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { if a.prevRcvAEAD == nil {
if a.keyPhase == 0 { return nil, ErrKeysDropped
// This can only occur when the first packet received has key phase 1.
// This is an error, since the key phase starts at 0,
// and peers are only allowed to update keys after the handshake is confirmed.
// Proceed from here, and only return an error if decryption of the packet succeeds.
receivedWrongInitialKeyPhase = true
} else {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
} }
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
} }
// try opening the packet with the next key phase // try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err == nil && receivedWrongInitialKeyPhase { if err != nil {
return nil, qerr.NewError(qerr.KeyUpdateError, "wrong initial key phase")
} else if err != nil {
return nil, ErrDecryptionFailed return nil, ErrDecryptionFailed
} }
// Opening succeeded. Check if the peer was allowed to update. // Opening succeeded. Check if the peer was allowed to update.
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly") return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly")
} }
a.rollKeys() a.rollKeys()
@ -256,10 +246,20 @@ func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
return nil return nil
} }
func (a *updatableAEAD) SetHandshakeConfirmed() {
a.handshakeConfirmed = true
}
func (a *updatableAEAD) updateAllowed() bool { func (a *updatableAEAD) updateAllowed() bool {
return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && if !a.handshakeConfirmed {
a.largestAcked != protocol.InvalidPacketNumber && return false
a.largestAcked >= a.firstSentWithCurrentKey }
// the first key update is allowed as soon as the handshake is confirmed
return a.keyPhase == 0 ||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey)
} }
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {

View file

@ -215,11 +215,13 @@ var _ = Describe("Updatable AEAD", func() {
Expect(err).To(MatchError(ErrKeysDropped)) Expect(err).To(MatchError(ErrKeysDropped))
}) })
It("errors when the peer starts with key phase 1", func() { It("allows the first key update immediately", func() {
// receive a packet at key phase one, before having sent or received any packets at key phase 0
client.rollKeys() client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad) encrypted1 := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
Expect(err).To(MatchError("KEY_UPDATE_ERROR: wrong initial key phase")) _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
}) })
It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
@ -231,14 +233,16 @@ var _ = Describe("Updatable AEAD", func() {
}) })
It("errors when the peer updates keys too frequently", func() { It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero server.rollKeys()
client.rollKeys()
// receive the first packet at key phase one
encrypted0 := client.Seal(nil, msg, 0x42, ad) encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets // now receive a packet at key phase two, before having sent any packets
client.rollKeys() client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x42, ad) encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly")) Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly"))
}) })
}) })
@ -249,25 +253,40 @@ var _ = Describe("Updatable AEAD", func() {
BeforeEach(func() { BeforeEach(func() {
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
server.keyUpdateInterval = keyUpdateInterval server.keyUpdateInterval = keyUpdateInterval
server.SetHandshakeConfirmed()
}) })
It("initiates a key update after sealing the maximum number of packets", func() { It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
} }
// no update allowed before receiving an acknowledgement for the current key phase // the first update is allowed without receiving an acknowledgement
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
// receive an ACK for a packet sent in key phase 0
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
server.rollKeys()
client.rollKeys()
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
server.Seal(nil, msg, pn, ad)
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// receive an ACK for a packet sent in key phase 0
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
})
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
// First make sure that we update our keys. // First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
@ -275,14 +294,9 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
} }
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
// Now that our keys are updated, send a packet using the new keys.
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// Now that our keys are updated, send a packet using the new keys.
const nextPN = keyUpdateInterval + 1 const nextPN = keyUpdateInterval + 1
server.Seal(nil, msg, nextPN, ad) server.Seal(nil, msg, nextPN, ad)
// We haven't decrypted any packet in the new key phase yet. // We haven't decrypted any packet in the new key phase yet.
@ -297,7 +311,6 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
} }
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -310,7 +323,7 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
}) })
It("initiates a key update after opening the maximum number of packets", func() { It("initiates a key update after opening the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
@ -318,14 +331,30 @@ var _ = Describe("Updatable AEAD", func() {
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
// no update allowed before receiving an acknowledgement for the current key phase // the first update is allowed without receiving an acknowledgement
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, 1, ad)
Expect(server.SetLargestAcked(1)).To(Succeed())
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
server.rollKeys()
client.rollKeys()
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
server.Seal(nil, msg, 1, ad)
Expect(server.SetLargestAcked(1)).To(Succeed())
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
})
It("drops keys 3 PTOs after a key update", func() { It("drops keys 3 PTOs after a key update", func() {
now := time.Now() now := time.Now()
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
@ -415,6 +444,7 @@ var _ = Describe("Updatable AEAD", func() {
}) })
It("drops keys early when we initiate another key update within the 3 PTO period", func() { It("drops keys early when we initiate another key update within the 3 PTO period", func() {
server.SetHandshakeConfirmed()
// send so many packets that we initiate the first key update // send so many packets that we initiate the first key update
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)

View file

@ -76,18 +76,6 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
} }
// DropHandshakeKeys mocks base method
func (m *MockCryptoSetup) DropHandshakeKeys() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropHandshakeKeys")
}
// DropHandshakeKeys indicates an expected call of DropHandshakeKeys
func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys))
}
// Get0RTTOpener mocks base method // Get0RTTOpener mocks base method
func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -249,6 +237,18 @@ func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
} }
// SetHandshakeConfirmed mocks base method
func (m *MockCryptoSetup) SetHandshakeConfirmed() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHandshakeConfirmed")
}
// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed
func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed))
}
// SetLargest1RTTAcked mocks base method // SetLargest1RTTAcked mocks base method
func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -52,7 +52,7 @@ type cryptoStreamHandler interface {
RunHandshake() RunHandshake()
ChangeConnectionID(protocol.ConnectionID) ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error SetLargest1RTTAcked(protocol.PacketNumber) error
DropHandshakeKeys() SetHandshakeConfirmed()
GetSessionTicket() ([]byte, error) GetSessionTicket() ([]byte, error)
io.Closer io.Closer
ConnectionState() handshake.ConnectionState ConnectionState() handshake.ConnectionState
@ -688,6 +688,8 @@ func (s *session) handleHandshakeComplete() {
s.connIDGenerator.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete()
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
ticket, err := s.cryptoStreamHandler.GetSessionTicket() ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil { if err != nil {
s.closeLocal(err) s.closeLocal(err)
@ -703,7 +705,7 @@ func (s *session) handleHandshakeComplete() {
s.closeLocal(err) s.closeLocal(err)
} }
s.queueControlFrame(&wire.NewTokenFrame{Token: token}) s.queueControlFrame(&wire.NewTokenFrame{Token: token})
s.cryptoStreamHandler.DropHandshakeKeys() s.cryptoStreamHandler.SetHandshakeConfirmed()
s.queueControlFrame(&wire.HandshakeDoneFrame{}) s.queueControlFrame(&wire.HandshakeDoneFrame{})
} }
} }
@ -1238,7 +1240,9 @@ func (s *session) handleHandshakeDoneFrame() error {
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame")
} }
s.cryptoStreamHandler.DropHandshakeKeys() s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed()
return nil return nil
} }
@ -1347,10 +1351,6 @@ func (s *session) handleCloseError(closeErr closeError) {
} }
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
if encLevel == protocol.EncryptionHandshake {
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
}
s.sentPacketHandler.DropPackets(encLevel) s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel)
if s.tracer != nil { if s.tracer != nil {

View file

@ -1635,12 +1635,13 @@ var _ = Describe("Session", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode().AnyTimes() sph.EXPECT().SendMode().AnyTimes()
sph.EXPECT().SetHandshakeConfirmed()
sessionRunner.EXPECT().Retire(clientDestConnID) sessionRunner.EXPECT().Retire(clientDestConnID)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
<-finishHandshake <-finishHandshake
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket() cryptoSetup.EXPECT().GetSessionTicket()
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
sess.run() sess.run()
@ -1670,7 +1671,7 @@ var _ = Describe("Session", func() {
defer GinkgoRecover() defer GinkgoRecover()
<-finishHandshake <-finishHandshake
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
sess.run() sess.run()
@ -1730,14 +1731,17 @@ var _ = Describe("Session", func() {
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SetHandshakeConfirmed()
sph.EXPECT().SentPacket(gomock.Any())
mconn.EXPECT().Write(gomock.Any())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
done := make(chan struct{}) done := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID) sessionRunner.EXPECT().Retire(clientDestConnID)
packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) {
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).ToNot(BeEmpty()) Expect(frames).ToNot(BeEmpty())
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
@ -1749,11 +1753,11 @@ var _ = Describe("Session", func() {
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
}, nil }, nil
}) })
packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes() packer.EXPECT().PackPacket().AnyTimes()
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket() cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any()) mconn.EXPECT().Write(gomock.Any())
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
@ -2027,7 +2031,7 @@ var _ = Describe("Session", func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
err := sess.run() err := sess.run()
nerr, ok := err.(net.Error) nerr, ok := err.(net.Error)
@ -2271,7 +2275,10 @@ var _ = Describe("Client Session", func() {
}) })
It("handles HANDSHAKE_DONE frames", func() { It("handles HANDSHAKE_DONE frames", func() {
cryptoSetup.EXPECT().DropHandshakeKeys() sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) Expect(sess.handleHandshakeDoneFrame()).To(Succeed())
}) })