From 33bf79c7357bf585ef9d4552b7da05f7034b5982 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Mar 2019 13:14:46 +0900 Subject: [PATCH] fix packet buffer usage when handling coalesced packets --- buffer_pool.go | 34 ++++++++++++++++++++++++++-------- buffer_pool_test.go | 14 ++++++++++---- session.go | 3 ++- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index 204eff2b..d6fb7673 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -22,23 +22,41 @@ func (b *packetBuffer) Split() { b.refCount++ } -// Release decreases the refCount. -// It should be called when processing the packet is finished. -// When the refCount reaches 0, the packet buffer is put back into the pool. -func (b *packetBuffer) Release() { - if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { - panic("putPacketBuffer called with packet of wrong size!") - } +// Decrement decrements the reference counter. +// It doesn't put the buffer back into the pool. +func (b *packetBuffer) Decrement() { b.refCount-- if b.refCount < 0 { panic("negative packetBuffer refCount") } +} + +// MaybeRelease puts the packet buffer back into the pool, +// if the reference counter already reached 0. +func (b *packetBuffer) MaybeRelease() { // only put the packetBuffer back if it's not used any more if b.refCount == 0 { - bufferPool.Put(b) + b.putBack() } } +// Release puts back the packet buffer into the pool. +// It should be called when processing is definitely finished. +func (b *packetBuffer) Release() { + b.Decrement() + if b.refCount != 0 { + panic("packetBuffer refCount not zero") + } + b.putBack() +} + +func (b *packetBuffer) putBack() { + if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { + panic("putPacketBuffer called with packet of wrong size!") + } + bufferPool.Put(b) +} + var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { diff --git a/buffer_pool_test.go b/buffer_pool_test.go index c49141d9..3ee7037e 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -30,14 +30,20 @@ var _ = Describe("Buffer Pool", func() { Expect(func() { buf.Release() }).To(Panic()) }) + It("panics if it is decremented too many times", func() { + buf := getPacketBuffer() + buf.Decrement() + Expect(func() { buf.Decrement() }).To(Panic()) + }) + It("waits until all parts have been released", func() { buf := getPacketBuffer() buf.Split() buf.Split() // now we have 3 parts - buf.Release() - buf.Release() - buf.Release() - Expect(func() { buf.Release() }).To(Panic()) + buf.Decrement() + buf.Decrement() + buf.Decrement() + Expect(func() { buf.Decrement() }).To(Panic()) }) }) diff --git a/session.go b/session.go index feaadc29..c4f2309b 100644 --- a/session.go +++ b/session.go @@ -515,6 +515,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool { } p.data = rest } + p.buffer.MaybeRelease() return processed } @@ -524,7 +525,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / defer func() { // Put back the packet buffer if the packet wasn't queued for later decryption. if !wasQueued { - p.buffer.Release() + p.buffer.Decrement() } }()