mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Add reserve support for buffer
This commit is contained in:
parent
c17babe0ba
commit
8b43ec8058
7 changed files with 104 additions and 107 deletions
|
@ -4,10 +4,9 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
|
@ -17,14 +16,15 @@ type Buffer struct {
|
|||
data []byte
|
||||
start int
|
||||
end int
|
||||
capacity int
|
||||
refs atomic.Int32
|
||||
managed bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func New() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(BufferSize),
|
||||
capacity: BufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ func New() *Buffer {
|
|||
func NewPacket() *Buffer {
|
||||
return &Buffer{
|
||||
data: Get(UDPBufferSize),
|
||||
capacity: UDPBufferSize,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
@ -42,39 +43,28 @@ func NewSize(size int) *Buffer {
|
|||
} else if size > 65535 {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
capacity: size,
|
||||
}
|
||||
}
|
||||
return &Buffer{
|
||||
data: Get(size),
|
||||
capacity: size,
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: use New instead.
|
||||
func StackNew() *Buffer {
|
||||
return New()
|
||||
}
|
||||
|
||||
// Deprecated: use NewPacket instead.
|
||||
func StackNewPacket() *Buffer {
|
||||
return NewPacket()
|
||||
}
|
||||
|
||||
// Deprecated: use NewSize instead.
|
||||
func StackNewSize(size int) *Buffer {
|
||||
return NewSize(size)
|
||||
}
|
||||
|
||||
func As(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
end: len(data),
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
capacity: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {
|
|||
|
||||
func (b *Buffer) Extend(n int) []byte {
|
||||
end := b.end + n
|
||||
if end > cap(b.data) {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
|
||||
if end > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",end ", b.end, ", need ", n))
|
||||
}
|
||||
ext := b.data[b.end:end]
|
||||
b.end = end
|
||||
|
@ -111,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], data)
|
||||
n = copy(b.data[b.end:b.capacity], data)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Buffer) ExtendHeader(n int) []byte {
|
||||
if b.start < n {
|
||||
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ",start ", b.start, ", need ", n))
|
||||
}
|
||||
b.start -= n
|
||||
return b.data[b.start : b.start+n]
|
||||
|
@ -171,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
|||
}
|
||||
|
||||
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
|
||||
if b.end+size > b.Cap() {
|
||||
if b.end+size > b.capacity {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
|
||||
|
@ -208,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n = copy(b.data[b.end:], s)
|
||||
n = copy(b.data[b.end:b.capacity], s)
|
||||
b.end += n
|
||||
return
|
||||
}
|
||||
|
@ -223,7 +213,7 @@ func (b *Buffer) WriteZero() error {
|
|||
}
|
||||
|
||||
func (b *Buffer) WriteZeroN(n int) error {
|
||||
if b.end+n > b.Cap() {
|
||||
if b.end+n > b.capacity {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
for i := b.end; i < b.end+n; i++ {
|
||||
|
@ -272,9 +262,24 @@ func (b *Buffer) Resize(start, end int) {
|
|||
b.end = b.start + end
|
||||
}
|
||||
|
||||
func (b *Buffer) Reserve(n int) {
|
||||
if n > b.capacity {
|
||||
panic(F.ToString("buffer overflow: capacity ", b.capacity, ", need ", n))
|
||||
}
|
||||
b.capacity -= n
|
||||
}
|
||||
|
||||
func (b *Buffer) OverCap(n int) {
|
||||
if b.capacity+n > len(b.data) {
|
||||
panic(F.ToString("buffer overflow: capacity ", len(b.data), ", need ", b.capacity+n))
|
||||
}
|
||||
b.capacity += n
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.start = 0
|
||||
b.end = 0
|
||||
b.capacity = len(b.data)
|
||||
}
|
||||
|
||||
// Deprecated: use Reset instead.
|
||||
|
@ -291,19 +296,19 @@ func (b *Buffer) DecRef() {
|
|||
}
|
||||
|
||||
func (b *Buffer) Release() {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
if b.refs.Load() > 0 {
|
||||
return
|
||||
}
|
||||
common.Must(Put(b.data))
|
||||
*b = Buffer{closed: true}
|
||||
*b = Buffer{}
|
||||
}
|
||||
|
||||
func (b *Buffer) Leak() {
|
||||
if debug.Enabled {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
if b == nil || !b.managed {
|
||||
return
|
||||
}
|
||||
refs := b.refs.Load()
|
||||
|
@ -317,14 +322,6 @@ func (b *Buffer) Leak() {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
return &Buffer{
|
||||
data: b.data[b.start:b.end],
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Start() int {
|
||||
return b.start
|
||||
}
|
||||
|
@ -342,7 +339,7 @@ func (b *Buffer) Bytes() []byte {
|
|||
}
|
||||
|
||||
func (b *Buffer) Slice() []byte {
|
||||
return b.data
|
||||
return b.data[:b.capacity]
|
||||
}
|
||||
|
||||
func (b *Buffer) From(n int) []byte {
|
||||
|
@ -362,11 +359,11 @@ func (b *Buffer) Index(start int) []byte {
|
|||
}
|
||||
|
||||
func (b *Buffer) FreeLen() int {
|
||||
return b.Cap() - b.end
|
||||
return b.capacity - b.end
|
||||
}
|
||||
|
||||
func (b *Buffer) FreeBytes() []byte {
|
||||
return b.data[b.end:b.Cap()]
|
||||
return b.data[b.end:b.capacity]
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
|
@ -374,7 +371,7 @@ func (b *Buffer) IsEmpty() bool {
|
|||
}
|
||||
|
||||
func (b *Buffer) IsFull() bool {
|
||||
return b.end == b.Cap()
|
||||
return b.end == b.capacity
|
||||
}
|
||||
|
||||
func (b *Buffer) ToOwned() *Buffer {
|
||||
|
@ -382,5 +379,6 @@ func (b *Buffer) ToOwned() *Buffer {
|
|||
copy(n.data[b.start:b.end], b.data[b.start:b.end])
|
||||
n.start = b.start
|
||||
n.end = b.end
|
||||
n.capacity = b.capacity
|
||||
return n
|
||||
}
|
||||
|
|
|
@ -81,12 +81,11 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
|||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
|
@ -94,8 +93,8 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
|||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
|
@ -126,10 +125,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = source.ReadBuffer(readBuffer)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
err = source.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
@ -138,8 +136,8 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
|
@ -263,16 +261,15 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
|
|||
var destination M.Socksaddr
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = source.ReadPacket(readBuffer)
|
||||
buffer.Resize(frontHeadroom, 0)
|
||||
buffer.Reserve(rearHeadroom)
|
||||
destination, err = source.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
buffer.OverCap(rearHeadroom)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
|
|
|
@ -104,22 +104,22 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
|||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer, readBuffer := w.options.NewBuffer()
|
||||
buffer := w.options.NewBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Resize(readBuffer.Start(), readN)
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 {
|
||||
if readN == 0 && w.readErr == nil {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -168,15 +168,16 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
|
|||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer, readBuffer := w.options.NewPacketBuffer()
|
||||
buffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
if readN > 0 {
|
||||
buffer.Resize(readBuffer.Start(), readN)
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
|
@ -189,7 +190,6 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
|
|||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
|
|
@ -19,30 +19,35 @@ func (o ReadWaitOptions) NeedHeadroom() bool {
|
|||
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
|
||||
return o.newBuffer(buf.BufferSize)
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
|
||||
return o.newBuffer(buf.UDPBufferSize)
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
|
||||
var bufferSize int
|
||||
if o.MTU > 0 {
|
||||
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
||||
} else {
|
||||
bufferSize = defaultBufferSize
|
||||
}
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
if o.RearHeadroom > 0 {
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom])
|
||||
} else {
|
||||
readBuffer = buffer
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
if o.FrontHeadroom > 0 {
|
||||
buffer.Resize(o.FrontHeadroom, 0)
|
||||
}
|
||||
if o.RearHeadroom > 0 {
|
||||
buffer.Reserve(o.RearHeadroom)
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) PostReturn(buffer *buf.Buffer) {
|
||||
if o.RearHeadroom > 0 {
|
||||
buffer.OverCap(o.RearHeadroom)
|
||||
}
|
||||
readBuffer.Resize(o.FrontHeadroom, 0)
|
||||
return
|
||||
}
|
||||
|
||||
type ReadWaiter interface {
|
||||
|
|
|
@ -33,17 +33,16 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
|
|||
case isClosedChan(p.readDeadline.wait()):
|
||||
return nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
var readBuffer *buf.Buffer
|
||||
select {
|
||||
case bw := <-p.rdRx:
|
||||
buffer, readBuffer = p.readWaitOptions.NewBuffer()
|
||||
buffer = p.readWaitOptions.NewBuffer()
|
||||
var nr int
|
||||
nr, err = readBuffer.Write(bw)
|
||||
nr, err = buffer.Write(bw)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
p.readWaitOptions.PostReturn(buffer)
|
||||
p.rdTx <- nr
|
||||
return
|
||||
case <-p.localDone:
|
||||
|
|
|
@ -19,15 +19,14 @@ func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
|
|||
select {
|
||||
case p := <-c.data:
|
||||
if c.readWaitOptions.NeedHeadroom() {
|
||||
var readBuffer *buf.Buffer
|
||||
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = readBuffer.Write(p.data.Bytes())
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.Write(p.data.Bytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
p.data.Release()
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
} else {
|
||||
buffer = p.data
|
||||
}
|
||||
|
|
|
@ -28,13 +28,12 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
var readBuffer *buf.Buffer
|
||||
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = readBuffer.ReadFullFrom(c.Conn, int(length))
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
|
||||
}
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue