mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
Merge pull request #2811 from lucas-clemente/fix-first-key-update
allow the first key update immediately after handshake confirmation
This commit is contained in:
commit
145e7b10d0
7 changed files with 125 additions and 86 deletions
|
@ -653,7 +653,9 @@ func (h *cryptoSetup) dropInitialKeys() {
|
||||||
h.logger.Debugf("Dropping Initial keys.")
|
h.logger.Debugf("Dropping Initial keys.")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) DropHandshakeKeys() {
|
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||||
|
h.aead.SetHandshakeConfirmed()
|
||||||
|
// drop Handshake keys
|
||||||
var dropped bool
|
var dropped bool
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
if h.handshakeOpener != nil {
|
if h.handshakeOpener != nil {
|
||||||
|
|
|
@ -77,7 +77,7 @@ type CryptoSetup interface {
|
||||||
|
|
||||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||||
DropHandshakeKeys()
|
SetHandshakeConfirmed()
|
||||||
ConnectionState() ConnectionState
|
ConnectionState() ConnectionState
|
||||||
|
|
||||||
GetInitialOpener() (LongHeaderOpener, error)
|
GetInitialOpener() (LongHeaderOpener, error)
|
||||||
|
|
|
@ -22,9 +22,10 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
||||||
type updatableAEAD struct {
|
type updatableAEAD struct {
|
||||||
suite *qtls.CipherSuiteTLS13
|
suite *qtls.CipherSuiteTLS13
|
||||||
|
|
||||||
keyPhase protocol.KeyPhase
|
keyPhase protocol.KeyPhase
|
||||||
largestAcked protocol.PacketNumber
|
largestAcked protocol.PacketNumber
|
||||||
firstPacketNumber protocol.PacketNumber
|
firstPacketNumber protocol.PacketNumber
|
||||||
|
handshakeConfirmed bool
|
||||||
|
|
||||||
keyUpdateInterval uint64
|
keyUpdateInterval uint64
|
||||||
invalidPacketLimit uint64
|
invalidPacketLimit uint64
|
||||||
|
@ -172,35 +173,24 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||||
if kp != a.keyPhase.Bit() {
|
if kp != a.keyPhase.Bit() {
|
||||||
var receivedWrongInitialKeyPhase bool
|
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
if a.prevRcvAEAD == nil {
|
||||||
if a.keyPhase == 0 {
|
return nil, ErrKeysDropped
|
||||||
// This can only occur when the first packet received has key phase 1.
|
|
||||||
// This is an error, since the key phase starts at 0,
|
|
||||||
// and peers are only allowed to update keys after the handshake is confirmed.
|
|
||||||
// Proceed from here, and only return an error if decryption of the packet succeeds.
|
|
||||||
receivedWrongInitialKeyPhase = true
|
|
||||||
} else {
|
|
||||||
if a.prevRcvAEAD == nil {
|
|
||||||
return nil, ErrKeysDropped
|
|
||||||
}
|
|
||||||
// we updated the key, but the peer hasn't updated yet
|
|
||||||
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
|
||||||
if err != nil {
|
|
||||||
err = ErrDecryptionFailed
|
|
||||||
}
|
|
||||||
return dec, err
|
|
||||||
}
|
}
|
||||||
|
// we updated the key, but the peer hasn't updated yet
|
||||||
|
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||||
|
if err != nil {
|
||||||
|
err = ErrDecryptionFailed
|
||||||
|
}
|
||||||
|
return dec, err
|
||||||
}
|
}
|
||||||
// try opening the packet with the next key phase
|
// try opening the packet with the next key phase
|
||||||
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||||
if err == nil && receivedWrongInitialKeyPhase {
|
if err != nil {
|
||||||
return nil, qerr.NewError(qerr.KeyUpdateError, "wrong initial key phase")
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, ErrDecryptionFailed
|
return nil, ErrDecryptionFailed
|
||||||
}
|
}
|
||||||
// Opening succeeded. Check if the peer was allowed to update.
|
// Opening succeeded. Check if the peer was allowed to update.
|
||||||
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||||
return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly")
|
return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly")
|
||||||
}
|
}
|
||||||
a.rollKeys()
|
a.rollKeys()
|
||||||
|
@ -256,10 +246,20 @@ func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *updatableAEAD) SetHandshakeConfirmed() {
|
||||||
|
a.handshakeConfirmed = true
|
||||||
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) updateAllowed() bool {
|
func (a *updatableAEAD) updateAllowed() bool {
|
||||||
return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
if !a.handshakeConfirmed {
|
||||||
a.largestAcked != protocol.InvalidPacketNumber &&
|
return false
|
||||||
a.largestAcked >= a.firstSentWithCurrentKey
|
}
|
||||||
|
// the first key update is allowed as soon as the handshake is confirmed
|
||||||
|
return a.keyPhase == 0 ||
|
||||||
|
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
|
||||||
|
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||||
|
a.largestAcked != protocol.InvalidPacketNumber &&
|
||||||
|
a.largestAcked >= a.firstSentWithCurrentKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||||
|
|
|
@ -215,11 +215,13 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
Expect(err).To(MatchError(ErrKeysDropped))
|
Expect(err).To(MatchError(ErrKeysDropped))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when the peer starts with key phase 1", func() {
|
It("allows the first key update immediately", func() {
|
||||||
|
// receive a packet at key phase one, before having sent or received any packets at key phase 0
|
||||||
client.rollKeys()
|
client.rollKeys()
|
||||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
encrypted1 := client.Seal(nil, msg, 0x1337, ad)
|
||||||
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
|
||||||
Expect(err).To(MatchError("KEY_UPDATE_ERROR: wrong initial key phase"))
|
_, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
|
It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
|
||||||
|
@ -231,14 +233,16 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when the peer updates keys too frequently", func() {
|
It("errors when the peer updates keys too frequently", func() {
|
||||||
// receive the first packet at key phase zero
|
server.rollKeys()
|
||||||
|
client.rollKeys()
|
||||||
|
// receive the first packet at key phase one
|
||||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||||
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// now receive a packet at key phase one, before having sent any packets
|
// now receive a packet at key phase two, before having sent any packets
|
||||||
client.rollKeys()
|
client.rollKeys()
|
||||||
encrypted1 := client.Seal(nil, msg, 0x42, ad)
|
encrypted1 := client.Seal(nil, msg, 0x42, ad)
|
||||||
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
|
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
||||||
Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly"))
|
Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -249,25 +253,40 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
||||||
server.keyUpdateInterval = keyUpdateInterval
|
server.keyUpdateInterval = keyUpdateInterval
|
||||||
|
server.SetHandshakeConfirmed()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("initiates a key update after sealing the maximum number of packets", func() {
|
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
}
|
}
|
||||||
// no update allowed before receiving an acknowledgement for the current key phase
|
// the first update is allowed without receiving an acknowledgement
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
||||||
// receive an ACK for a packet sent in key phase 0
|
|
||||||
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
||||||
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
|
||||||
|
server.rollKeys()
|
||||||
|
client.rollKeys()
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
// no update allowed before receiving an acknowledgement for the current key phase
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
// receive an ACK for a packet sent in key phase 0
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
})
|
||||||
|
|
||||||
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
||||||
// First make sure that we update our keys.
|
// First make sure that we update our keys.
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
@ -275,14 +294,9 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
}
|
}
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
||||||
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
|
||||||
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
// Now that our keys are updated, send a packet using the new keys.
|
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
// Now that our keys are updated, send a packet using the new keys.
|
||||||
const nextPN = keyUpdateInterval + 1
|
const nextPN = keyUpdateInterval + 1
|
||||||
server.Seal(nil, msg, nextPN, ad)
|
server.Seal(nil, msg, nextPN, ad)
|
||||||
// We haven't decrypted any packet in the new key phase yet.
|
// We haven't decrypted any packet in the new key phase yet.
|
||||||
|
@ -297,7 +311,6 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
}
|
}
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
||||||
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -310,7 +323,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
|
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("initiates a key update after opening the maximum number of packets", func() {
|
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
@ -318,14 +331,30 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
|
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
// no update allowed before receiving an acknowledgement for the current key phase
|
// the first update is allowed without receiving an acknowledgement
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
|
||||||
server.Seal(nil, msg, 1, ad)
|
|
||||||
Expect(server.SetLargestAcked(1)).To(Succeed())
|
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
|
||||||
|
server.rollKeys()
|
||||||
|
client.rollKeys()
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
encrypted := client.Seal(nil, msg, pn, ad)
|
||||||
|
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
// no update allowed before receiving an acknowledgement for the current key phase
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
server.Seal(nil, msg, 1, ad)
|
||||||
|
Expect(server.SetLargestAcked(1)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
})
|
||||||
|
|
||||||
It("drops keys 3 PTOs after a key update", func() {
|
It("drops keys 3 PTOs after a key update", func() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
@ -415,6 +444,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
||||||
|
server.SetHandshakeConfirmed()
|
||||||
// send so many packets that we initiate the first key update
|
// send so many packets that we initiate the first key update
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
|
|
|
@ -76,18 +76,6 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropHandshakeKeys mocks base method
|
|
||||||
func (m *MockCryptoSetup) DropHandshakeKeys() {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "DropHandshakeKeys")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropHandshakeKeys indicates an expected call of DropHandshakeKeys
|
|
||||||
func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get0RTTOpener mocks base method
|
// Get0RTTOpener mocks base method
|
||||||
func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) {
|
func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -249,6 +237,18 @@ func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHandshakeConfirmed mocks base method
|
||||||
|
func (m *MockCryptoSetup) SetHandshakeConfirmed() {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetHandshakeConfirmed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed))
|
||||||
|
}
|
||||||
|
|
||||||
// SetLargest1RTTAcked mocks base method
|
// SetLargest1RTTAcked mocks base method
|
||||||
func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error {
|
func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
14
session.go
14
session.go
|
@ -52,7 +52,7 @@ type cryptoStreamHandler interface {
|
||||||
RunHandshake()
|
RunHandshake()
|
||||||
ChangeConnectionID(protocol.ConnectionID)
|
ChangeConnectionID(protocol.ConnectionID)
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||||
DropHandshakeKeys()
|
SetHandshakeConfirmed()
|
||||||
GetSessionTicket() ([]byte, error)
|
GetSessionTicket() ([]byte, error)
|
||||||
io.Closer
|
io.Closer
|
||||||
ConnectionState() handshake.ConnectionState
|
ConnectionState() handshake.ConnectionState
|
||||||
|
@ -688,6 +688,8 @@ func (s *session) handleHandshakeComplete() {
|
||||||
s.connIDGenerator.SetHandshakeComplete()
|
s.connIDGenerator.SetHandshakeComplete()
|
||||||
|
|
||||||
if s.perspective == protocol.PerspectiveServer {
|
if s.perspective == protocol.PerspectiveServer {
|
||||||
|
s.handshakeConfirmed = true
|
||||||
|
s.sentPacketHandler.SetHandshakeConfirmed()
|
||||||
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
|
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.closeLocal(err)
|
s.closeLocal(err)
|
||||||
|
@ -703,7 +705,7 @@ func (s *session) handleHandshakeComplete() {
|
||||||
s.closeLocal(err)
|
s.closeLocal(err)
|
||||||
}
|
}
|
||||||
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
|
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
|
||||||
s.cryptoStreamHandler.DropHandshakeKeys()
|
s.cryptoStreamHandler.SetHandshakeConfirmed()
|
||||||
s.queueControlFrame(&wire.HandshakeDoneFrame{})
|
s.queueControlFrame(&wire.HandshakeDoneFrame{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1238,7 +1240,9 @@ func (s *session) handleHandshakeDoneFrame() error {
|
||||||
if s.perspective == protocol.PerspectiveServer {
|
if s.perspective == protocol.PerspectiveServer {
|
||||||
return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame")
|
return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame")
|
||||||
}
|
}
|
||||||
s.cryptoStreamHandler.DropHandshakeKeys()
|
s.handshakeConfirmed = true
|
||||||
|
s.sentPacketHandler.SetHandshakeConfirmed()
|
||||||
|
s.cryptoStreamHandler.SetHandshakeConfirmed()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1347,10 +1351,6 @@ func (s *session) handleCloseError(closeErr closeError) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
|
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
|
||||||
if encLevel == protocol.EncryptionHandshake {
|
|
||||||
s.handshakeConfirmed = true
|
|
||||||
s.sentPacketHandler.SetHandshakeConfirmed()
|
|
||||||
}
|
|
||||||
s.sentPacketHandler.DropPackets(encLevel)
|
s.sentPacketHandler.DropPackets(encLevel)
|
||||||
s.receivedPacketHandler.DropPackets(encLevel)
|
s.receivedPacketHandler.DropPackets(encLevel)
|
||||||
if s.tracer != nil {
|
if s.tracer != nil {
|
||||||
|
|
|
@ -1635,12 +1635,13 @@ var _ = Describe("Session", func() {
|
||||||
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
||||||
sph.EXPECT().TimeUntilSend().AnyTimes()
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
||||||
sph.EXPECT().SendMode().AnyTimes()
|
sph.EXPECT().SendMode().AnyTimes()
|
||||||
|
sph.EXPECT().SetHandshakeConfirmed()
|
||||||
sessionRunner.EXPECT().Retire(clientDestConnID)
|
sessionRunner.EXPECT().Retire(clientDestConnID)
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
<-finishHandshake
|
<-finishHandshake
|
||||||
cryptoSetup.EXPECT().RunHandshake()
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||||
cryptoSetup.EXPECT().GetSessionTicket()
|
cryptoSetup.EXPECT().GetSessionTicket()
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
sess.run()
|
sess.run()
|
||||||
|
@ -1670,7 +1671,7 @@ var _ = Describe("Session", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
<-finishHandshake
|
<-finishHandshake
|
||||||
cryptoSetup.EXPECT().RunHandshake()
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||||
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
|
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
sess.run()
|
sess.run()
|
||||||
|
@ -1730,14 +1731,17 @@ var _ = Describe("Session", func() {
|
||||||
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
|
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
|
||||||
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||||
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
|
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
|
||||||
sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount)
|
|
||||||
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
|
||||||
sph.EXPECT().TimeUntilSend().AnyTimes()
|
sph.EXPECT().TimeUntilSend().AnyTimes()
|
||||||
sph.EXPECT().HasPacingBudget().Return(true)
|
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
|
||||||
|
sph.EXPECT().SetHandshakeConfirmed()
|
||||||
|
sph.EXPECT().SentPacket(gomock.Any())
|
||||||
|
mconn.EXPECT().Write(gomock.Any())
|
||||||
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
sess.sentPacketHandler = sph
|
sess.sentPacketHandler = sph
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
sessionRunner.EXPECT().Retire(clientDestConnID)
|
sessionRunner.EXPECT().Retire(clientDestConnID)
|
||||||
packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) {
|
packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) {
|
||||||
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
||||||
Expect(frames).ToNot(BeEmpty())
|
Expect(frames).ToNot(BeEmpty())
|
||||||
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
|
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
|
||||||
|
@ -1749,11 +1753,11 @@ var _ = Describe("Session", func() {
|
||||||
buffer: getPacketBuffer(),
|
buffer: getPacketBuffer(),
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes()
|
packer.EXPECT().PackPacket().AnyTimes()
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
cryptoSetup.EXPECT().RunHandshake()
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||||
cryptoSetup.EXPECT().GetSessionTicket()
|
cryptoSetup.EXPECT().GetSessionTicket()
|
||||||
mconn.EXPECT().Write(gomock.Any())
|
mconn.EXPECT().Write(gomock.Any())
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
|
@ -2027,7 +2031,7 @@ var _ = Describe("Session", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
|
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
err := sess.run()
|
err := sess.run()
|
||||||
nerr, ok := err.(net.Error)
|
nerr, ok := err.(net.Error)
|
||||||
|
@ -2271,7 +2275,10 @@ var _ = Describe("Client Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("handles HANDSHAKE_DONE frames", func() {
|
It("handles HANDSHAKE_DONE frames", func() {
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||||
|
sess.sentPacketHandler = sph
|
||||||
|
sph.EXPECT().SetHandshakeConfirmed()
|
||||||
|
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||||
Expect(sess.handleHandshakeDoneFrame()).To(Succeed())
|
Expect(sess.handleHandshakeDoneFrame()).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue