pass the time a packet was received to the short header opener

This saves us one time.Now() syscall per received 1-RTT packet.
This commit is contained in:
Marten Seemann 2019-08-11 13:33:02 +07:00
parent ee00cbe088
commit bb48c9b31a
9 changed files with 92 additions and 80 deletions

View file

@ -4,6 +4,7 @@ import (
"crypto/tls"
"errors"
"io"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
@ -34,7 +35,7 @@ type LongHeaderOpener interface {
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet

View file

@ -91,14 +91,14 @@ func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updat
}
}
func (a *updatableAEAD) rollKeys() {
func (a *updatableAEAD) rollKeys(now time.Time) {
a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.prevRcvAEADExpiry = time.Now().Add(3 * a.rttStats.PTO())
a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO())
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
@ -142,8 +142,8 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && time.Now().After(a.prevRcvAEADExpiry) {
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.prevRcvAEADExpiry = time.Time{}
}
@ -175,7 +175,7 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, qerr.Error(qerr.ProtocolViolation, "keys updated too quickly")
}
a.rollKeys()
a.rollKeys(rcvTime)
a.logger.Debugf("Peer updated keys to %s", a.keyPhase)
a.firstRcvdWithCurrentKey = pn
return dec, err
@ -232,7 +232,7 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() {
a.rollKeys()
a.rollKeys(time.Now())
}
return a.keyPhase.Bit()
}

View file

@ -82,122 +82,123 @@ var _ = Describe("Updatable AEAD", func() {
It("encrypts and decrypts a message", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
opened, err := client.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg))
})
It("fails to open a message if the associated data is not the same", func() {
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
Expect(err).To(MatchError(ErrDecryptionFailed))
})
It("fails to open a message if the packet number is not the same", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad)
_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
})
Context("key updates", func() {
Context("receiving key updates", func() {
It("updates keys", func() {
now := time.Now()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
server.rollKeys()
server.rollKeys(now)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
Expect(encrypted0).ToNot(Equal(encrypted1))
// expect opening to fail. The client didn't roll keys yet
_, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad)
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
client.rollKeys()
decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad)
client.rollKeys(now)
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
})
It("updates the keys when receiving a packet with the next key phase", func() {
now := time.Now()
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
// send one packet at key phase zero
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_ = server.Seal(nil, msg, 0x1, ad)
// now received a message at key phase one
client.rollKeys()
client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x43, ad)
decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad)
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("opens a reordered packet with the old keys after an update", func() {
rttStats.UpdateRTT(time.Hour, 0, time.Time{}) // make sure the keys don't get dropped yet
now := time.Now()
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys()
client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0
decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("drops keys 3 PTOs after a key update", func() {
rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now())
now := time.Now()
rttStats.UpdateRTT(10*time.Millisecond, 0, now)
pto := rttStats.PTO()
Expect(pto).To(BeNumerically("<", 50*time.Millisecond))
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys()
client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0
time.Sleep(3*pto + 5*time.Millisecond)
_, err = server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrKeysDropped))
})
It("errors when the peer starts with key phase 1", func() {
client.rollKeys()
client.rollKeys(time.Now())
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad)
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
})
It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets
client.rollKeys()
client.rollKeys(time.Now())
encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad)
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
})
})
@ -227,7 +228,7 @@ var _ = Describe("Updatable AEAD", func() {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, pn, protocol.KeyPhaseZero, ad)
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
}
// no update allowed before receiving an acknowledgement for the current key phase

View file

@ -6,6 +6,7 @@ package mocks
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
@ -47,16 +48,16 @@ func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 inte
}
// Open mocks base method
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhaseBit, arg4 []byte) ([]byte, error) {
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4)
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5)
}

View file

@ -6,6 +6,7 @@ package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
wire "github.com/lucas-clemente/quic-go/internal/wire"
@ -35,16 +36,16 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
}
// Unpack mocks base method
func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) {
func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 time.Time, arg2 []byte) (*unpackedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unpack", arg0, arg1)
ret := m.ctrl.Call(m, "Unpack", arg0, arg1, arg2)
ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Unpack indicates an expected call of Unpack
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1, arg2)
}

View file

@ -3,6 +3,7 @@ package quic
import (
"bytes"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -39,7 +40,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
}
}
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
@ -73,7 +74,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data)
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data)
if err != nil {
return nil, err
}
@ -109,7 +110,12 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
func (u *packetUnpacker) unpackShortHeaderPacket(
opener handshake.ShortHeaderOpener,
hdr *wire.Header,
rcvTime time.Time,
data []byte,
) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpack(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
@ -118,7 +124,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpe
return nil, nil, parseErr
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
return nil, nil, err
}

View file

@ -3,6 +3,7 @@ package quic
import (
"bytes"
"errors"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
@ -50,7 +51,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, data)
_, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
})
@ -71,7 +72,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted")))
@ -85,7 +86,7 @@ var _ = Describe("Packet Unpacker", func() {
}
hdr, hdrRaw := getHeader(extHdr)
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable))
})
@ -106,7 +107,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError("test err"))
})
@ -127,7 +128,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
})
@ -142,8 +143,8 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
})
@ -159,8 +160,8 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
testErr := errors.New("decryption error")
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(testErr))
})
@ -202,12 +203,13 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ {
data = append(data, uint8(i))
}
packet, err := unpacker.Unpack(hdr, data)
packet, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
})
It("decodes the packet number", func() {
rcvTime := time.Now().Add(-time.Hour)
firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
@ -217,9 +219,9 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...))
packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
@ -231,9 +233,9 @@ var _ = Describe("Packet Unpacker", func() {
}
// expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...))
packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
})

View file

@ -24,7 +24,7 @@ import (
)
type unpacker interface {
Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error)
Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@ -588,7 +588,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
return false
}
packet, err := s.unpacker.Unpack(hdr, p.data)
packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data)
if err != nil {
switch err {
case handshake.ErrKeysDropped:

View file

@ -508,7 +508,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
packetNumber: 0x1337,
encryptionLevel: protocol.EncryptionInitial,
hdr: hdr,
@ -531,7 +531,7 @@ var _ = Describe("Session", func() {
rcvTime := time.Now().Add(-10 * time.Second)
buf := &bytes.Buffer{}
Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
packetNumber: 0x1337,
encryptionLevel: protocol.Encryption1RTT,
hdr: hdr,
@ -546,7 +546,7 @@ var _ = Describe("Session", func() {
})
It("drops a packet when unpacking fails", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -567,7 +567,7 @@ var _ = Describe("Session", func() {
})
It("closes the session when unpacking fails because the reserved bits were incorrect", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -590,7 +590,7 @@ var _ = Describe("Session", func() {
It("ignores packets when unpacking fails for any other reason", func() {
testErr := errors.New("test err")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, testErr)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -612,7 +612,7 @@ var _ = Describe("Session", func() {
})
It("rejects packets with empty payload", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
hdr: &wire.ExtendedHeader{},
data: []byte{}, // no payload
}, nil)
@ -675,7 +675,7 @@ var _ = Describe("Session", func() {
Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: hdr1,
data: []byte{0}, // one PADDING frame
@ -698,7 +698,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
}
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable)
packet := getPacket(hdr, nil)
Expect(sess.handlePacketImpl(packet)).To(BeFalse())
Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet}))
@ -706,7 +706,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() {
It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{},
data: []byte{0}, // one PADDING frame
@ -746,7 +746,7 @@ var _ = Describe("Session", func() {
It("cuts packets to the right length", func() {
hdrLen, packet := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@ -758,7 +758,7 @@ var _ = Describe("Session", func() {
It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@ -766,7 +766,7 @@ var _ = Describe("Session", func() {
}, nil
})
hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@ -781,8 +781,8 @@ var _ = Describe("Session", func() {
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123)
gomock.InOrder(
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable),
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable),
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@ -801,7 +801,7 @@ var _ = Describe("Session", func() {
wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
Expect(sess.srcConnID).ToNot(Equal(wrongConnID))
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
@ -1563,7 +1563,7 @@ var _ = Describe("Client Session", func() {
It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{Header: *hdr},
@ -1738,7 +1738,7 @@ var _ = Describe("Client Session", func() {
// Send one packet, which might change the connection ID.
packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1)
// only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: hdr1,
data: []byte{0}, // one PADDING frame