use in-place decryption when unpacking

This commit is contained in:
Lucas Clemente 2016-07-26 18:33:48 +02:00
parent 658ceab877
commit daa328460f
4 changed files with 43 additions and 47 deletions

View file

@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/frames"
@ -22,14 +21,13 @@ type packetUnpacker struct {
aead crypto.AEAD
}
func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) {
ciphertext, _ := ioutil.ReadAll(r)
plaintext, err := u.aead.Open(nil, ciphertext, hdr.PacketNumber, publicHeaderBinary)
func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) {
data, err := u.aead.Open(data[:0], data, hdr.PacketNumber, publicHeaderBinary)
if err != nil {
// Wrap err in quicError so that public reset is sent by session
return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
}
r = bytes.NewReader(plaintext)
r := bytes.NewReader(data)
// read private flag byte, for QUIC Version < 34
var entropyBit bool

View file

@ -18,7 +18,7 @@ var _ = Describe("Packet unpacker", func() {
hdr *publicHeader
hdrBin []byte
aead crypto.AEAD
r *bytes.Reader
data []byte
buf *bytes.Buffer
)
@ -30,29 +30,29 @@ var _ = Describe("Packet unpacker", func() {
}
hdrBin = []byte{0x04, 0x4c, 0x01}
unpacker = &packetUnpacker{aead: aead}
r = nil
data = nil
buf = &bytes.Buffer{}
})
setReader := func(data []byte) {
setData := func(p []byte) {
if unpacker.version < protocol.Version34 { // add private flag
data = append([]byte{0x01}, data...)
p = append([]byte{0x01}, p...)
}
r = bytes.NewReader(aead.Seal(nil, data, 0, hdrBin))
data = aead.Seal(nil, p, 0, hdrBin)
}
It("returns an error for empty packets that don't have a private flag, for QUIC Version < 34", func() {
// don't use setReader here, since it adds a private flag
unpacker.version = protocol.Version34
setData(nil)
unpacker.version = protocol.Version33
r = bytes.NewReader(aead.Seal(nil, nil, 0, hdrBin))
_, err := unpacker.Unpack(hdrBin, hdr, r)
_, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload))
})
It("returns an error for empty packets that have a private flag, for QUIC Version < 34", func() {
unpacker.version = protocol.Version33
setReader(nil)
_, err := unpacker.Unpack(hdrBin, hdr, r)
setData(nil)
_, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload))
})
@ -61,8 +61,8 @@ var _ = Describe("Packet unpacker", func() {
f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"}
err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f}))
})
@ -74,8 +74,8 @@ var _ = Describe("Packet unpacker", func() {
}
err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f}))
})
@ -89,9 +89,9 @@ var _ = Describe("Packet unpacker", func() {
}
err := f.Write(buf, protocol.Version32)
Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes())
setData(buf.Bytes())
unpacker.version = protocol.Version32
packet, err := unpacker.Unpack(hdrBin, hdr, r)
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(HaveLen(1))
readFrame := packet.frames[0].(*frames.AckFrame)
@ -101,21 +101,21 @@ var _ = Describe("Packet unpacker", func() {
})
It("errors on CONGESTION_FEEDBACK frames", func() {
setReader([]byte{0x20})
_, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x20})
_, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError("unimplemented: CONGESTION_FEEDBACK"))
})
It("handles pad frames", func() {
setReader([]byte{0, 0, 0})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0, 0, 0})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(BeEmpty())
})
It("unpacks RST_STREAM frames", func() {
setReader([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.RstStreamFrame{
@ -130,21 +130,21 @@ var _ = Describe("Packet unpacker", func() {
f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"}
err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f}))
})
It("accepts GOAWAY frames", func() {
setReader([]byte{
setData([]byte{
0x03,
0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00,
0x03, 0x00,
'f', 'o', 'o',
})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.GoawayFrame{
@ -156,8 +156,8 @@ var _ = Describe("Packet unpacker", func() {
})
It("accepts WINDOW_UPDATE frames", func() {
setReader([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.WindowUpdateFrame{
@ -168,8 +168,8 @@ var _ = Describe("Packet unpacker", func() {
})
It("accepts BLOCKED frames", func() {
setReader([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.BlockedFrame{
@ -179,8 +179,8 @@ var _ = Describe("Packet unpacker", func() {
})
It("unpacks STOP_WAITING frames", func() {
setReader([]byte{0x06, 0xA4, 0x03})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x06, 0xA4, 0x03})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.StopWaitingFrame{
@ -191,8 +191,8 @@ var _ = Describe("Packet unpacker", func() {
})
It("accepts PING frames", func() {
setReader([]byte{0x07})
packet, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x07})
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{
&frames.PingFrame{},
@ -200,8 +200,8 @@ var _ = Describe("Packet unpacker", func() {
})
It("errors on invalid type", func() {
setReader([]byte{0x08})
_, err := unpacker.Unpack(hdrBin, hdr, r)
setData([]byte{0x08})
_, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x8"))
})
})

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"errors"
"fmt"
"sync"
@ -19,7 +18,7 @@ import (
)
type unpacker interface {
Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error)
Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error)
}
type receivedPacket struct {
@ -236,7 +235,6 @@ func (s *Session) maybeResetTimer() {
func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, data []byte) error {
s.lastNetworkActivityTime = time.Now()
r := bytes.NewReader(data)
// Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber(
@ -246,13 +244,13 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da
)
s.lastRcvdPacketNumber = hdr.PacketNumber
if utils.Debug() {
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, r.Size()+int64(len(hdr.Raw)), hdr.ConnectionID)
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
}
// TODO: Only do this after authenticating
s.conn.setCurrentRemoteAddr(remoteAddr)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, r)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil {
return err
}

View file

@ -38,7 +38,7 @@ func (*mockConnection) IP() net.IP { return nil }
type mockUnpacker struct{}
func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) {
func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{
entropyBit: false,
frames: nil,