diff --git a/packet_unpacker.go b/packet_unpacker.go index dd9650a2..eb8313b5 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/frames" @@ -22,14 +21,13 @@ type packetUnpacker struct { aead crypto.AEAD } -func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) { - ciphertext, _ := ioutil.ReadAll(r) - plaintext, err := u.aead.Open(nil, ciphertext, hdr.PacketNumber, publicHeaderBinary) +func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) { + data, err := u.aead.Open(data[:0], data, hdr.PacketNumber, publicHeaderBinary) if err != nil { // Wrap err in quicError so that public reset is sent by session return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) } - r = bytes.NewReader(plaintext) + r := bytes.NewReader(data) // read private flag byte, for QUIC Version < 34 var entropyBit bool diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 519309a7..d112492f 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -18,7 +18,7 @@ var _ = Describe("Packet unpacker", func() { hdr *publicHeader hdrBin []byte aead crypto.AEAD - r *bytes.Reader + data []byte buf *bytes.Buffer ) @@ -30,29 +30,29 @@ var _ = Describe("Packet unpacker", func() { } hdrBin = []byte{0x04, 0x4c, 0x01} unpacker = &packetUnpacker{aead: aead} - r = nil + data = nil buf = &bytes.Buffer{} }) - setReader := func(data []byte) { + setData := func(p []byte) { if unpacker.version < protocol.Version34 { // add private flag - data = append([]byte{0x01}, data...) + p = append([]byte{0x01}, p...) } - r = bytes.NewReader(aead.Seal(nil, data, 0, hdrBin)) + data = aead.Seal(nil, p, 0, hdrBin) } It("returns an error for empty packets that don't have a private flag, for QUIC Version < 34", func() { - // don't use setReader here, since it adds a private flag + unpacker.version = protocol.Version34 + setData(nil) unpacker.version = protocol.Version33 - r = bytes.NewReader(aead.Seal(nil, nil, 0, hdrBin)) - _, err := unpacker.Unpack(hdrBin, hdr, r) + _, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).To(MatchError(qerr.MissingPayload)) }) It("returns an error for empty packets that have a private flag, for QUIC Version < 34", func() { unpacker.version = protocol.Version33 - setReader(nil) - _, err := unpacker.Unpack(hdrBin, hdr, r) + setData(nil) + _, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).To(MatchError(qerr.MissingPayload)) }) @@ -61,8 +61,8 @@ var _ = Describe("Packet unpacker", func() { f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"} err := f.Write(buf, 0) Expect(err).ToNot(HaveOccurred()) - setReader(buf.Bytes()) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{f})) }) @@ -74,8 +74,8 @@ var _ = Describe("Packet unpacker", func() { } err := f.Write(buf, 0) Expect(err).ToNot(HaveOccurred()) - setReader(buf.Bytes()) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{f})) }) @@ -89,9 +89,9 @@ var _ = Describe("Packet unpacker", func() { } err := f.Write(buf, protocol.Version32) Expect(err).ToNot(HaveOccurred()) - setReader(buf.Bytes()) + setData(buf.Bytes()) unpacker.version = protocol.Version32 - packet, err := unpacker.Unpack(hdrBin, hdr, r) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(HaveLen(1)) readFrame := packet.frames[0].(*frames.AckFrame) @@ -101,21 +101,21 @@ var _ = Describe("Packet unpacker", func() { }) It("errors on CONGESTION_FEEDBACK frames", func() { - setReader([]byte{0x20}) - _, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x20}) + _, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).To(MatchError("unimplemented: CONGESTION_FEEDBACK")) }) It("handles pad frames", func() { - setReader([]byte{0, 0, 0}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0, 0, 0}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(BeEmpty()) }) It("unpacks RST_STREAM frames", func() { - setReader([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x01, 0xEF, 0xBE, 0xAD, 0xDE, 0x44, 0x33, 0x22, 0x11, 0xAD, 0xFB, 0xCA, 0xDE, 0x34, 0x12, 0x37, 0x13}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.RstStreamFrame{ @@ -130,21 +130,21 @@ var _ = Describe("Packet unpacker", func() { f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"} err := f.Write(buf, 0) Expect(err).ToNot(HaveOccurred()) - setReader(buf.Bytes()) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{f})) }) It("accepts GOAWAY frames", func() { - setReader([]byte{ + setData([]byte{ 0x03, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 'f', 'o', 'o', }) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.GoawayFrame{ @@ -156,8 +156,8 @@ var _ = Describe("Packet unpacker", func() { }) It("accepts WINDOW_UPDATE frames", func() { - setReader([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x04, 0xEF, 0xBE, 0xAD, 0xDE, 0x37, 0x13, 0, 0, 0, 0, 0xFE, 0xCA}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.WindowUpdateFrame{ @@ -168,8 +168,8 @@ var _ = Describe("Packet unpacker", func() { }) It("accepts BLOCKED frames", func() { - setReader([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x05, 0xEF, 0xBE, 0xAD, 0xDE}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.BlockedFrame{ @@ -179,8 +179,8 @@ var _ = Describe("Packet unpacker", func() { }) It("unpacks STOP_WAITING frames", func() { - setReader([]byte{0x06, 0xA4, 0x03}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x06, 0xA4, 0x03}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.StopWaitingFrame{ @@ -191,8 +191,8 @@ var _ = Describe("Packet unpacker", func() { }) It("accepts PING frames", func() { - setReader([]byte{0x07}) - packet, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x07}) + packet, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).ToNot(HaveOccurred()) Expect(packet.frames).To(Equal([]frames.Frame{ &frames.PingFrame{}, @@ -200,8 +200,8 @@ var _ = Describe("Packet unpacker", func() { }) It("errors on invalid type", func() { - setReader([]byte{0x08}) - _, err := unpacker.Unpack(hdrBin, hdr, r) + setData([]byte{0x08}) + _, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x8")) }) }) diff --git a/session.go b/session.go index a7d25b31..cfd6ffd3 100644 --- a/session.go +++ b/session.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "errors" "fmt" "sync" @@ -19,7 +18,7 @@ import ( ) type unpacker interface { - Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) + Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) } type receivedPacket struct { @@ -236,7 +235,6 @@ func (s *Session) maybeResetTimer() { func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, data []byte) error { s.lastNetworkActivityTime = time.Now() - r := bytes.NewReader(data) // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( @@ -246,13 +244,13 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da ) s.lastRcvdPacketNumber = hdr.PacketNumber if utils.Debug() { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, r.Size()+int64(len(hdr.Raw)), hdr.ConnectionID) + utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) } // TODO: Only do this after authenticating s.conn.setCurrentRemoteAddr(remoteAddr) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, r) + packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) if err != nil { return err } diff --git a/session_test.go b/session_test.go index e39ce9aa..42ba6236 100644 --- a/session_test.go +++ b/session_test.go @@ -38,7 +38,7 @@ func (*mockConnection) IP() net.IP { return nil } type mockUnpacker struct{} -func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, r *bytes.Reader) (*unpackedPacket, error) { +func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *publicHeader, data []byte) (*unpackedPacket, error) { return &unpackedPacket{ entropyBit: false, frames: nil,