implement a buffer pool for STREAM frames

This commit is contained in:
Marten Seemann 2019-09-04 16:45:39 +07:00
parent 326ec9e16e
commit 5ea33cd31e
70 changed files with 193 additions and 48 deletions

View file

@ -1 +1 @@
s ��� ���

View file

@ -1 +1 @@
Ą ´ŮÉ ÝťRß×›MvB›az źź ;Ą[ ŔÖL�…5„Ëŕp›X?aÓuĽ´ôů)áŹÚžo‚ĺNtŽ�çžK˝oăLÜş„>čÖ>ŚOţëęTmŹ¬Ý¬Î.˘‡|UyϢǎ d ´ŮÉ ÝťRß×›MvB›az źź ;Ą[ ŔÖL�…5„Ëŕp›X?aÓuĽ´ôů)áŹÚžo‚ĺNtŽ�çžK˝oăLÜş„>čÖ>ŚOţëęTmŹ¬Ý¬Î.˘‡|UyϢǎ

View file

@ -1 +1 @@
�74 W

View file

@ -1 +1 @@
¦z‹˙˙˙˙˙˙˙˙ �fL�l�Q5

View file

@ -1 +1,2 @@
?イlコQ 6
u*.租マ�メ

View file

@ -1 +1 @@
��������� wK

View file

@ -1 +1 @@
5��M ��������

View file

@ -1 +1 @@
p������� H�

Binary file not shown.

View file

@ -1 +1 @@
M�� ���u��������

View file

@ -1 +1 @@
/' Z•ňG&

View file

@ -1 +1 @@
�R� •`Ο

View file

@ -1 +1 @@
P������� 

View file

@ -1 +1 @@
o/ ���������

View file

@ -1 +1 @@
��g�������� �5z

View file

@ -1 +1 @@
îE k€��������

View file

@ -1 +1 @@
�L� В•Ј.

View file

@ -1 +1 @@
 ���M

View file

@ -1 +1 @@
C .T^'?ãWkoí'ÿ‰tºÀÊýšÐV’±6ç8–MýÇž�SCsfýf×Oì‰I·#n w .T^'?ăWkoí'˙‰tşŔĘýšĐV’±6ç8–MýÇžŤSCsfýf×Oě‰I·#n

View file

@ -1 +1 @@
� ��ҕ�H���G(����+^��  @d=€çËż‡ŇŮ) ¤yĄ[{säď˛ACN6ÔO}y8Ťˇž\ţ�!cu�&ß9ř3i\˘*¦Q•H�ĽÜĎf­=s¨á«ź�‡}&;»'őĽI{ÖTGMˆ*´Ýh› M ®h �ÖĄčtÄä®

View file

@ -1 +1 @@
ËF­F­ÓÔ˙ćŐ™VQ__S‡ä”f™Ł®ÁJ|¶®» w˝[ŢmĽ!źvť)n§vDݸ#¨/ş

View file

@ -1 +1,2 @@
� g���������Ku!�g } gʆ+«ÔKu!_²‚ˆ]2`ñÞ—‚ƒÔ š6ùl ”FãíM¦F©®‹O§´ü: ºúuí2z†¸°ÃšñÏÒ³µêyS:¿D�,G�2`u3�YÈÌð�£ž’�J¼¼±i9{·4çï
nñ…Më_ä$þ÷¬!‘õ—^)*¢¸ª|aͳ3sëç,' ˜À!—Úæ·2ÃQßf�‡N,Ÿàœ¨`çâH0?ô#áHíS�h_vò§˜¼dÞ]²†K*ÓÂlæ8#BvZÖ–å-÷`öÃF^) Ü¤jàÕp;ÂqõõJd1Q?ØB›Jẟ*ƒ†®r³fè,*ÆM¿F

View file

@ -1 +1 @@
с ��В_џџџџџџџџЮІшПF`ёо�� ¦ ^Эяяяяяяяъ$в0S

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -57,6 +57,7 @@ func Fuzz(data []byte) int {
// We accept empty STREAM frames, but we don't write them. // We accept empty STREAM frames, but we don't write them.
if sf, ok := f.(*wire.StreamFrame); ok { if sf, ok := f.(*wire.StreamFrame); ok {
if sf.DataLen() == 0 { if sf.DataLen() == 0 {
sf.PutBack()
continue continue
} }
} }
@ -68,6 +69,9 @@ func Fuzz(data []byte) int {
if f.Length(version) != protocol.ByteCount(frameLen) { if f.Length(version) != protocol.ByteCount(frameLen) {
panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version)))
} }
if sf, ok := f.(*wire.StreamFrame); ok {
sf.PutBack()
}
} }
if b.Len() > parsedLen { if b.Len() > parsedLen {
panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", b.Len(), parsedLen)) panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", b.Len(), parsedLen))

View file

@ -88,6 +88,12 @@ func getFrames() []wire.Frame {
Data: getRandomData(50), Data: getRandomData(50),
FinBit: true, FinBit: true,
}, },
&wire.StreamFrame{ // STREAM frame at non-zero offset, with data and FIN bit. Long enough to use the buffer.
StreamID: protocol.StreamID(getRandomNumber()),
Offset: protocol.ByteCount(getRandomNumber()),
Data: getRandomData(2 * protocol.MinStreamFrameBufferSize),
FinBit: true,
},
&wire.StreamFrame{ // STREAM frame at maximum offset, with FIN bit &wire.StreamFrame{ // STREAM frame at maximum offset, with FIN bit
StreamID: protocol.StreamID(getRandomNumber()), StreamID: protocol.StreamID(getRandomNumber()),
Offset: protocol.MaxByteCount - 5, Offset: protocol.MaxByteCount - 5,

View file

@ -81,6 +81,11 @@ const MaxNonAckElicitingAcks = 19
// prevents DoS attacks against the streamFrameSorter // prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000 const MaxStreamFrameSorterGaps = 1000
// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame
// that we use the buffer for. This protects against a DoS where an attacker would send us
// very small STREAM frames to consume a lot of memory.
const MinStreamFrameBufferSize = 128
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. // MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received. // This limits the size of the ClientHello and Certificates that can be received.
const MaxCryptoStreamOffset = 16 * (1 << 10) const MaxCryptoStreamOffset = 16 * (1 << 10)

33
internal/wire/pool.go Normal file
View file

@ -0,0 +1,33 @@
package wire
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var pool sync.Pool
func init() {
pool.New = func() interface{} {
return &StreamFrame{
Data: make([]byte, 0, protocol.MaxReceivePacketSize),
fromPool: true,
}
}
}
func getStreamFrame() *StreamFrame {
f := pool.Get().(*StreamFrame)
return f
}
func putStreamFrame(f *StreamFrame) {
if !f.fromPool {
return
}
if protocol.ByteCount(cap(f.Data)) != protocol.MaxReceivePacketSize {
panic("wire.PutStreamFrame called with packet of wrong size!")
}
pool.Put(f)
}

View file

@ -0,0 +1,24 @@
package wire
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Pool", func() {
It("gets and puts STREAM frames", func() {
f := getStreamFrame()
putStreamFrame(f)
})
It("panics when putting a STREAM frame with a wrong capacity", func() {
f := getStreamFrame()
f.Data = []byte("foobar")
Expect(func() { putStreamFrame(f) }).To(Panic())
})
It("accepts STREAM frames not from the buffer, but ignores them", func() {
f := &StreamFrame{Data: []byte("foobar")}
putStreamFrame(f)
})
})

View file

@ -17,6 +17,8 @@ type StreamFrame struct {
DataLenPresent bool DataLenPresent bool
Offset protocol.ByteCount Offset protocol.ByteCount
Data []byte Data []byte
fromPool bool
} }
func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) {
@ -26,45 +28,53 @@ func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF
} }
hasOffset := typeByte&0x4 > 0 hasOffset := typeByte&0x4 > 0
frame := &StreamFrame{ fin := typeByte&0x1 > 0
FinBit: typeByte&0x1 > 0, hasDataLen := typeByte&0x2 > 0
DataLenPresent: typeByte&0x2 > 0,
}
streamID, err := utils.ReadVarInt(r) streamID, err := utils.ReadVarInt(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
frame.StreamID = protocol.StreamID(streamID) var offset uint64
if hasOffset { if hasOffset {
offset, err := utils.ReadVarInt(r) offset, err = utils.ReadVarInt(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
frame.Offset = protocol.ByteCount(offset)
} }
var dataLen uint64 var dataLen uint64
if frame.DataLenPresent { if hasDataLen {
var err error var err error
dataLen, err = utils.ReadVarInt(r) dataLen, err = utils.ReadVarInt(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the packet contents would result in EOF when attempting to READ
if dataLen > uint64(r.Len()) {
return nil, io.EOF
}
} else { } else {
// The rest of the packet is data // The rest of the packet is data
dataLen = uint64(r.Len()) dataLen = uint64(r.Len())
} }
var frame *StreamFrame
if dataLen < protocol.MinStreamFrameBufferSize {
frame = &StreamFrame{Data: make([]byte, dataLen)}
} else {
frame = getStreamFrame()
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
frame.StreamID = protocol.StreamID(streamID)
frame.Offset = protocol.ByteCount(offset)
frame.FinBit = fin
frame.DataLenPresent = hasDataLen
if dataLen != 0 { if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil { if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err return nil, err
} }
} }
@ -156,16 +166,25 @@ func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version pro
if n == 0 { if n == 0 {
return nil, true return nil, true
} }
newFrame := &StreamFrame{
FinBit: false,
StreamID: f.StreamID,
Offset: f.Offset,
Data: f.Data[:n],
DataLenPresent: f.DataLenPresent,
}
f.Data = f.Data[n:] new := getStreamFrame()
new.StreamID = f.StreamID
new.Offset = f.Offset
new.FinBit = false
new.DataLenPresent = f.DataLenPresent
// swap the data slices
new.Data, f.Data = f.Data, new.Data
new.fromPool, f.fromPool = f.fromPool, new.fromPool
f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n]
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n f.Offset += n
return newFrame, true return new, true
}
func (f *StreamFrame) PutBack() {
putStreamFrame(f)
} }

View file

@ -2,6 +2,7 @@ package wire
import ( import (
"bytes" "bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -78,6 +79,16 @@ var _ = Describe("STREAM frame", func() {
Expect(err).To(MatchError("FRAME_ENCODING_ERROR: stream data overflows maximum offset")) Expect(err).To(MatchError("FRAME_ENCODING_ERROR: stream data overflows maximum offset"))
}) })
It("rejects frames that claim to be longer than the packet size", func() {
data := []byte{0x8 ^ 0x2}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, encodeVarInt(uint64(protocol.MaxReceivePacketSize)+1)...) // data length
data = append(data, make([]byte, protocol.MaxReceivePacketSize+1)...)
r := bytes.NewReader(data)
_, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).To(Equal(io.EOF))
})
It("errors on EOFs", func() { It("errors on EOFs", func() {
data := []byte{0x8 ^ 0x4 ^ 0x2} data := []byte{0x8 ^ 0x4 ^ 0x2}
data = append(data, encodeVarInt(0x12345)...) // stream ID data = append(data, encodeVarInt(0x12345)...) // stream ID
@ -93,6 +104,40 @@ var _ = Describe("STREAM frame", func() {
}) })
}) })
Context("using the buffer", func() {
It("uses the buffer for long STREAM frames", func() {
data := []byte{0x8}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...)
r := bytes.NewReader(data)
frame, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345)))
Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)))
Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize))
Expect(frame.FinBit).To(BeFalse())
Expect(frame.fromPool).To(BeTrue())
Expect(r.Len()).To(BeZero())
Expect(frame.PutBack).ToNot(Panic())
})
It("doesn't use the buffer for short STREAM frames", func() {
data := []byte{0x8}
data = append(data, encodeVarInt(0x12345)...) // stream ID
data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...)
r := bytes.NewReader(data)
frame, err := parseStreamFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345)))
Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)))
Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1))
Expect(frame.FinBit).To(BeFalse())
Expect(frame.fromPool).To(BeFalse())
Expect(r.Len()).To(BeZero())
Expect(frame.PutBack).ToNot(Panic())
})
})
Context("when writing", func() { Context("when writing", func() {
It("writes a frame without offset", func() { It("writes a frame without offset", func() {
f := &StreamFrame{ f := &StreamFrame{
@ -294,6 +339,7 @@ var _ = Describe("STREAM frame", func() {
frame, needsSplit = f.MaybeSplitOffFrame(f.Length(versionIETFFrames)-1, versionIETFFrames) frame, needsSplit = f.MaybeSplitOffFrame(f.Length(versionIETFFrames)-1, versionIETFFrames)
Expect(needsSplit).To(BeTrue()) Expect(needsSplit).To(BeTrue())
Expect(frame.DataLen()).To(BeEquivalentTo(99)) Expect(frame.DataLen()).To(BeEquivalentTo(99))
f.PutBack()
}) })
It("keeps the data len", func() { It("keeps the data len", func() {
@ -353,6 +399,7 @@ var _ = Describe("STREAM frame", func() {
Expect(f).To(BeNil()) Expect(f).To(BeNil())
} }
for i := minFrameSize; i < size; i++ { for i := minFrameSize; i < size; i++ {
f.fromPool = false
f.Data = make([]byte, size) f.Data = make([]byte, size)
f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames)
Expect(needsSplit).To(BeTrue()) Expect(needsSplit).To(BeTrue())
@ -376,6 +423,7 @@ var _ = Describe("STREAM frame", func() {
} }
var frameOneByteTooSmallCounter int var frameOneByteTooSmallCounter int
for i := minFrameSize; i < size; i++ { for i := minFrameSize; i < size; i++ {
f.fromPool = false
f.Data = make([]byte, size) f.Data = make([]byte, size)
newFrame, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) newFrame, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames)
Expect(needsSplit).To(BeTrue()) Expect(needsSplit).To(BeTrue())

View file

@ -412,7 +412,11 @@ var _ = Describe("Packet packer", func() {
frameParser := wire.NewFrameParser(packer.version) frameParser := wire.NewFrameParser(packer.version)
frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f)) Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
sf := frame.(*wire.StreamFrame)
Expect(sf.StreamID).To(Equal(f.StreamID))
Expect(sf.FinBit).To(Equal(f.FinBit))
Expect(sf.Data).To(BeEmpty())
Expect(r.Len()).To(BeZero()) Expect(r.Len()).To(BeZero())
}) })