diff --git a/buffer_pool.go b/buffer_pool.go index c890d32b..204eff2b 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -22,6 +22,23 @@ func (b *packetBuffer) Split() { b.refCount++ } +// Release decreases the refCount. +// It should be called when processing the packet is finished. +// When the refCount reaches 0, the packet buffer is put back into the pool. +func (b *packetBuffer) Release() { + if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { + panic("putPacketBuffer called with packet of wrong size!") + } + b.refCount-- + if b.refCount < 0 { + panic("negative packetBuffer refCount") + } + // only put the packetBuffer back if it's not used any more + if b.refCount == 0 { + bufferPool.Put(b) + } +} + var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { @@ -31,20 +48,6 @@ func getPacketBuffer() *packetBuffer { return buf } -func putPacketBuffer(buf *packetBuffer) { - if cap(buf.Slice) != int(protocol.MaxReceivePacketSize) { - panic("putPacketBuffer called with packet of wrong size!") - } - buf.refCount-- - if buf.refCount < 0 { - panic("negative packetBuffer refCount") - } - // only put the packetBuffer back if it's not used any more - if buf.refCount == 0 { - bufferPool.Put(buf) - } -} - func init() { bufferPool.New = func() interface{} { return &packetBuffer{ diff --git a/buffer_pool_test.go b/buffer_pool_test.go index ef6b4852..c49141d9 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -13,31 +13,31 @@ var _ = Describe("Buffer Pool", func() { Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize))) }) - It("puts buffers back", func() { + It("releases buffers", func() { buf := getPacketBuffer() - putPacketBuffer(buf) + buf.Release() }) It("panics if wrong-sized buffers are passed", func() { buf := getPacketBuffer() buf.Slice = make([]byte, 10) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + Expect(func() { buf.Release() }).To(Panic()) }) - It("panics if it is put pack twice", func() { + It("panics if it is released twice", func() { buf := getPacketBuffer() - putPacketBuffer(buf) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) }) - It("waits until all parts have been put back", func() { + It("waits until all parts have been released", func() { buf := getPacketBuffer() buf.Split() buf.Split() // now we have 3 parts - putPacketBuffer(buf) - putPacketBuffer(buf) - putPacketBuffer(buf) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + buf.Release() + buf.Release() + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 5c2a1948..ba810b44 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -169,7 +169,7 @@ func (h *packetHandlerMap) handlePacket( // We still need to process the packets that were successfully parsed before. } if len(packets) == 0 { - putPacketBuffer(buffer) + buffer.Release() return } h.handleParsedPackets(packets) diff --git a/server.go b/server.go index c665283d..3fb1ca7a 100644 --- a/server.go +++ b/server.go @@ -321,20 +321,20 @@ func (s *server) handlePacket(p *receivedPacket) { return } - putPacketBuffer(p.buffer) // TODO(#943): send Stateless Reset + p.buffer.Release() } func (s *server) handleInitial(p *receivedPacket) { s.logger.Debugf("<- Received Initial packet.") sess, connID, err := s.handleInitialImpl(p) if err != nil { - putPacketBuffer(p.buffer) + p.buffer.Release() s.logger.Errorf("Error occurred handling initial packet: %s", err) return } if sess == nil { // a retry was done - putPacketBuffer(p.buffer) + p.buffer.Release() return } // Don't put the packet buffer back if a new session was created. @@ -461,7 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { } func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { - defer putPacketBuffer(p.buffer) + defer p.buffer.Release() hdr := p.hdr s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) diff --git a/session.go b/session.go index 511a2b24..63aa3d9d 100644 --- a/session.go +++ b/session.go @@ -479,7 +479,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc defer func() { // Put back the packet buffer if the packet wasn't queued for later decryption. if !wasQueued { - putPacketBuffer(p.buffer) + p.buffer.Release() } }() @@ -962,7 +962,7 @@ func (s *session) sendPacket() (bool, error) { } func (s *session) sendPackedPacket(packet *packedPacket) error { - defer putPacketBuffer(packet.buffer) + defer packet.buffer.Release() s.logPacket(packet) return s.conn.Write(packet.raw) }