Fix async buffer usage

This commit is contained in:
世界 2022-06-16 01:03:45 +08:00
parent ea019b3655
commit 9d88431791
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 31 additions and 8 deletions

View file

@ -21,8 +21,9 @@ type Buffer struct {
data []byte data []byte
start int start int
end int end int
managed bool
refs int32 refs int32
managed bool
closed bool
} }
func New() *Buffer { func New() *Buffer {
@ -97,6 +98,10 @@ func With(data []byte) *Buffer {
} }
} }
func (b *Buffer) Closed() bool {
return b.closed
}
func (b *Buffer) Byte(index int) byte { func (b *Buffer) Byte(index int) byte {
return b.data[b.start+index] return b.data[b.start+index]
} }
@ -339,14 +344,14 @@ func (b *Buffer) DecRef() {
} }
func (b *Buffer) Release() { func (b *Buffer) Release() {
if b == nil || b.data == nil || !b.managed { if b == nil || b.closed || !b.managed {
return return
} }
if atomic.LoadInt32(&b.refs) > 0 { if atomic.LoadInt32(&b.refs) > 0 {
return return
} }
common.Must(Put(b.data)) common.Must(Put(b.data))
*b = Buffer{} *b = Buffer{closed: true}
} }
func (b *Buffer) Cut(start int, end int) *Buffer { func (b *Buffer) Cut(start int, end int) *Buffer {

View file

@ -2,6 +2,7 @@ package bufio
import ( import (
"io" "io"
"os"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -21,10 +22,14 @@ func NewBufferedReader(upstream io.Reader, buffer *buf.Buffer) *BufferedReader {
} }
func (r *BufferedReader) Read(p []byte) (n int, err error) { func (r *BufferedReader) Read(p []byte) (n int, err error) {
if r.buffer.Closed() {
return 0, os.ErrClosed
}
if r.buffer.IsEmpty() { if r.buffer.IsEmpty() {
r.buffer.Reset() r.buffer.Reset()
err = r.upstream.ReadBuffer(r.buffer) err = r.upstream.ReadBuffer(r.buffer)
if err != nil { if err != nil {
r.buffer.Release()
return return
} }
} }
@ -32,21 +37,34 @@ func (r *BufferedReader) Read(p []byte) (n int, err error) {
} }
func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error { func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error {
if r.buffer.Closed() {
return os.ErrClosed
}
var err error
if r.buffer.IsEmpty() { if r.buffer.IsEmpty() {
r.buffer.Reset() r.buffer.Reset()
err := r.upstream.ReadBuffer(r.buffer) err = r.upstream.ReadBuffer(r.buffer)
if err != nil { if err != nil {
r.buffer.Release()
return err return err
} }
} }
if r.buffer.Len() > buffer.FreeLen() { if r.buffer.Len() > buffer.FreeLen() {
return common.Error(buffer.ReadFullFrom(r.buffer, buffer.FreeLen())) err = common.Error(buffer.ReadFullFrom(r.buffer, buffer.FreeLen()))
} else { } else {
return common.Error(buffer.ReadFullFrom(r.buffer, r.buffer.Len())) err = common.Error(buffer.ReadFullFrom(r.buffer, r.buffer.Len()))
} }
if err != nil {
r.buffer.Release()
}
return err
} }
func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) { func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
if r.buffer.Closed() {
return 0, os.ErrClosed
}
defer r.buffer.Release()
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer) return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer)
} }

View file

@ -242,7 +242,7 @@ func CopyPacketWithSrcBufferTimeout(dest N.PacketWriter, src N.ThreadSafePacketR
func CopyPacketWithPool(dest N.PacketWriter, src N.PacketReader) (n int64, err error) { func CopyPacketWithPool(dest N.PacketWriter, src N.PacketReader) (n int64, err error) {
var destination M.Socksaddr var destination M.Socksaddr
for { for {
buffer := buf.New() buffer := buf.NewPacket()
destination, err = src.ReadPacket(buffer) destination, err = src.ReadPacket(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
@ -261,7 +261,7 @@ func CopyPacketWithPool(dest N.PacketWriter, src N.PacketReader) (n int64, err e
func CopyPacketWithPoolTimeout(dest N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) { func CopyPacketWithPoolTimeout(dest N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
var destination M.Socksaddr var destination M.Socksaddr
for { for {
buffer := buf.New() buffer := buf.NewPacket()
err = src.SetReadDeadline(time.Now().Add(timeout)) err = src.SetReadDeadline(time.Now().Add(timeout))
if err != nil { if err != nil {
return return