move parsing of the extended header to the unpacker

This commit is contained in:
Marten Seemann 2018-12-12 12:41:17 +06:30
parent c912860a94
commit aaea375fb6
5 changed files with 165 additions and 117 deletions

View file

@ -35,7 +35,7 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder {
}
// Unpack mocks base method
func (m *MockUnpacker) Unpack(arg0 *wire.ExtendedHeader, arg1 []byte) (*unpackedPacket, error) {
func (m *MockUnpacker) Unpack(arg0 *wire.Header, arg1 []byte) (*unpackedPacket, error) {
ret := m.ctrl.Call(m, "Unpack", arg0, arg1)
ret0, _ := ret[0].(*unpackedPacket)
ret1, _ := ret[1].(error)

View file

@ -11,7 +11,8 @@ import (
)
type unpackedPacket struct {
packetNumber protocol.PacketNumber
packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader
encryptionLevel protocol.EncryptionLevel
frames []wire.Frame
}
@ -40,11 +41,30 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
}
}
func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(r, u.version)
if err != nil {
return nil, fmt.Errorf("error parsing extended header: %s", err)
}
extHdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(extHdr.PacketNumberLen) {
return nil, fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", extHdr.Length, extHdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(extHdr.PacketNumberLen) < extHdr.Length {
return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(extHdr.PacketNumberLen), extHdr.Length)
}
data = data[:int(extHdr.Length)-int(extHdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
pn := protocol.DecodePacketNumber(
hdr.PacketNumberLen,
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
hdr.PacketNumber,
extHdr.PacketNumber,
)
buf := *getPacketBuffer()
@ -53,19 +73,18 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke
var decrypted []byte
var encryptionLevel protocol.EncryptionLevel
var err error
switch hdr.Type {
case protocol.PacketTypeInitial:
decrypted, err = u.aead.OpenInitial(buf, data, pn, hdr.Raw)
decrypted, err = u.aead.OpenInitial(buf, data, pn, extHdr.Raw)
encryptionLevel = protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
decrypted, err = u.aead.OpenHandshake(buf, data, pn, hdr.Raw)
decrypted, err = u.aead.OpenHandshake(buf, data, pn, extHdr.Raw)
encryptionLevel = protocol.EncryptionHandshake
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
decrypted, err = u.aead.Open1RTT(buf, data, pn, hdr.Raw)
decrypted, err = u.aead.Open1RTT(buf, data, pn, extHdr.Raw)
encryptionLevel = protocol.Encryption1RTT
}
if err != nil {
@ -81,6 +100,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpacke
}
return &unpackedPacket{
hdr: extHdr,
packetNumber: pn,
encryptionLevel: encryptionLevel,
frames: fs,

View file

@ -14,82 +14,187 @@ import (
)
var _ = Describe("Packet Unpacker", func() {
const version = protocol.VersionTLS
var (
unpacker *packetUnpacker
hdr *wire.ExtendedHeader
aead *MockQuicAEAD
connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
)
getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
buf := &bytes.Buffer{}
Expect(extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed())
hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), connID.Len())
Expect(err).ToNot(HaveOccurred())
return hdr, buf.Bytes()
}
BeforeEach(func() {
aead = NewMockQuicAEAD(mockCtrl)
hdr = &wire.ExtendedHeader{
PacketNumber: 10,
PacketNumberLen: 1,
Raw: []byte{0x04, 0x4c, 0x01},
}
unpacker = newPacketUnpacker(aead, protocol.VersionWhatever).(*packetUnpacker)
unpacker = newPacketUnpacker(aead, version).(*packetUnpacker)
})
It("errors if the packet doesn't contain any payload", func() {
data := []byte("foobar")
aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), hdr.PacketNumber, hdr.Raw).Return([]byte{}, nil)
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 42,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, []byte("foobar")...) // add some payload
// return an empty (unencrypted) payload
aead.EXPECT().Open1RTT(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{}, nil)
_, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload))
})
It("opens Initial packets", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeInitial
aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr, nil)
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
Version: version,
},
PacketNumber: 2,
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
aead.EXPECT().OpenInitial(gomock.Any(), []byte("foobar"), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
})
It("opens Handshake packets", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr, nil)
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
Version: version,
},
PacketNumber: 2,
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, []byte("foobar")...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake))
})
It("errors on packets that are smaller than the length in the packet header", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 1000,
DestConnectionID: connID,
Version: version,
},
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...)
_, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
It("errors when receiving a packet that has a length smaller than the packet number length", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
DestConnectionID: connID,
Type: protocol.PacketTypeHandshake,
Length: 3,
Version: protocol.VersionTLS,
},
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
_, err := unpacker.Unpack(hdr, hdrRaw)
Expect(err).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)"))
})
It("cuts packets to the right length", func() {
pnLen := protocol.PacketNumberLen2
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
DestConnectionID: connID,
Type: protocol.PacketTypeHandshake,
Length: 456,
Version: protocol.VersionTLS,
},
PacketNumberLen: pnLen,
}
payloadLen := 456 - int(pnLen)
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, payloadLen)...)
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
Expect(payload).To(HaveLen(payloadLen))
return []byte{0}, nil
})
_, err := unpacker.Unpack(hdr, data)
Expect(err).ToNot(HaveOccurred())
})
It("returns the error when unpacking fails", func() {
hdr.IsLongHeader = true
hdr.Type = protocol.PacketTypeHandshake
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len
DestConnectionID: connID,
Version: version,
},
PacketNumber: 2,
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err"))
_, err := unpacker.Unpack(hdr, nil)
_, err := unpacker.Unpack(hdr, hdrRaw)
Expect(err).To(MatchError(qerr.Error(qerr.DecryptionFailure, "test err")))
})
It("decodes the packet number", func() {
firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), firstHdr.PacketNumber, gomock.Any()).Return([]byte{0}, nil)
packet, err := unpacker.Unpack(firstHdr, nil)
packet, err := unpacker.Unpack(getHeader(firstHdr))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
secondHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38,
PacketNumberLen: 1,
}
// expect the call with the decoded packet number
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1338), gomock.Any()).Return([]byte{0}, nil)
packet, err = unpacker.Unpack(secondHdr, nil)
packet, err = unpacker.Unpack(getHeader(secondHdr))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
})
It("unpacks the frames", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
buf := &bytes.Buffer{}
(&wire.PingFrame{}).Write(buf, protocol.VersionWhatever)
(&wire.DataBlockedFrame{}).Write(buf, protocol.VersionWhatever)
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return(buf.Bytes(), nil)
packet, err := unpacker.Unpack(hdr, nil)
hdr, hdrRaw := getHeader(extHdr)
aead.EXPECT().Open1RTT(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).Return(buf.Bytes(), nil)
packet, err := unpacker.Unpack(hdr, append(hdrRaw, buf.Bytes()...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{&wire.PingFrame{}, &wire.DataBlockedFrame{}}))
})

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -22,7 +21,7 @@ import (
)
type unpacker interface {
Unpack(hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@ -483,27 +482,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
return nil
}
data := p.data
r := bytes.NewReader(data)
hdr, err := p.hdr.ParseExtended(r, s.version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
}
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
packet, err := s.unpacker.Unpack(hdr, data)
packet, err := s.unpacker.Unpack(p.hdr, p.data)
// if the decryption failed, this might be a packet sent by an attacker
if err != nil {
return err
@ -511,13 +490,13 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel)
hdr.Log(s.logger)
packet.hdr.Log(s.logger)
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID)
s.destConnID = hdr.SrcConnectionID
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", p.hdr.SrcConnectionID)
s.destConnID = p.hdr.SrcConnectionID
s.packer.ChangeDestConnectionID(s.destConnID)
}
@ -536,7 +515,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
// If this is a Retry packet, there's no need to send an ACK.
// The session will be closed and recreated as soon as the crypto setup processed the HRR.
if hdr.Type != protocol.PacketTypeRetry {
if p.hdr.Type != protocol.PacketTypeRetry {
isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, p.rcvTime, isRetransmittable); err != nil {
return err

View file

@ -533,62 +533,6 @@ var _ = Describe("Session", func() {
})).To(Succeed())
})
It("errors on packets that are smaller than the length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 1000,
DestConnectionID: connID,
Version: protocol.VersionTLS,
},
PacketNumberLen: protocol.PacketNumberLen2,
}
data := getData(hdr)
data = append(data, make([]byte, 500-2 /* for packet number length */)...)
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
It("errors when receiving a packet that has a length smaller than the packet number length", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
DestConnectionID: connID,
Type: protocol.PacketTypeHandshake,
Length: 3,
Version: protocol.VersionTLS,
},
PacketNumberLen: protocol.PacketNumberLen4,
}
data := getData(hdr)
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)"))
})
It("cuts packets to the right length", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
pnLen := protocol.PacketNumberLen2
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
DestConnectionID: connID,
Type: protocol.PacketTypeHandshake,
Length: 456,
Version: protocol.VersionTLS,
},
PacketNumberLen: pnLen,
}
payloadLen := 456 - int(pnLen)
data := getData(hdr)
data = append(data, make([]byte, payloadLen)...)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
Expect(data).To(HaveLen(payloadLen))
return &unpackedPacket{}, nil
})
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(Succeed())
})
Context("updating the remote address", func() {
It("doesn't support connection migration", func() {
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)