pass the time a packet was received to the short header opener

This saves us one time.Now() syscall per received 1-RTT packet.
This commit is contained in:
Marten Seemann 2019-08-11 13:33:02 +07:00
parent ee00cbe088
commit bb48c9b31a
9 changed files with 92 additions and 80 deletions

View file

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
@ -34,7 +35,7 @@ type LongHeaderOpener interface {
// ShortHeaderOpener opens a short header packet // ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface { type ShortHeaderOpener interface {
headerDecryptor headerDecryptor
Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
} }
// LongHeaderSealer seals a long header packet // LongHeaderSealer seals a long header packet

View file

@ -91,14 +91,14 @@ func newUpdatableAEAD(rttStats *congestion.RTTStats, logger utils.Logger) *updat
} }
} }
func (a *updatableAEAD) rollKeys() { func (a *updatableAEAD) rollKeys(now time.Time) {
a.keyPhase++ a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0 a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0 a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD a.prevRcvAEAD = a.rcvAEAD
a.prevRcvAEADExpiry = time.Now().Add(3 * a.rttStats.PTO()) a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO())
a.rcvAEAD = a.nextRcvAEAD a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD a.sendAEAD = a.nextSendAEAD
@ -142,8 +142,8 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
} }
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && time.Now().After(a.prevRcvAEADExpiry) { if a.prevRcvAEAD != nil && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil a.prevRcvAEAD = nil
a.prevRcvAEADExpiry = time.Time{} a.prevRcvAEADExpiry = time.Time{}
} }
@ -175,7 +175,7 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, qerr.Error(qerr.ProtocolViolation, "keys updated too quickly") return nil, qerr.Error(qerr.ProtocolViolation, "keys updated too quickly")
} }
a.rollKeys() a.rollKeys(rcvTime)
a.logger.Debugf("Peer updated keys to %s", a.keyPhase) a.logger.Debugf("Peer updated keys to %s", a.keyPhase)
a.firstRcvdWithCurrentKey = pn a.firstRcvdWithCurrentKey = pn
return dec, err return dec, err
@ -232,7 +232,7 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() { if a.shouldInitiateKeyUpdate() {
a.rollKeys() a.rollKeys(time.Now())
} }
return a.keyPhase.Bit() return a.keyPhase.Bit()
} }

View file

@ -82,122 +82,123 @@ var _ = Describe("Updatable AEAD", func() {
It("encrypts and decrypts a message", func() { It("encrypts and decrypts a message", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad) encrypted := server.Seal(nil, msg, 0x1337, ad)
opened, err := client.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg)) Expect(opened).To(Equal(msg))
}) })
It("fails to open a message if the associated data is not the same", func() { It("fails to open a message if the associated data is not the same", func() {
encrypted := client.Seal(nil, msg, 0x1337, ad) encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
}) })
It("fails to open a message if the packet number is not the same", func() { It("fails to open a message if the packet number is not the same", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad) encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad) _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
}) })
Context("key updates", func() { Context("key updates", func() {
Context("receiving key updates", func() { Context("receiving key updates", func() {
It("updates keys", func() { It("updates keys", func() {
now := time.Now()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted0 := server.Seal(nil, msg, 0x1337, ad) encrypted0 := server.Seal(nil, msg, 0x1337, ad)
server.rollKeys() server.rollKeys(now)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted1 := server.Seal(nil, msg, 0x1337, ad) encrypted1 := server.Seal(nil, msg, 0x1337, ad)
Expect(encrypted0).ToNot(Equal(encrypted1)) Expect(encrypted0).ToNot(Equal(encrypted1))
// expect opening to fail. The client didn't roll keys yet // expect opening to fail. The client didn't roll keys yet
_, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad) _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
client.rollKeys() client.rollKeys(now)
decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad) decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg)) Expect(decrypted).To(Equal(msg))
}) })
It("updates the keys when receiving a packet with the next key phase", func() { It("updates the keys when receiving a packet with the next key phase", func() {
now := time.Now()
// receive the first packet at key phase zero // receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad) encrypted0 := client.Seal(nil, msg, 0x42, ad)
decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg)) Expect(decrypted).To(Equal(msg))
// send one packet at key phase zero // send one packet at key phase zero
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_ = server.Seal(nil, msg, 0x1, ad) _ = server.Seal(nil, msg, 0x1, ad)
// now received a message at key phase one // now received a message at key phase one
client.rollKeys() client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x43, ad) encrypted1 := client.Seal(nil, msg, 0x43, ad)
decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad) decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg)) Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
It("opens a reordered packet with the old keys after an update", func() { It("opens a reordered packet with the old keys after an update", func() {
rttStats.UpdateRTT(time.Hour, 0, time.Time{}) // make sure the keys don't get dropped yet now := time.Now()
encrypted01 := client.Seal(nil, msg, 0x42, ad) encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad) encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0 // receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero // send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad) _ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1 // now receive a packet with key phase 1
client.rollKeys() client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad) encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0 // now receive a reordered packet with key phase 0
decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad) decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg)) Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
It("drops keys 3 PTOs after a key update", func() { It("drops keys 3 PTOs after a key update", func() {
rttStats.UpdateRTT(10*time.Millisecond, 0, time.Now()) now := time.Now()
rttStats.UpdateRTT(10*time.Millisecond, 0, now)
pto := rttStats.PTO() pto := rttStats.PTO()
Expect(pto).To(BeNumerically("<", 50*time.Millisecond))
encrypted01 := client.Seal(nil, msg, 0x42, ad) encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad) encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0 // receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero // send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad) _ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1 // now receive a packet with key phase 1
client.rollKeys() client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad) encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0 // now receive a reordered packet with key phase 0
time.Sleep(3*pto + 5*time.Millisecond) _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
_, err = server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrKeysDropped)) Expect(err).To(MatchError(ErrKeysDropped))
}) })
It("errors when the peer starts with key phase 1", func() { It("errors when the peer starts with key phase 1", func() {
client.rollKeys() client.rollKeys(time.Now())
encrypted := client.Seal(nil, msg, 0x1337, ad) encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad) _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
}) })
It("errors when the peer updates keys too frequently", func() { It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero // receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad) encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets // now receive a packet at key phase one, before having sent any packets
client.rollKeys() client.rollKeys(time.Now())
encrypted1 := client.Seal(nil, msg, 0x42, ad) encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
}) })
}) })
@ -227,7 +228,7 @@ var _ = Describe("Updatable AEAD", func() {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted := client.Seal(nil, msg, pn, ad) encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, pn, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
// no update allowed before receiving an acknowledgement for the current key phase // no update allowed before receiving an acknowledgement for the current key phase

View file

@ -6,6 +6,7 @@ package mocks
import ( import (
reflect "reflect" reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol" protocol "github.com/lucas-clemente/quic-go/internal/protocol"
@ -47,16 +48,16 @@ func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 inte
} }
// Open mocks base method // Open mocks base method
func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhaseBit, arg4 []byte) ([]byte, error) { func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 protocol.PacketNumber, arg4 protocol.KeyPhaseBit, arg5 []byte) ([]byte, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].([]byte) ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Open indicates an expected call of Open // Open indicates an expected call of Open
func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5)
} }

View file

@ -6,6 +6,7 @@ package quic
import ( import (
reflect "reflect" reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wire "github.com/lucas-clemente/quic-go/internal/wire" wire "github.com/lucas-clemente/quic-go/internal/wire"
@ -35,16 +36,16 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
} }
// Unpack mocks base method // Unpack mocks base method
func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) { func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 time.Time, arg2 []byte) (*unpackedPacket, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unpack", arg0, arg1) ret := m.ctrl.Call(m, "Unpack", arg0, arg1, arg2)
ret0, _ := ret[0].(*unpackedPacket) ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Unpack indicates an expected call of Unpack // Unpack indicates an expected call of Unpack
func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1 interface{}) *gomock.Call { func (mr *MockUnpackerMockRecorder) Unpack(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockUnpacker)(nil).Unpack), arg0, arg1, arg2)
} }

View file

@ -3,6 +3,7 @@ package quic
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -39,7 +40,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
} }
} }
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader var extHdr *wire.ExtendedHeader
var decrypted []byte var decrypted []byte
@ -73,7 +74,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
if err != nil { if err != nil {
return nil, err return nil, err
} }
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data) extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -109,7 +110,12 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return extHdr, decrypted, nil return extHdr, decrypted, nil
} }
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackShortHeaderPacket(
opener handshake.ShortHeaderOpener,
hdr *wire.Header,
rcvTime time.Time,
data []byte,
) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpack(opener, hdr, data) extHdr, parseErr := u.unpack(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
@ -118,7 +124,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpe
return nil, nil, parseErr return nil, nil, parseErr
} }
extHdrLen := extHdr.GetLength(u.version) extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -3,6 +3,7 @@ package quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
@ -50,7 +51,7 @@ var _ = Describe("Packet Unpacker", func() {
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, data) _, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
}) })
@ -71,7 +72,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetInitialOpener().Return(opener, nil) cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil) opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(packet.data).To(Equal([]byte("decrypted"))) Expect(packet.data).To(Equal([]byte("decrypted")))
@ -85,7 +86,7 @@ var _ = Describe("Packet Unpacker", func() {
} }
hdr, hdrRaw := getHeader(extHdr) hdr, hdrRaw := getHeader(extHdr)
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable)) Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable))
}) })
@ -106,7 +107,7 @@ var _ = Describe("Packet Unpacker", func() {
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError("test err")) Expect(err).To(MatchError("test err"))
}) })
@ -127,7 +128,7 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil) cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
@ -142,8 +143,8 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}) })
@ -159,8 +160,8 @@ var _ = Describe("Packet Unpacker", func() {
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
testErr := errors.New("decryption error") testErr := errors.New("decryption error")
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
_, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -202,12 +203,13 @@ var _ = Describe("Packet Unpacker", func() {
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
data = append(data, uint8(i)) data = append(data, uint8(i))
} }
packet, err := unpacker.Unpack(hdr, data) packet, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
}) })
It("decodes the packet number", func() { It("decodes the packet number", func() {
rcvTime := time.Now().Add(-time.Hour)
firstHdr := &wire.ExtendedHeader{ firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID}, Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337, PacketNumber: 0x1337,
@ -217,9 +219,9 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2) cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr) hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent // the real packet number is 0x1338, but only the last byte is sent
@ -231,9 +233,9 @@ var _ = Describe("Packet Unpacker", func() {
} }
// expect the call with the decoded packet number // expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil) opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr) hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, append(hdrRaw, payload...)) packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338))) Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
}) })

View file

@ -24,7 +24,7 @@ import (
) )
type unpacker interface { type unpacker interface {
Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
} }
type streamGetter interface { type streamGetter interface {
@ -588,7 +588,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
return false return false
} }
packet, err := s.unpacker.Unpack(hdr, p.data) packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data)
if err != nil { if err != nil {
switch err { switch err {
case handshake.ErrKeysDropped: case handshake.ErrKeysDropped:

View file

@ -508,7 +508,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
packetNumber: 0x1337, packetNumber: 0x1337,
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: hdr, hdr: hdr,
@ -531,7 +531,7 @@ var _ = Describe("Session", func() {
rcvTime := time.Now().Add(-10 * time.Second) rcvTime := time.Now().Add(-10 * time.Second)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed()) Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
packetNumber: 0x1337, packetNumber: 0x1337,
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: hdr, hdr: hdr,
@ -546,7 +546,7 @@ var _ = Describe("Session", func() {
}) })
It("drops a packet when unpacking fails", func() { It("drops a packet when unpacking fails", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -567,7 +567,7 @@ var _ = Describe("Session", func() {
}) })
It("closes the session when unpacking fails because the reserved bits were incorrect", func() { It("closes the session when unpacking fails because the reserved bits were incorrect", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -590,7 +590,7 @@ var _ = Describe("Session", func() {
It("ignores packets when unpacking fails for any other reason", func() { It("ignores packets when unpacking fails for any other reason", func() {
testErr := errors.New("test err") testErr := errors.New("test err")
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, testErr) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
@ -612,7 +612,7 @@ var _ = Describe("Session", func() {
}) })
It("rejects packets with empty payload", func() { It("rejects packets with empty payload", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
hdr: &wire.ExtendedHeader{}, hdr: &wire.ExtendedHeader{},
data: []byte{}, // no payload data: []byte{}, // no payload
}, nil) }, nil)
@ -675,7 +675,7 @@ var _ = Describe("Session", func() {
Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -698,7 +698,7 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1, PacketNumber: 1,
} }
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable)
packet := getPacket(hdr, nil) packet := getPacket(hdr, nil)
Expect(sess.handlePacketImpl(packet)).To(BeFalse()) Expect(sess.handlePacketImpl(packet)).To(BeFalse())
Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet})) Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet}))
@ -706,7 +706,7 @@ var _ = Describe("Session", func() {
Context("updating the remote address", func() { Context("updating the remote address", func() {
It("doesn't support connection migration", func() { It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{}, hdr: &wire.ExtendedHeader{},
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame
@ -746,7 +746,7 @@ var _ = Describe("Session", func() {
It("cuts packets to the right length", func() { It("cuts packets to the right length", func() {
hdrLen, packet := getPacketWithLength(sess.srcConnID, 456) hdrLen, packet := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen + 456 - 3)) Expect(data).To(HaveLen(hdrLen + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -758,7 +758,7 @@ var _ = Describe("Session", func() {
It("handles coalesced packets", func() { It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -766,7 +766,7 @@ var _ = Describe("Session", func() {
}, nil }, nil
}) })
hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -781,8 +781,8 @@ var _ = Describe("Session", func() {
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123) hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123)
gomock.InOrder( gomock.InOrder(
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable), unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable),
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) Expect(data).To(HaveLen(hdrLen2 + 123 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -801,7 +801,7 @@ var _ = Describe("Session", func() {
wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
Expect(sess.srcConnID).ToNot(Equal(wrongConnID)) Expect(sess.srcConnID).ToNot(Equal(wrongConnID))
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake, encryptionLevel: protocol.EncryptionHandshake,
@ -1563,7 +1563,7 @@ var _ = Describe("Client Session", func() {
It("changes the connection ID when receiving the first packet from the server", func() { It("changes the connection ID when receiving the first packet from the server", func() {
unpacker := NewMockUnpacker(mockCtrl) unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: protocol.Encryption1RTT, encryptionLevel: protocol.Encryption1RTT,
hdr: &wire.ExtendedHeader{Header: *hdr}, hdr: &wire.ExtendedHeader{Header: *hdr},
@ -1738,7 +1738,7 @@ var _ = Describe("Client Session", func() {
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1) packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1)
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial, encryptionLevel: protocol.EncryptionInitial,
hdr: hdr1, hdr: hdr1,
data: []byte{0}, // one PADDING frame data: []byte{0}, // one PADDING frame