use in-place decryption when unpacking

This commit is contained in:
Lucas Clemente 2016-07-26 18:33:48 +02:00
parent 658ceab877
commit daa328460f
4 changed files with 43 additions and 47 deletions

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
@ -22,14 +21,13 @@ type packetUnpacker struct {
aead crypto.AEAD aead crypto.AEAD
} }
func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) {
ciphertext, _ := ioutil.ReadAll(r) data, err := u.aead.Open(data[:0], data, hdr.PacketNumber, publicHeaderBinary)
plaintext, err := u.aead.Open(nil, ciphertext, hdr.PacketNumber, publicHeaderBinary)
if err != nil { if err != nil {
// Wrap err in quicError so that public reset is sent by session // Wrap err in quicError so that public reset is sent by session
return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
} }
r = bytes.NewReader(plaintext) r := bytes.NewReader(data)
// read private flag byte, for QUIC Version < 34 // read private flag byte, for QUIC Version < 34
var entropyBit bool var entropyBit bool

View file

@ -18,7 +18,7 @@ var _ = Describe("Packet unpacker", func() {
hdr *publicHeader hdr *publicHeader
hdrBin []byte hdrBin []byte
aead crypto.AEAD aead crypto.AEAD
r *bytes.Reader data []byte
buf *bytes.Buffer buf *bytes.Buffer
) )
@ -30,29 +30,29 @@ var _ = Describe("Packet unpacker", func() {
} }
hdrBin = []byte{0x04, 0x4c, 0x01} hdrBin = []byte{0x04, 0x4c, 0x01}
unpacker = &packetUnpacker{aead: aead} unpacker = &packetUnpacker{aead: aead}
r = nil data = nil
buf = &bytes.Buffer{} buf = &bytes.Buffer{}
}) })
setReader := func(data []byte) { setData := func(p []byte) {
if unpacker.version < protocol.Version34 { // add private flag if unpacker.version < protocol.Version34 { // add private flag
data = append([]byte{0x01}, data...) p = append([]byte{0x01}, p...)
} }
r = bytes.NewReader(aead.Seal(nil, data, 0, hdrBin)) data = aead.Seal(nil, p, 0, hdrBin)
} }
It("returns an error for empty packets that don't have a private flag, for QUIC Version < 34", func() { It("returns an error for empty packets that don't have a private flag, for QUIC Version < 34", func() {
// don't use setReader here, since it adds a private flag unpacker.version = protocol.Version34
setData(nil)
unpacker.version = protocol.Version33 unpacker.version = protocol.Version33
r = bytes.NewReader(aead.Seal(nil, nil, 0, hdrBin)) _, err := unpacker.Unpack(hdrBin, hdr, data)
_, err := unpacker.Unpack(hdrBin, hdr, r)
Expect(err).To(MatchError(qerr.MissingPayload)) Expect(err).To(MatchError(qerr.MissingPayload))
}) })
It("returns an error for empty packets that have a private flag, for QUIC Version < 34", func() { It("returns an error for empty packets that have a private flag, for QUIC Version < 34", func() {
unpacker.version = protocol.Version33 unpacker.version = protocol.Version33
setReader(nil) setData(nil)
_, err := unpacker.Unpack(hdrBin, hdr, r) _, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError(qerr.MissingPayload)) Expect(err).To(MatchError(qerr.MissingPayload))
}) })
@ -61,8 +61,8 @@ var _ = Describe("Packet unpacker", func() {
f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"} f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"}
err := f.Write(buf, 0) err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes()) setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f})) Expect(packet.frames).To(Equal([]frames.Frame{f}))
}) })
@ -74,8 +74,8 @@ var _ = Describe("Packet unpacker", func() {
} }
err := f.Write(buf, 0) err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes()) setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f})) Expect(packet.frames).To(Equal([]frames.Frame{f}))
}) })
@ -89,9 +89,9 @@ var _ = Describe("Packet unpacker", func() {
} }
err := f.Write(buf, protocol.Version32) err := f.Write(buf, protocol.Version32)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes()) setData(buf.Bytes())
unpacker.version = protocol.Version32 unpacker.version = protocol.Version32
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames).To(HaveLen(1))
readFrame := packet.frames[0].(*frames.AckFrame) readFrame := packet.frames[0].(*frames.AckFrame)
@ -101,21 +101,21 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("errors on CONGESTION_FEEDBACK frames", func() { It("errors on CONGESTION_FEEDBACK frames", func() {
setReader([]byte{0x20}) setData([]byte{0x20})
_, err := unpacker.Unpack(hdrBin, hdr, r) _, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError("unimplemented: CONGESTION_FEEDBACK")) Expect(err).To(MatchError("unimplemented: CONGESTION_FEEDBACK"))
}) })
It("handles pad frames", func() { It("handles pad frames", func() {
setReader([]byte{0, 0, 0}) setData([]byte{0, 0, 0})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(BeEmpty()) Expect(packet.frames).To(BeEmpty())
}) })
It("unpacks RST_STREAM frames", func() { It("unpacks RST_STREAM frames", func() {
setReader([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13}) setData([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.RstStreamFrame{ &frames.RstStreamFrame{
@ -130,21 +130,21 @@ var _ = Describe("Packet unpacker", func() {
f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"} f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"}
err := f.Write(buf, 0) err := f.Write(buf, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
setReader(buf.Bytes()) setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{f})) Expect(packet.frames).To(Equal([]frames.Frame{f}))
}) })
It("accepts GOAWAY frames", func() { It("accepts GOAWAY frames", func() {
setReader([]byte{ setData([]byte{
0x03, 0x03,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x03, 0x00, 0x03, 0x00,
'f', 'o', 'o', 'f', 'o', 'o',
}) })
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.GoawayFrame{ &frames.GoawayFrame{
@ -156,8 +156,8 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("accepts WINDOW_UPDATE frames", func() { It("accepts WINDOW_UPDATE frames", func() {
setReader([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA}) setData([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.WindowUpdateFrame{ &frames.WindowUpdateFrame{
@ -168,8 +168,8 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("accepts BLOCKED frames", func() { It("accepts BLOCKED frames", func() {
setReader([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE}) setData([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.BlockedFrame{ &frames.BlockedFrame{
@ -179,8 +179,8 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("unpacks STOP_WAITING frames", func() { It("unpacks STOP_WAITING frames", func() {
setReader([]byte{0x06, 0xA4, 0x03}) setData([]byte{0x06, 0xA4, 0x03})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.StopWaitingFrame{ &frames.StopWaitingFrame{
@ -191,8 +191,8 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("accepts PING frames", func() { It("accepts PING frames", func() {
setReader([]byte{0x07}) setData([]byte{0x07})
packet, err := unpacker.Unpack(hdrBin, hdr, r) packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]frames.Frame{ Expect(packet.frames).To(Equal([]frames.Frame{
&frames.PingFrame{}, &frames.PingFrame{},
@ -200,8 +200,8 @@ var _ = Describe("Packet unpacker", func() {
}) })
It("errors on invalid type", func() { It("errors on invalid type", func() {
setReader([]byte{0x08}) setData([]byte{0x08})
_, err := unpacker.Unpack(hdrBin, hdr, r) _, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x8")) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x8"))
}) })
}) })

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@ -19,7 +18,7 @@ import (
) )
type unpacker interface { type unpacker interface {
Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error)
} }
type receivedPacket struct { type receivedPacket struct {
@ -236,7 +235,6 @@ func (s *Session) maybeResetTimer() {
func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, data []byte) error { func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, data []byte) error {
s.lastNetworkActivityTime = time.Now() s.lastNetworkActivityTime = time.Now()
r := bytes.NewReader(data)
// Calculate packet number // Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumber = protocol.InferPacketNumber(
@ -246,13 +244,13 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da
) )
s.lastRcvdPacketNumber = hdr.PacketNumber s.lastRcvdPacketNumber = hdr.PacketNumber
if utils.Debug() { if utils.Debug() {
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, r.Size()+int64(len(hdr.Raw)), hdr.ConnectionID) utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
} }
// TODO: Only do this after authenticating // TODO: Only do this after authenticating
s.conn.setCurrentRemoteAddr(remoteAddr) s.conn.setCurrentRemoteAddr(remoteAddr)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, r) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil { if err != nil {
return err return err
} }

View file

@ -38,7 +38,7 @@ func (*mockConnection) IP() net.IP { return nil }
type mockUnpacker struct{} type mockUnpacker struct{}
func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) { func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) {
return &unpackedPacket{ return &unpackedPacket{
entropyBit: false, entropyBit: false,
frames: nil, frames: nil,