Merge ThreadSafeReader into ReadWaiter interface

This commit is contained in:
世界 2023-12-06 21:25:48 +08:00
parent 335ac47e45
commit 0e0da26649
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 258 additions and 285 deletions

View file

@ -8,14 +8,16 @@ import (
"sync/atomic" "sync/atomic"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
) )
type Buffer struct { type Buffer struct {
data []byte data []byte
start int start int
end int end int
refs int32 refs atomic.Int32
managed bool managed bool
closed bool closed bool
} }
@ -281,24 +283,40 @@ func (b *Buffer) FullReset() {
} }
func (b *Buffer) IncRef() { func (b *Buffer) IncRef() {
atomic.AddInt32(&b.refs, 1) b.refs.Add(1)
} }
func (b *Buffer) DecRef() { func (b *Buffer) DecRef() {
atomic.AddInt32(&b.refs, -1) b.refs.Add(-1)
} }
func (b *Buffer) Release() { func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed { if b == nil || b.closed || !b.managed {
return return
} }
if atomic.LoadInt32(&b.refs) > 0 { if b.refs.Load() > 0 {
return return
} }
common.Must(Put(b.data)) common.Must(Put(b.data))
*b = Buffer{closed: true} *b = Buffer{closed: true}
} }
func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || b.closed || !b.managed {
return
}
refs := b.refs.Load()
if refs == 0 {
panic("leaking buffer")
} else {
panic(F.ToString("leaking buffer with ", refs, " references"))
}
} else {
b.Release()
}
}
func (b *Buffer) Cut(start int, end int) *Buffer { func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start b.start += start
b.end = len(b.data) - end b.end = len(b.data) - end

View file

@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter readWaiter N.PacketReadWaiter
} }
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.readWaiter.InitializeReadWaiter(newBuffer) return w.readWaiter.InitializeReadWaiter(options)
} }
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
addr M.Socksaddr addr M.Socksaddr
} }
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.readWaiter.InitializeReadWaiter(newBuffer) return w.readWaiter.InitializeReadWaiter(options)
} }
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {

View file

@ -5,7 +5,6 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"reflect"
"syscall" "syscall"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -57,21 +56,23 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
} }
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
safeSrc := N.IsSafeReader(source) frontHeadroom := N.CalculateFrontHeadroom(destination)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
}
}
readWaiter, isReadWaiter := CreateReadWaiter(source) readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter { if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destination),
})
if !needCopy || common.LowMemory {
var handled bool var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled { if handled {
return return
} }
} }
}
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
} }
@ -113,38 +114,6 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
} }
} }
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
buffer, err = source.ReadBufferThreadSafe()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination) frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination)
@ -173,7 +142,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
buffer.Resize(readBuffer.Start(), dataLen) buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Leak()
if !notFirstTime { if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err) err = N.ReportHandshakeFailure(originSource, err)
} }
@ -256,69 +225,32 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return return
} }
} }
safeSrc := N.IsSafePacketReader(source)
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
var copyN int64
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
n += copyN
return
}
}
var ( var (
handled bool handled bool
copeN int64 copeN int64
) )
readWaiter, isReadWaiter := CreatePacketReadWaiter(source) readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter { if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled { if handled {
n += copeN n += copeN
return return
} }
} }
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
n += copeN n += copeN
return return
} }
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
for {
buffer, destination, err = source.ReadPacketThreadSafe()
if err != nil {
return
}
if buffer == nil {
panic("nil buffer returned from " + reflect.TypeOf(source).String())
}
dataLen := buffer.Len()
if dataLen == 0 {
continue
}
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn)
@ -343,7 +275,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
buffer.Resize(readBuffer.Start(), dataLen) buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination) err = destinationConn.WritePacket(buffer, destination)
if err != nil { if err != nil {
buffer.Release() buffer.Leak()
if !notFirstTime { if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err) err = N.ReportHandshakeFailure(originSource, err)
} }
@ -379,7 +311,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
buffer.Resize(readBuffer.Start(), dataLen) buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination) err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
if err != nil { if err != nil {
buffer.Release() buffer.Leak()
if !notFirstTime { if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err) err = N.ReportHandshakeFailure(originSource, err)
} }

View file

@ -17,29 +17,12 @@ import (
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var ( var (
buffer *buf.Buffer buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool notFirstTime bool
) )
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for { for {
_, err = source.WaitReadBuffer() buffer, err = source.WaitReadBuffer()
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
err = nil err = nil
@ -47,11 +30,10 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
} }
return return
} }
dataLen := readBuffer.Len() dataLen := buffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Leak()
if !notFirstTime { if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err) err = N.ReportHandshakeFailure(originSource, err)
} }
@ -70,37 +52,19 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var ( var (
buffer *buf.Buffer buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr destination M.Socksaddr
) )
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for { for {
_, destination, err = source.WaitReadPacket() buffer, destination, err = source.WaitReadPacket()
if err != nil { if err != nil {
return return
} }
dataLen := readBuffer.Len() dataLen := buffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination) err = destinationConn.WritePacket(buffer, destination)
if err != nil { if err != nil {
buffer.Release() buffer.Leak()
if !notFirstTime { if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err) err = N.ReportHandshakeFailure(originSource, err)
} }
@ -124,6 +88,7 @@ type syscallReadWaiter struct {
readErr error readErr error
readFunc func(fd uintptr) (done bool) readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer buffer *buf.Buffer
options N.ReadWaitOptions
} }
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
@ -136,17 +101,14 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false return nil, false
} }
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.readErr = nil w.options = options
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) { w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer() buffer, readBuffer := w.options.NewBuffer()
var readN int var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
if readN > 0 { if readN > 0 {
buffer.Truncate(readN) buffer.Resize(readBuffer.Start(), readN)
} else { } else {
buffer.Release() buffer.Release()
buffer = nil buffer = nil
@ -160,7 +122,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.buffer = buffer w.buffer = buffer
return true return true
} }
} return false
} }
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
@ -190,6 +152,7 @@ type syscallPacketReadWaiter struct {
readFrom M.Socksaddr readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool) readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer buffer *buf.Buffer
options N.ReadWaitOptions
} }
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
@ -202,19 +165,15 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false return nil, false
} }
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.readErr = nil w.options = options
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) { w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer() buffer, readBuffer := w.options.NewPacketBuffer()
var readN int var readN int
var from syscall.Sockaddr var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
if readN > 0 { if readN > 0 {
buffer.Truncate(readN) buffer.Resize(readBuffer.Start(), readN)
} else { } else {
buffer.Release() buffer.Release()
buffer = nil buffer = nil
@ -233,7 +192,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
w.buffer = buffer w.buffer = buffer
return true return true
} }
} return false
} }
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {

View file

@ -5,8 +5,48 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
) )
type ReadWaitable interface {
InitializeReadWaiter(options ReadWaitOptions) (needCopy bool)
}
type ReadWaitOptions struct {
FrontHeadroom int
RearHeadroom int
MTU int
}
func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
}
func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
return o.newBuffer(buf.BufferSize)
}
func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) {
return o.newBuffer(buf.UDPBufferSize)
}
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) {
var bufferSize int
if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
} else {
bufferSize = defaultBufferSize
}
buffer = buf.NewSize(bufferSize)
if o.RearHeadroom > 0 {
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom])
} else {
readBuffer = buffer
}
readBuffer.Resize(o.FrontHeadroom, 0)
return
}
type ReadWaiter interface { type ReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer) ReadWaitable
WaitReadBuffer() (buffer *buf.Buffer, err error) WaitReadBuffer() (buffer *buf.Buffer, err error)
} }
@ -15,7 +55,7 @@ type ReadWaitCreator interface {
} }
type PacketReadWaiter interface { type PacketReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer) ReadWaitable
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
} }

View file

@ -10,11 +10,15 @@ type ThreadUnsafeWriter interface {
WriteIsThreadUnsafe() WriteIsThreadUnsafe()
} }
// Deprecated: Use ReadWaiter interface instead.
type ThreadSafeReader interface { type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadBufferThreadSafe() (buffer *buf.Buffer, err error) ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
} }
// Deprecated: Use ReadWaiter interface instead.
type ThreadSafePacketReader interface { type ThreadSafePacketReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
} }
@ -23,6 +27,7 @@ func IsUnsafeWriter(writer any) bool {
return isUnsafe return isUnsafe
} }
// Deprecated: Use ReadWaiter interface instead.
func IsSafeReader(reader any) ThreadSafeReader { func IsSafeReader(reader any) ThreadSafeReader {
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe { if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
return safeReader return safeReader
@ -39,6 +44,7 @@ func IsSafeReader(reader any) ThreadSafeReader {
return nil return nil
} }
// Deprecated: Use ReadWaiter interface instead.
func IsSafePacketReader(reader any) ThreadSafePacketReader { func IsSafePacketReader(reader any) ThreadSafePacketReader {
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe { if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
return safeReader return safeReader

View file

@ -11,7 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/sagernet/sing/common/buf" N "github.com/sagernet/sing/common/network"
) )
// pipeDeadline is an abstraction for handling timeouts. // pipeDeadline is an abstraction for handling timeouts.
@ -107,7 +107,7 @@ type pipe struct {
readDeadline pipeDeadline readDeadline pipeDeadline
writeDeadline pipeDeadline writeDeadline pipeDeadline
newBuffer func() *buf.Buffer readWaitOptions N.ReadWaitOptions
} }
// Pipe creates a synchronous, in-memory, full duplex // Pipe creates a synchronous, in-memory, full duplex

View file

@ -11,8 +11,9 @@ import (
var _ N.ReadWaiter = (*pipe)(nil) var _ N.ReadWaiter = (*pipe)(nil)
func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (p *pipe) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
p.newBuffer = newBuffer p.readWaitOptions = options
return false
} }
func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) {
@ -32,15 +33,17 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
case isClosedChan(p.readDeadline.wait()): case isClosedChan(p.readDeadline.wait()):
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
var readBuffer *buf.Buffer
select { select {
case bw := <-p.rdRx: case bw := <-p.rdRx:
buffer = p.newBuffer() buffer, readBuffer = p.readWaitOptions.NewBuffer()
var nr int var nr int
nr, err = buffer.Write(bw) nr, err = readBuffer.Write(bw)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
buffer.Resize(readBuffer.Start(), readBuffer.Len())
p.rdTx <- nr p.rdTx <- nr
return return
case <-p.localDone: case <-p.localDone:

View file

@ -0,0 +1,39 @@
package udpnat
import (
"io"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.PacketReadWaiter = (*conn)(nil)
func (c *conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
if c.readWaitOptions.NeedHeadroom() {
var readBuffer *buf.Buffer
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
_, err = readBuffer.Write(p.data.Bytes())
if err != nil {
buffer.Release()
return
}
p.data.Release()
buffer.Resize(readBuffer.Start(), readBuffer.Len())
} else {
buffer = p.data
}
destination = p.destination
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}

View file

@ -107,6 +107,8 @@ type packet struct {
destination M.Socksaddr destination M.Socksaddr
} }
var _ N.PacketConn = (*conn)(nil)
type conn struct { type conn struct {
ctx context.Context ctx context.Context
cancel common.ContextCancelCauseFunc cancel common.ContextCancelCauseFunc
@ -114,15 +116,7 @@ type conn struct {
localAddr M.Socksaddr localAddr M.Socksaddr
remoteAddr M.Socksaddr remoteAddr M.Socksaddr
source N.PacketWriter source N.PacketWriter
} readWaitOptions N.ReadWaitOptions
func (c *conn) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) {
select {
case p := <-c.data:
return p.data, p.destination, nil
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
} }
func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
@ -136,37 +130,10 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
} }
} }
func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
p.data.Release()
return p.destination, err
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}
func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.source.WritePacket(buffer, destination) return c.source.WritePacket(buffer, destination)
} }
func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
n = copy(p, pkt.data.Bytes())
pkt.data.Release()
addr = pkt.destination.UDPAddr()
return n, addr, nil
case <-c.ctx.Done():
return 0, nil, io.ErrClosedPipe
}
}
func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr))
}
func (c *conn) Close() error { func (c *conn) Close() error {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():

View file

@ -4,7 +4,6 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"os"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -24,7 +23,7 @@ type Conn struct {
isConnect bool isConnect bool
destination M.Socksaddr destination M.Socksaddr
writer N.VectorisedWriter writer N.VectorisedWriter
newBuffer func() *buf.Buffer readWaitOptions N.ReadWaitOptions
} }
func NewConn(conn net.Conn, request Request) *Conn { func NewConn(conn net.Conn, request Request) *Conn {
@ -148,36 +147,6 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
} }
func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.newBuffer = newBuffer
}
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if c.newBuffer == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
if c.isConnect {
destination = c.destination
} else {
destination, err = AddrParser.ReadAddrPort(c.Conn)
if err != nil {
return
}
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.newBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
}
return
}
func (c *Conn) NeedAdditionalReadDeadline() bool { func (c *Conn) NeedAdditionalReadDeadline() bool {
return true return true
} }

40
common/uot/conn_wait.go Normal file
View file

@ -0,0 +1,40 @@
package uot
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if c.isConnect {
destination = c.destination
} else {
destination, err = AddrParser.ReadAddrPort(c.Conn)
if err != nil {
return
}
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return
}
var readBuffer *buf.Buffer
buffer, readBuffer = c.readWaitOptions.NewPacketBuffer()
_, err = readBuffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
}
buffer.Resize(readBuffer.Start(), readBuffer.Len())
return
}

View file

@ -24,8 +24,8 @@ type AssociatePacketReadWaiter struct {
readWaiter N.PacketReadWaiter readWaiter N.PacketReadWaiter
} }
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.readWaiter.InitializeReadWaiter(newBuffer) return w.readWaiter.InitializeReadWaiter(options)
} }
func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {