Add reserve support for buffer

This commit is contained in:
世界 2023-12-07 10:22:29 +08:00
parent c17babe0ba
commit 8b43ec8058
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 104 additions and 107 deletions

View file

@ -4,35 +4,36 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"net" "net"
"strconv"
"sync/atomic"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/debug" "github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
) )
type Buffer struct { type Buffer struct {
data []byte data []byte
start int start int
end int end int
refs atomic.Int32 capacity int
managed bool refs atomic.Int32
closed bool managed bool
} }
func New() *Buffer { func New() *Buffer {
return &Buffer{ return &Buffer{
data: Get(BufferSize), data: Get(BufferSize),
managed: true, capacity: BufferSize,
managed: true,
} }
} }
func NewPacket() *Buffer { func NewPacket() *Buffer {
return &Buffer{ return &Buffer{
data: Get(UDPBufferSize), data: Get(UDPBufferSize),
managed: true, capacity: UDPBufferSize,
managed: true,
} }
} }
@ -41,40 +42,29 @@ func NewSize(size int) *Buffer {
return &Buffer{} return &Buffer{}
} else if size > 65535 { } else if size > 65535 {
return &Buffer{ return &Buffer{
data: make([]byte, size), data: make([]byte, size),
capacity: size,
} }
} }
return &Buffer{ return &Buffer{
data: Get(size), data: Get(size),
managed: true, capacity: size,
managed: true,
} }
} }
// Deprecated: use New instead.
func StackNew() *Buffer {
return New()
}
// Deprecated: use NewPacket instead.
func StackNewPacket() *Buffer {
return NewPacket()
}
// Deprecated: use NewSize instead.
func StackNewSize(size int) *Buffer {
return NewSize(size)
}
func As(data []byte) *Buffer { func As(data []byte) *Buffer {
return &Buffer{ return &Buffer{
data: data, data: data,
end: len(data), end: len(data),
capacity: len(data),
} }
} }
func With(data []byte) *Buffer { func With(data []byte) *Buffer {
return &Buffer{ return &Buffer{
data: data, data: data,
capacity: len(data),
} }
} }
@ -88,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
func (b *Buffer) Extend(n int) []byte { func (b *Buffer) Extend(n int) []byte {
end := b.end + n end := b.end + n
if end > cap(b.data) { if end > b.capacity {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n)) panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
} }
ext := b.data[b.end:end] ext := b.data[b.end:end]
b.end = end b.end = end
@ -111,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
n = copy(b.data[b.end:], data) n = copy(b.data[b.end:b.capacity], data)
b.end += n b.end += n
return return
} }
func (b *Buffer) ExtendHeader(n int) []byte { func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n { if b.start < n {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n)) panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
} }
b.start -= n b.start -= n
return b.data[b.start : b.start+n] return b.data[b.start : b.start+n]
@ -171,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
} }
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) { func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.Cap() { if b.end+size > b.capacity {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
n, err = io.ReadFull(r, b.data[b.end:b.end+size]) n, err = io.ReadFull(r, b.data[b.end:b.end+size])
@ -208,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
n = copy(b.data[b.end:], s) n = copy(b.data[b.end:b.capacity], s)
b.end += n b.end += n
return return
} }
@ -223,7 +213,7 @@ func (b *Buffer) WriteZero() error {
} }
func (b *Buffer) WriteZeroN(n int) error { func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.Cap() { if b.end+n > b.capacity {
return io.ErrShortBuffer return io.ErrShortBuffer
} }
for i := b.end; i < b.end+n; i++ { for i := b.end; i < b.end+n; i++ {
@ -272,9 +262,24 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end b.end = b.start + end
} }
func (b *Buffer) Reserve(n int) {
if n > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
}
b.capacity -= n
}
func (b *Buffer) OverCap(n int) {
if b.capacity+n > len(b.data) {
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
}
b.capacity += n
}
func (b *Buffer) Reset() { func (b *Buffer) Reset() {
b.start = 0 b.start = 0
b.end = 0 b.end = 0
b.capacity = len(b.data)
} }
// Deprecated: use Reset instead. // Deprecated: use Reset instead.
@ -291,19 +296,19 @@ func (b *Buffer) DecRef() {
} }
func (b *Buffer) Release() { func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed { if b == nil || !b.managed {
return return
} }
if b.refs.Load() > 0 { if b.refs.Load() > 0 {
return return
} }
common.Must(Put(b.data)) common.Must(Put(b.data))
*b = Buffer{closed: true} *b = Buffer{}
} }
func (b *Buffer) Leak() { func (b *Buffer) Leak() {
if debug.Enabled { if debug.Enabled {
if b == nil || b.closed || !b.managed { if b == nil || !b.managed {
return return
} }
refs := b.refs.Load() refs := b.refs.Load()
@ -317,14 +322,6 @@ func (b *Buffer) Leak() {
} }
} }
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
}
}
func (b *Buffer) Start() int { func (b *Buffer) Start() int {
return b.start return b.start
} }
@ -342,7 +339,7 @@ func (b *Buffer) Bytes() []byte {
} }
func (b *Buffer) Slice() []byte { func (b *Buffer) Slice() []byte {
return b.data return b.data[:b.capacity]
} }
func (b *Buffer) From(n int) []byte { func (b *Buffer) From(n int) []byte {
@ -362,11 +359,11 @@ func (b *Buffer) Index(start int) []byte {
} }
func (b *Buffer) FreeLen() int { func (b *Buffer) FreeLen() int {
return b.Cap() - b.end return b.capacity - b.end
} }
func (b *Buffer) FreeBytes() []byte { func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.Cap()] return b.data[b.end:b.capacity]
} }
func (b *Buffer) IsEmpty() bool { func (b *Buffer) IsEmpty() bool {
@ -374,7 +371,7 @@ func (b *Buffer) IsEmpty() bool {
} }
func (b *Buffer) IsFull() bool { func (b *Buffer) IsFull() bool {
return b.end == b.Cap() return b.end == b.capacity
} }
func (b *Buffer) ToOwned() *Buffer { func (b *Buffer) ToOwned() *Buffer {
@ -382,5 +379,6 @@ func (b *Buffer) ToOwned() *Buffer {
copy(n.data[b.start:b.end], b.data[b.start:b.end]) copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start n.start = b.start
n.end = b.end n.end = b.end
n.capacity = b.capacity
return n return n
} }

View file

@ -81,12 +81,11 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
defer buffer.DecRef() defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination) frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination)
readBufferRaw := buffer.Slice() buffer.Resize(frontHeadroom, 0)
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) buffer.Reserve(rearHeadroom)
var notFirstTime bool var notFirstTime bool
for { for {
readBuffer.Resize(frontHeadroom, 0) err = source.ReadBuffer(buffer)
err = source.ReadBuffer(readBuffer)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
err = nil err = nil
@ -94,8 +93,8 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
} }
return return
} }
dataLen := readBuffer.Len() dataLen := buffer.Len()
buffer.Resize(readBuffer.Start(), dataLen) buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
if !notFirstTime { if !notFirstTime {
@ -126,10 +125,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
var notFirstTime bool var notFirstTime bool
for { for {
buffer := buf.NewSize(bufferSize) buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice() buffer.Resize(frontHeadroom, 0)
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) buffer.Reserve(rearHeadroom)
readBuffer.Resize(frontHeadroom, 0) err = source.ReadBuffer(buffer)
err = source.ReadBuffer(readBuffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -138,8 +136,8 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
} }
return return
} }
dataLen := readBuffer.Len() dataLen := buffer.Len()
buffer.Resize(readBuffer.Start(), dataLen) buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
@ -263,16 +261,15 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
var destination M.Socksaddr var destination M.Socksaddr
for { for {
buffer := buf.NewSize(bufferSize) buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice() buffer.Resize(frontHeadroom, 0)
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) buffer.Reserve(rearHeadroom)
readBuffer.Resize(frontHeadroom, 0) destination, err = source.ReadPacket(buffer)
destination, err = source.ReadPacket(readBuffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
dataLen := readBuffer.Len() dataLen := buffer.Len()
buffer.Resize(readBuffer.Start(), dataLen) buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination) err = destinationConn.WritePacket(buffer, destination)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()

View file

@ -104,22 +104,22 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options w.options = options
w.readFunc = func(fd uintptr) (done bool) { w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewBuffer() buffer := w.options.NewBuffer()
var readN int var readN int
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes()) readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 { if readN > 0 {
buffer.Resize(readBuffer.Start(), readN) buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else { } else {
buffer.Release() buffer.Release()
buffer = nil
} }
if w.readErr == syscall.EAGAIN { if w.readErr == syscall.EAGAIN {
return false return false
} }
if readN == 0 { if readN == 0 && w.readErr == nil {
w.readErr = io.EOF w.readErr = io.EOF
} }
w.buffer = buffer
return true return true
} }
return false return false
@ -168,15 +168,16 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options w.options = options
w.readFunc = func(fd uintptr) (done bool) { w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewPacketBuffer() buffer := w.options.NewPacketBuffer()
var readN int var readN int
var from syscall.Sockaddr var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0) readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if readN > 0 { if readN > 0 {
buffer.Resize(readBuffer.Start(), readN) buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else { } else {
buffer.Release() buffer.Release()
buffer = nil
} }
if w.readErr == syscall.EAGAIN { if w.readErr == syscall.EAGAIN {
return false return false
@ -189,7 +190,6 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
} }
} }
w.buffer = buffer
return true return true
} }
return false return false

View file

@ -19,30 +19,35 @@ func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0 return o.FrontHeadroom > 0 || o.RearHeadroom > 0
} }
func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize) return o.newBuffer(buf.BufferSize)
} }
func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize) return o.newBuffer(buf.UDPBufferSize)
} }
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) { func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
var bufferSize int var bufferSize int
if o.MTU > 0 { if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
} else { } else {
bufferSize = defaultBufferSize bufferSize = defaultBufferSize
} }
buffer = buf.NewSize(bufferSize) buffer := buf.NewSize(bufferSize)
if o.RearHeadroom > 0 { if o.FrontHeadroom > 0 {
readBufferRaw := buffer.Slice() buffer.Resize(o.FrontHeadroom, 0)
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom]) }
} else { if o.RearHeadroom > 0 {
readBuffer = buffer buffer.Reserve(o.RearHeadroom)
}
return buffer
}
func (o ReadWaitOptions) PostReturn(buffer *buf.Buffer) {
if o.RearHeadroom > 0 {
buffer.OverCap(o.RearHeadroom)
} }
readBuffer.Resize(o.FrontHeadroom, 0)
return
} }
type ReadWaiter interface { type ReadWaiter interface {

View file

@ -33,17 +33,16 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
case isClosedChan(p.readDeadline.wait()): case isClosedChan(p.readDeadline.wait()):
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
var readBuffer *buf.Buffer
select { select {
case bw := <-p.rdRx: case bw := <-p.rdRx:
buffer, readBuffer = p.readWaitOptions.NewBuffer() buffer = p.readWaitOptions.NewBuffer()
var nr int var nr int
nr, err = readBuffer.Write(bw) nr, err = buffer.Write(bw)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
buffer.Resize(readBuffer.Start(), readBuffer.Len()) p.readWaitOptions.PostReturn(buffer)
p.rdTx <- nr p.rdTx <- nr
return return
case <-p.localDone: case <-p.localDone:

View file

@ -19,15 +19,14 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
select { select {
case p := <-c.data: case p := <-c.data:
if c.readWaitOptions.NeedHeadroom() { if c.readWaitOptions.NeedHeadroom() {
var readBuffer *buf.Buffer buffer = c.readWaitOptions.NewPacketBuffer()
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() _, err = buffer.Write(p.data.Bytes())
_, err = readBuffer.Write(p.data.Bytes())
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
c.readWaitOptions.PostReturn(buffer)
p.data.Release() p.data.Release()
buffer.Resize(readBuffer.Start(), readBuffer.Len())
} else { } else {
buffer = p.data buffer = p.data
} }

View file

@ -28,13 +28,12 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
if err != nil { if err != nil {
return return
} }
var readBuffer *buf.Buffer buffer = c.readWaitOptions.NewPacketBuffer()
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() _, err = buffer.ReadFullFrom(c.Conn, int(length))
_, err = readBuffer.ReadFullFrom(c.Conn, int(length))
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return nil, M.Socksaddr{}, E.Cause(err, "UoT read") return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
} }
buffer.Resize(readBuffer.Start(), readBuffer.Len()) c.readWaitOptions.PostReturn(buffer)
return return
} }