mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Merge ThreadSafeReader into ReadWaiter interface
This commit is contained in:
parent
335ac47e45
commit
0e0da26649
13 changed files with 258 additions and 285 deletions
|
@ -8,14 +8,16 @@ import (
|
|||
"sync/atomic"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
start int
|
||||
end int
|
||||
refs int32
|
||||
refs atomic.Int32
|
||||
managed bool
|
||||
closed bool
|
||||
}
|
||||
|
@ -281,24 +283,40 @@ func (b *Buffer) FullReset() {
|
|||
}
|
||||
|
||||
func (b *Buffer) IncRef() {
|
||||
atomic.AddInt32(&b.refs, 1)
|
||||
b.refs.Add(1)
|
||||
}
|
||||
|
||||
func (b *Buffer) DecRef() {
|
||||
atomic.AddInt32(&b.refs, -1)
|
||||
b.refs.Add(-1)
|
||||
}
|
||||
|
||||
func (b *Buffer) Release() {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
return
|
||||
}
|
||||
if atomic.LoadInt32(&b.refs) > 0 {
|
||||
if b.refs.Load() > 0 {
|
||||
return
|
||||
}
|
||||
common.Must(Put(b.data))
|
||||
*b = Buffer{closed: true}
|
||||
}
|
||||
|
||||
func (b *Buffer) Leak() {
|
||||
if debug.Enabled {
|
||||
if b == nil || b.closed || !b.managed {
|
||||
return
|
||||
}
|
||||
refs := b.refs.Load()
|
||||
if refs == 0 {
|
||||
panic("leaking buffer")
|
||||
} else {
|
||||
panic(F.ToString("leaking buffer with ", refs, " references"))
|
||||
}
|
||||
} else {
|
||||
b.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Cut(start int, end int) *Buffer {
|
||||
b.start += start
|
||||
b.end = len(b.data) - end
|
||||
|
|
|
@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
|
|||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
|
@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
|
|||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -57,19 +56,21 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
safeSrc := N.IsSafeReader(source)
|
||||
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
|
||||
}
|
||||
}
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destination),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
|
@ -113,38 +114,6 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
|||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = source.ReadBufferThreadSafe()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
|
@ -173,7 +142,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
|||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -256,28 +225,25 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
return
|
||||
}
|
||||
}
|
||||
safeSrc := N.IsSafePacketReader(source)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
var copyN int64
|
||||
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
|
||||
n += copyN
|
||||
return
|
||||
}
|
||||
}
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||
FrontHeadroom: frontHeadroom,
|
||||
RearHeadroom: rearHeadroom,
|
||||
MTU: N.CalculateMTU(source, destinationConn),
|
||||
})
|
||||
if !needCopy || common.LowMemory {
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
||||
|
@ -285,40 +251,6 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
for {
|
||||
buffer, destination, err = source.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if buffer == nil {
|
||||
panic("nil buffer returned from " + reflect.TypeOf(source).String())
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
if dataLen == 0 {
|
||||
continue
|
||||
}
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
|
@ -343,7 +275,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
|
|||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -379,7 +311,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
|
|||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
|
|
@ -17,29 +17,12 @@ import (
|
|||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
_, err = source.WaitReadBuffer()
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
|
@ -47,11 +30,10 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
|
|||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -70,37 +52,19 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
|
|||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
_, destination, err = source.WaitReadPacket()
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
|
@ -124,6 +88,7 @@ type syscallReadWaiter struct {
|
|||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
|
@ -136,31 +101,28 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer, readBuffer := w.options.NewBuffer()
|
||||
var readN int
|
||||
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
|
||||
if readN > 0 {
|
||||
buffer.Resize(readBuffer.Start(), readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if readN == 0 {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
|
@ -190,6 +152,7 @@ type syscallPacketReadWaiter struct {
|
|||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
options N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
|
@ -202,38 +165,34 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
w.readFrom = M.Socksaddr{}
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
w.options = options
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer, readBuffer := w.options.NewPacketBuffer()
|
||||
var readN int
|
||||
var from syscall.Sockaddr
|
||||
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
|
||||
if readN > 0 {
|
||||
buffer.Resize(readBuffer.Start(), readN)
|
||||
} else {
|
||||
buffer.Release()
|
||||
buffer = nil
|
||||
}
|
||||
if w.readErr == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
|
|
|
@ -5,8 +5,48 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type ReadWaitable interface {
|
||||
InitializeReadWaiter(options ReadWaitOptions) (needCopy bool)
|
||||
}
|
||||
|
||||
type ReadWaitOptions struct {
|
||||
FrontHeadroom int
|
||||
RearHeadroom int
|
||||
MTU int
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) NeedHeadroom() bool {
|
||||
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
return o.newBuffer(buf.BufferSize)
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
return o.newBuffer(buf.UDPBufferSize)
|
||||
}
|
||||
|
||||
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) {
|
||||
var bufferSize int
|
||||
if o.MTU > 0 {
|
||||
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
||||
} else {
|
||||
bufferSize = defaultBufferSize
|
||||
}
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
if o.RearHeadroom > 0 {
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom])
|
||||
} else {
|
||||
readBuffer = buffer
|
||||
}
|
||||
readBuffer.Resize(o.FrontHeadroom, 0)
|
||||
return
|
||||
}
|
||||
|
||||
type ReadWaiter interface {
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
ReadWaitable
|
||||
WaitReadBuffer() (buffer *buf.Buffer, err error)
|
||||
}
|
||||
|
||||
|
@ -15,7 +55,7 @@ type ReadWaitCreator interface {
|
|||
}
|
||||
|
||||
type PacketReadWaiter interface {
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
ReadWaitable
|
||||
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,11 +10,15 @@ type ThreadUnsafeWriter interface {
|
|||
WriteIsThreadUnsafe()
|
||||
}
|
||||
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
type ThreadSafeReader interface {
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
|
||||
}
|
||||
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
type ThreadSafePacketReader interface {
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
|
@ -23,6 +27,7 @@ func IsUnsafeWriter(writer any) bool {
|
|||
return isUnsafe
|
||||
}
|
||||
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
func IsSafeReader(reader any) ThreadSafeReader {
|
||||
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
|
||||
return safeReader
|
||||
|
@ -39,6 +44,7 @@ func IsSafeReader(reader any) ThreadSafeReader {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Deprecated: Use ReadWaiter interface instead.
|
||||
func IsSafePacketReader(reader any) ThreadSafePacketReader {
|
||||
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
|
||||
return safeReader
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
// pipeDeadline is an abstraction for handling timeouts.
|
||||
|
@ -107,7 +107,7 @@ type pipe struct {
|
|||
readDeadline pipeDeadline
|
||||
writeDeadline pipeDeadline
|
||||
|
||||
newBuffer func() *buf.Buffer
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
// Pipe creates a synchronous, in-memory, full duplex
|
||||
|
|
|
@ -11,8 +11,9 @@ import (
|
|||
|
||||
var _ N.ReadWaiter = (*pipe)(nil)
|
||||
|
||||
func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
p.newBuffer = newBuffer
|
||||
func (p *pipe) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
p.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
|
@ -32,15 +33,17 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
|
|||
case isClosedChan(p.readDeadline.wait()):
|
||||
return nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
var readBuffer *buf.Buffer
|
||||
select {
|
||||
case bw := <-p.rdRx:
|
||||
buffer = p.newBuffer()
|
||||
buffer, readBuffer = p.readWaitOptions.NewBuffer()
|
||||
var nr int
|
||||
nr, err = buffer.Write(bw)
|
||||
nr, err = readBuffer.Write(bw)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
p.rdTx <- nr
|
||||
return
|
||||
case <-p.localDone:
|
||||
|
|
39
common/udpnat/conn_wait.go
Normal file
39
common/udpnat/conn_wait.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package udpnat
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.PacketReadWaiter = (*conn)(nil)
|
||||
|
||||
func (c *conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
if c.readWaitOptions.NeedHeadroom() {
|
||||
var readBuffer *buf.Buffer
|
||||
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = readBuffer.Write(p.data.Bytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
p.data.Release()
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
} else {
|
||||
buffer = p.data
|
||||
}
|
||||
destination = p.destination
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
}
|
|
@ -107,22 +107,16 @@ type packet struct {
|
|||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
ctx context.Context
|
||||
cancel common.ContextCancelCauseFunc
|
||||
data chan packet
|
||||
localAddr M.Socksaddr
|
||||
remoteAddr M.Socksaddr
|
||||
source N.PacketWriter
|
||||
}
|
||||
var _ N.PacketConn = (*conn)(nil)
|
||||
|
||||
func (c *conn) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
return p.data, p.destination, nil
|
||||
case <-c.ctx.Done():
|
||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
type conn struct {
|
||||
ctx context.Context
|
||||
cancel common.ContextCancelCauseFunc
|
||||
data chan packet
|
||||
localAddr M.Socksaddr
|
||||
remoteAddr M.Socksaddr
|
||||
source N.PacketWriter
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
||||
|
@ -136,37 +130,10 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case p := <-c.data:
|
||||
_, err = newBuffer().ReadOnceFrom(p.data)
|
||||
p.data.Release()
|
||||
return p.destination, err
|
||||
case <-c.ctx.Done():
|
||||
return M.Socksaddr{}, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return c.source.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case pkt := <-c.data:
|
||||
n = copy(p, pkt.data.Bytes())
|
||||
pkt.data.Release()
|
||||
addr = pkt.destination.UDPAddr()
|
||||
return n, addr, nil
|
||||
case <-c.ctx.Done():
|
||||
return 0, nil, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr))
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -21,10 +20,10 @@ var (
|
|||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
isConnect bool
|
||||
destination M.Socksaddr
|
||||
writer N.VectorisedWriter
|
||||
newBuffer func() *buf.Buffer
|
||||
isConnect bool
|
||||
destination M.Socksaddr
|
||||
writer N.VectorisedWriter
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, request Request) *Conn {
|
||||
|
@ -148,36 +147,6 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|||
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
|
||||
}
|
||||
|
||||
func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
c.newBuffer = newBuffer
|
||||
}
|
||||
|
||||
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if c.newBuffer == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
if c.isConnect {
|
||||
destination = c.destination
|
||||
} else {
|
||||
destination, err = AddrParser.ReadAddrPort(c.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.Conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buffer = c.newBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
|
40
common/uot/conn_wait.go
Normal file
40
common/uot/conn_wait.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package uot
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if c.isConnect {
|
||||
destination = c.destination
|
||||
} else {
|
||||
destination, err = AddrParser.ReadAddrPort(c.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.Conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var readBuffer *buf.Buffer
|
||||
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = readBuffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
|
||||
}
|
||||
buffer.Resize(readBuffer.Start(), readBuffer.Len())
|
||||
return
|
||||
}
|
|
@ -24,8 +24,8 @@ type AssociatePacketReadWaiter struct {
|
|||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
return w.readWaiter.InitializeReadWaiter(options)
|
||||
}
|
||||
|
||||
func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue