Add reserve support for buffer

This commit is contained in:
世界 2023-12-07 10:22:29 +08:00
parent c17babe0ba
commit 8b43ec8058
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 104 additions and 107 deletions

View file

@ -4,10 +4,9 @@ import (
"crypto/rand"
"io"
"net"
"strconv"
"sync/atomic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
@ -17,14 +16,15 @@ type Buffer struct {
data []byte
start int
end int
capacity int
refs atomic.Int32
managed bool
closed bool
}
func New() *Buffer {
return &Buffer{
data: Get(BufferSize),
capacity: BufferSize,
managed: true,
}
}
@ -32,6 +32,7 @@ func New() *Buffer {
func NewPacket() *Buffer {
return &Buffer{
data: Get(UDPBufferSize),
capacity: UDPBufferSize,
managed: true,
}
}
@ -42,39 +43,28 @@ func NewSize(size int) *Buffer {
} else if size > 65535 {
return &Buffer{
data: make([]byte, size),
capacity: size,
}
}
return &Buffer{
data: Get(size),
capacity: size,
managed: true,
}
}
// Deprecated: use New instead.
func StackNew() *Buffer {
return New()
}
// Deprecated: use NewPacket instead.
func StackNewPacket() *Buffer {
return NewPacket()
}
// Deprecated: use NewSize instead.
func StackNewSize(size int) *Buffer {
return NewSize(size)
}
func As(data []byte) *Buffer {
return &Buffer{
data: data,
end: len(data),
capacity: len(data),
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,
capacity: len(data),
}
}
@ -88,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
func (b *Buffer) Extend(n int) []byte {
end := b.end + n
if end > cap(b.data) {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
if end > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
}
ext := b.data[b.end:end]
b.end = end
@ -111,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], data)
n = copy(b.data[b.end:b.capacity], data)
b.end += n
return
}
func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
}
b.start -= n
return b.data[b.start : b.start+n]
@ -171,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
}
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.Cap() {
if b.end+size > b.capacity {
return 0, io.ErrShortBuffer
}
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
@ -208,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], s)
n = copy(b.data[b.end:b.capacity], s)
b.end += n
return
}
@ -223,7 +213,7 @@ func (b *Buffer) WriteZero() error {
}
func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.Cap() {
if b.end+n > b.capacity {
return io.ErrShortBuffer
}
for i := b.end; i < b.end+n; i++ {
@ -272,9 +262,24 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end
}
func (b *Buffer) Reserve(n int) {
if n > b.capacity {
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
}
b.capacity -= n
}
func (b *Buffer) OverCap(n int) {
if b.capacity+n > len(b.data) {
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
}
b.capacity += n
}
func (b *Buffer) Reset() {
b.start = 0
b.end = 0
b.capacity = len(b.data)
}
// Deprecated: use Reset instead.
@ -291,19 +296,19 @@ func (b *Buffer) DecRef() {
}
func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
if b.refs.Load() > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{closed: true}
*b = Buffer{}
}
func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
refs := b.refs.Load()
@ -317,14 +322,6 @@ func (b *Buffer) Leak() {
}
}
func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
return &Buffer{
data: b.data[b.start:b.end],
}
}
func (b *Buffer) Start() int {
return b.start
}
@ -342,7 +339,7 @@ func (b *Buffer) Bytes() []byte {
}
func (b *Buffer) Slice() []byte {
return b.data
return b.data[:b.capacity]
}
func (b *Buffer) From(n int) []byte {
@ -362,11 +359,11 @@ func (b *Buffer) Index(start int) []byte {
}
func (b *Buffer) FreeLen() int {
return b.Cap() - b.end
return b.capacity - b.end
}
func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.Cap()]
return b.data[b.end:b.capacity]
}
func (b *Buffer) IsEmpty() bool {
@ -374,7 +371,7 @@ func (b *Buffer) IsEmpty() bool {
}
func (b *Buffer) IsFull() bool {
return b.end == b.Cap()
return b.end == b.capacity
}
func (b *Buffer) ToOwned() *Buffer {
@ -382,5 +379,6 @@ func (b *Buffer) ToOwned() *Buffer {
copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start
n.end = b.end
n.capacity = b.capacity
return n
}

View file

@ -81,12 +81,11 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
err = source.ReadBuffer(buffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
@ -94,8 +93,8 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
@ -126,10 +125,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
@ -138,8 +136,8 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
@ -263,16 +261,15 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
var destination M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = source.ReadPacket(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()

View file

@ -104,22 +104,22 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewBuffer()
buffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 {
if readN == 0 && w.readErr == nil {
w.readErr = io.EOF
}
w.buffer = buffer
return true
}
return false
@ -168,15 +168,16 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewPacketBuffer()
buffer := w.options.NewPacketBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
@ -189,7 +190,6 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.buffer = buffer
return true
}
return false

View file

@ -19,30 +19,35 @@ func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
}
func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize)
}
func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize)
}
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) {
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
var bufferSize int
if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
} else {
bufferSize = defaultBufferSize
}
buffer = buf.NewSize(bufferSize)
if o.RearHeadroom > 0 {
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom])
} else {
readBuffer = buffer
buffer := buf.NewSize(bufferSize)
if o.FrontHeadroom > 0 {
buffer.Resize(o.FrontHeadroom, 0)
}
if o.RearHeadroom > 0 {
buffer.Reserve(o.RearHeadroom)
}
return buffer
}
func (o ReadWaitOptions) PostReturn(buffer *buf.Buffer) {
if o.RearHeadroom > 0 {
buffer.OverCap(o.RearHeadroom)
}
readBuffer.Resize(o.FrontHeadroom, 0)
return
}
type ReadWaiter interface {

View file

@ -33,17 +33,16 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
case isClosedChan(p.readDeadline.wait()):
return nil, os.ErrDeadlineExceeded
}
var readBuffer *buf.Buffer
select {
case bw := <-p.rdRx:
buffer, readBuffer = p.readWaitOptions.NewBuffer()
buffer = p.readWaitOptions.NewBuffer()
var nr int
nr, err = readBuffer.Write(bw)
nr, err = buffer.Write(bw)
if err != nil {
buffer.Release()
return
}
buffer.Resize(readBuffer.Start(), readBuffer.Len())
p.readWaitOptions.PostReturn(buffer)
p.rdTx <- nr
return
case <-p.localDone:

View file

@ -19,15 +19,14 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
select {
case p := <-c.data:
if c.readWaitOptions.NeedHeadroom() {
var readBuffer *buf.Buffer
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
_, err = readBuffer.Write(p.data.Bytes())
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.Write(p.data.Bytes())
if err != nil {
buffer.Release()
return
}
c.readWaitOptions.PostReturn(buffer)
p.data.Release()
buffer.Resize(readBuffer.Start(), readBuffer.Len())
} else {
buffer = p.data
}

View file

@ -28,13 +28,12 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
if err != nil {
return
}
var readBuffer *buf.Buffer
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
_, err = readBuffer.ReadFullFrom(c.Conn, int(length))
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
}
buffer.Resize(readBuffer.Start(), readBuffer.Len())
c.readWaitOptions.PostReturn(buffer)
return
}