From c17babe0ba3504773f4d7129552d692c52b38887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 21:25:48 +0800 Subject: [PATCH] Merge ThreadSafeReader into ReadWaiter interface --- common/buf/buffer.go | 26 +++++- common/bufio/bind_wait.go | 8 +- common/bufio/copy.go | 122 ++++++------------------ common/bufio/copy_direct_posix.go | 149 +++++++++++------------------- common/network/direct.go | 44 ++++++++- common/network/thread.go | 6 ++ common/pipe/pipe.go | 4 +- common/pipe/pipe_wait.go | 11 ++- common/udpnat/conn_wait.go | 39 ++++++++ common/udpnat/service.go | 51 ++-------- common/uot/conn.go | 39 +------- common/uot/conn_wait.go | 40 ++++++++ protocol/socks/packet_wait.go | 4 +- 13 files changed, 258 insertions(+), 285 deletions(-) create mode 100644 common/udpnat/conn_wait.go create mode 100644 common/uot/conn_wait.go diff --git a/common/buf/buffer.go b/common/buf/buffer.go index d0f4ba4..79714de 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -8,14 +8,16 @@ import ( "sync/atomic" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" ) type Buffer struct { data []byte start int end int - refs int32 + refs atomic.Int32 managed bool closed bool } @@ -281,24 +283,40 @@ func (b *Buffer) FullReset() { } func (b *Buffer) IncRef() { - atomic.AddInt32(&b.refs, 1) + b.refs.Add(1) } func (b *Buffer) DecRef() { - atomic.AddInt32(&b.refs, -1) + b.refs.Add(-1) } func (b *Buffer) Release() { if b == nil || b.closed || !b.managed { return } - if atomic.LoadInt32(&b.refs) > 0 { + if b.refs.Load() > 0 { return } common.Must(Put(b.data)) *b = Buffer{closed: true} } +func (b *Buffer) Leak() { + if debug.Enabled { + if b == nil || b.closed || !b.managed { + return + } + refs := b.refs.Load() + if refs == 0 { + panic("leaking buffer") + } else { + panic(F.ToString("leaking buffer with ", refs, " references")) + } + } else { + b.Release() + } +} + func (b *Buffer) Cut(start int, end int) *Buffer { b.start += start b.end = len(b.data) - end diff --git a/common/bufio/bind_wait.go b/common/bufio/bind_wait.go index 724a76e..1396552 100644 --- a/common/bufio/bind_wait.go +++ b/common/bufio/bind_wait.go @@ -12,8 +12,8 @@ type BindPacketReadWaiter struct { readWaiter N.PacketReadWaiter } -func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) } func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { @@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct { addr M.Socksaddr } -func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) } func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 3bdb164..53c6e2e 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -5,7 +5,6 @@ import ( "errors" "io" "net" - "reflect" "syscall" "github.com/sagernet/sing/common" @@ -57,19 +56,21 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - safeSrc := N.IsSafeReader(source) - headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) - if safeSrc != nil { - if headroom == 0 { - return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters) - } - } + frontHeadroom := N.CalculateFrontHeadroom(destination) + rearHeadroom := N.CalculateRearHeadroom(destination) readWaiter, isReadWaiter := CreateReadWaiter(source) if isReadWaiter { - var handled bool - handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) - if handled { - return + needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ + FrontHeadroom: frontHeadroom, + RearHeadroom: rearHeadroom, + MTU: N.CalculateMTU(source, destination), + }) + if !needCopy || common.LowMemory { + var handled bool + handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) + if handled { + return + } } } return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) @@ -113,38 +114,6 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so } } -func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - var notFirstTime bool - for { - var buffer *buf.Buffer - buffer, err = source.ReadBufferThreadSafe() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - return - } - return - } - dataLen := buffer.Len() - err = destination.WriteBuffer(buffer) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) - } - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) @@ -173,7 +142,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, buffer.Resize(readBuffer.Start(), dataLen) err = destination.WriteBuffer(buffer) if err != nil { - buffer.Release() + buffer.Leak() if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -256,28 +225,25 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } } - safeSrc := N.IsSafePacketReader(source) frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) - headroom := frontHeadroom + rearHeadroom - if safeSrc != nil { - if headroom == 0 { - var copyN int64 - copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0) - n += copyN - return - } - } var ( handled bool copeN int64 ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) - if handled { - n += copeN - return + needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ + FrontHeadroom: frontHeadroom, + RearHeadroom: rearHeadroom, + MTU: N.CalculateMTU(source, destinationConn), + }) + if !needCopy || common.LowMemory { + handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) + if handled { + n += copeN + return + } } } copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) @@ -285,40 +251,6 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } -func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { - var buffer *buf.Buffer - var destination M.Socksaddr - for { - buffer, destination, err = source.ReadPacketThreadSafe() - if err != nil { - return - } - if buffer == nil { - panic("nil buffer returned from " + reflect.TypeOf(source).String()) - } - dataLen := buffer.Len() - if dataLen == 0 { - continue - } - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) - } - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) @@ -343,7 +275,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri buffer.Resize(readBuffer.Start(), dataLen) err = destinationConn.WritePacket(buffer, destination) if err != nil { - buffer.Release() + buffer.Leak() if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -379,7 +311,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr buffer.Resize(readBuffer.Start(), dataLen) err = destinationConn.WritePacket(buffer, packetBuffer.Destination) if err != nil { - buffer.Release() + buffer.Leak() if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 06da27d..0750678 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -17,29 +17,12 @@ import ( func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { handled = true - frontHeadroom := N.CalculateFrontHeadroom(destination) - rearHeadroom := N.CalculateRearHeadroom(destination) - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } var ( buffer *buf.Buffer - readBuffer *buf.Buffer notFirstTime bool ) - source.InitializeReadWaiter(func() *buf.Buffer { - buffer = buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - return readBuffer - }) - defer source.InitializeReadWaiter(nil) for { - _, err = source.WaitReadBuffer() + buffer, err = source.WaitReadBuffer() if err != nil { if errors.Is(err, io.EOF) { err = nil @@ -47,11 +30,10 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour } return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() err = destination.WriteBuffer(buffer) if err != nil { - buffer.Release() + buffer.Leak() if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -70,37 +52,19 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { handled = true - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.UDPBufferSize - } var ( buffer *buf.Buffer - readBuffer *buf.Buffer destination M.Socksaddr ) - source.InitializeReadWaiter(func() *buf.Buffer { - buffer = buf.NewSize(bufferSize) - readBufferRaw := buffer.Slice() - readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - return readBuffer - }) - defer source.InitializeReadWaiter(nil) for { - _, destination, err = source.WaitReadPacket() + buffer, destination, err = source.WaitReadPacket() if err != nil { return } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) + dataLen := buffer.Len() err = destinationConn.WritePacket(buffer, destination) if err != nil { - buffer.Release() + buffer.Leak() if !notFirstTime { err = N.ReportHandshakeFailure(originSource, err) } @@ -124,6 +88,7 @@ type syscallReadWaiter struct { readErr error readFunc func(fd uintptr) (done bool) buffer *buf.Buffer + options N.ReadWaitOptions } func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { @@ -136,31 +101,28 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { return nil, false } -func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readErr = nil - if newBuffer == nil { - w.readFunc = nil - } else { - w.readFunc = func(fd uintptr) (done bool) { - buffer := newBuffer() - var readN int - readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) - if readN > 0 { - buffer.Truncate(readN) - } else { - buffer.Release() - buffer = nil - } - if w.readErr == syscall.EAGAIN { - return false - } - if readN == 0 { - w.readErr = io.EOF - } - w.buffer = buffer - return true +func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + buffer, readBuffer := w.options.NewBuffer() + var readN int + readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes()) + if readN > 0 { + buffer.Resize(readBuffer.Start(), readN) + } else { + buffer.Release() + buffer = nil } + if w.readErr == syscall.EAGAIN { + return false + } + if readN == 0 { + w.readErr = io.EOF + } + w.buffer = buffer + return true } + return false } func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { @@ -190,6 +152,7 @@ type syscallPacketReadWaiter struct { readFrom M.Socksaddr readFunc func(fd uintptr) (done bool) buffer *buf.Buffer + options N.ReadWaitOptions } func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { @@ -202,38 +165,34 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) return nil, false } -func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readErr = nil - w.readFrom = M.Socksaddr{} - if newBuffer == nil { - w.readFunc = nil - } else { - w.readFunc = func(fd uintptr) (done bool) { - buffer := newBuffer() - var readN int - var from syscall.Sockaddr - readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) - if readN > 0 { - buffer.Truncate(readN) - } else { - buffer.Release() - buffer = nil - } - if w.readErr == syscall.EAGAIN { - return false - } - if from != nil { - switch fromAddr := from.(type) { - case *syscall.SockaddrInet4: - w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) - case *syscall.SockaddrInet6: - w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() - } - } - w.buffer = buffer - return true +func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + w.options = options + w.readFunc = func(fd uintptr) (done bool) { + buffer, readBuffer := w.options.NewPacketBuffer() + var readN int + var from syscall.Sockaddr + readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0) + if readN > 0 { + buffer.Resize(readBuffer.Start(), readN) + } else { + buffer.Release() + buffer = nil } + if w.readErr == syscall.EAGAIN { + return false + } + if from != nil { + switch fromAddr := from.(type) { + case *syscall.SockaddrInet4: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) + case *syscall.SockaddrInet6: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() + } + } + w.buffer = buffer + return true } + return false } func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { diff --git a/common/network/direct.go b/common/network/direct.go index b645676..7a28eca 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -5,8 +5,48 @@ import ( M "github.com/sagernet/sing/common/metadata" ) +type ReadWaitable interface { + InitializeReadWaiter(options ReadWaitOptions) (needCopy bool) +} + +type ReadWaitOptions struct { + FrontHeadroom int + RearHeadroom int + MTU int +} + +func (o ReadWaitOptions) NeedHeadroom() bool { + return o.FrontHeadroom > 0 || o.RearHeadroom > 0 +} + +func (o ReadWaitOptions) NewBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { + return o.newBuffer(buf.BufferSize) +} + +func (o ReadWaitOptions) NewPacketBuffer() (buffer *buf.Buffer, readBuffer *buf.Buffer) { + return o.newBuffer(buf.UDPBufferSize) +} + +func (o ReadWaitOptions) newBuffer(defaultBufferSize int) (buffer *buf.Buffer, readBuffer *buf.Buffer) { + var bufferSize int + if o.MTU > 0 { + bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom + } else { + bufferSize = defaultBufferSize + } + buffer = buf.NewSize(bufferSize) + if o.RearHeadroom > 0 { + readBufferRaw := buffer.Slice() + readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-o.RearHeadroom]) + } else { + readBuffer = buffer + } + readBuffer.Resize(o.FrontHeadroom, 0) + return +} + type ReadWaiter interface { - InitializeReadWaiter(newBuffer func() *buf.Buffer) + ReadWaitable WaitReadBuffer() (buffer *buf.Buffer, err error) } @@ -15,7 +55,7 @@ type ReadWaitCreator interface { } type PacketReadWaiter interface { - InitializeReadWaiter(newBuffer func() *buf.Buffer) + ReadWaitable WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) } diff --git a/common/network/thread.go b/common/network/thread.go index a492fdd..58ccebb 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -10,11 +10,15 @@ type ThreadUnsafeWriter interface { WriteIsThreadUnsafe() } +// Deprecated: Use ReadWaiter interface instead. type ThreadSafeReader interface { + // Deprecated: Use ReadWaiter interface instead. ReadBufferThreadSafe() (buffer *buf.Buffer, err error) } +// Deprecated: Use ReadWaiter interface instead. type ThreadSafePacketReader interface { + // Deprecated: Use ReadWaiter interface instead. ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) } @@ -23,6 +27,7 @@ func IsUnsafeWriter(writer any) bool { return isUnsafe } +// Deprecated: Use ReadWaiter interface instead. func IsSafeReader(reader any) ThreadSafeReader { if safeReader, isSafe := reader.(ThreadSafeReader); isSafe { return safeReader @@ -39,6 +44,7 @@ func IsSafeReader(reader any) ThreadSafeReader { return nil } +// Deprecated: Use ReadWaiter interface instead. func IsSafePacketReader(reader any) ThreadSafePacketReader { if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe { return safeReader diff --git a/common/pipe/pipe.go b/common/pipe/pipe.go index 2f87ae3..915b840 100644 --- a/common/pipe/pipe.go +++ b/common/pipe/pipe.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" ) // pipeDeadline is an abstraction for handling timeouts. @@ -107,7 +107,7 @@ type pipe struct { readDeadline pipeDeadline writeDeadline pipeDeadline - newBuffer func() *buf.Buffer + readWaitOptions N.ReadWaitOptions } // Pipe creates a synchronous, in-memory, full duplex diff --git a/common/pipe/pipe_wait.go b/common/pipe/pipe_wait.go index 27ea27b..e1a2413 100644 --- a/common/pipe/pipe_wait.go +++ b/common/pipe/pipe_wait.go @@ -11,8 +11,9 @@ import ( var _ N.ReadWaiter = (*pipe)(nil) -func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - p.newBuffer = newBuffer +func (p *pipe) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + p.readWaitOptions = options + return false } func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) { @@ -32,15 +33,17 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) { case isClosedChan(p.readDeadline.wait()): return nil, os.ErrDeadlineExceeded } + var readBuffer *buf.Buffer select { case bw := <-p.rdRx: - buffer = p.newBuffer() + buffer, readBuffer = p.readWaitOptions.NewBuffer() var nr int - nr, err = buffer.Write(bw) + nr, err = readBuffer.Write(bw) if err != nil { buffer.Release() return } + buffer.Resize(readBuffer.Start(), readBuffer.Len()) p.rdTx <- nr return case <-p.localDone: diff --git a/common/udpnat/conn_wait.go b/common/udpnat/conn_wait.go new file mode 100644 index 0000000..5dde531 --- /dev/null +++ b/common/udpnat/conn_wait.go @@ -0,0 +1,39 @@ +package udpnat + +import ( + "io" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.PacketReadWaiter = (*conn)(nil) + +func (c *conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + if c.readWaitOptions.NeedHeadroom() { + var readBuffer *buf.Buffer + buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() + _, err = readBuffer.Write(p.data.Bytes()) + if err != nil { + buffer.Release() + return + } + p.data.Release() + buffer.Resize(readBuffer.Start(), readBuffer.Len()) + } else { + buffer = p.data + } + destination = p.destination + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} diff --git a/common/udpnat/service.go b/common/udpnat/service.go index f43dc92..100a390 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -107,22 +107,16 @@ type packet struct { destination M.Socksaddr } -type conn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - data chan packet - localAddr M.Socksaddr - remoteAddr M.Socksaddr - source N.PacketWriter -} +var _ N.PacketConn = (*conn)(nil) -func (c *conn) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) { - select { - case p := <-c.data: - return p.data, p.destination, nil - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } +type conn struct { + ctx context.Context + cancel common.ContextCancelCauseFunc + data chan packet + localAddr M.Socksaddr + remoteAddr M.Socksaddr + source N.PacketWriter + readWaitOptions N.ReadWaitOptions } func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { @@ -136,37 +130,10 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } -func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - p.data.Release() - return p.destination, err - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.source.WritePacket(buffer, destination) } -func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - select { - case pkt := <-c.data: - n = copy(p, pkt.data.Bytes()) - pkt.data.Release() - addr = pkt.destination.UDPAddr() - return n, addr, nil - case <-c.ctx.Done(): - return 0, nil, io.ErrClosedPipe - } -} - -func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr)) -} - func (c *conn) Close() error { select { case <-c.ctx.Done(): diff --git a/common/uot/conn.go b/common/uot/conn.go index cf289d5..8e085be 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "io" "net" - "os" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -21,10 +20,10 @@ var ( type Conn struct { net.Conn - isConnect bool - destination M.Socksaddr - writer N.VectorisedWriter - newBuffer func() *buf.Buffer + isConnect bool + destination M.Socksaddr + writer N.VectorisedWriter + readWaitOptions N.ReadWaitOptions } func NewConn(conn net.Conn, request Request) *Conn { @@ -148,36 +147,6 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) } -func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - c.newBuffer = newBuffer -} - -func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - if c.newBuffer == nil { - return nil, M.Socksaddr{}, os.ErrInvalid - } - if c.isConnect { - destination = c.destination - } else { - destination, err = AddrParser.ReadAddrPort(c.Conn) - if err != nil { - return - } - } - var length uint16 - err = binary.Read(c.Conn, binary.BigEndian, &length) - if err != nil { - return - } - buffer = c.newBuffer() - _, err = buffer.ReadFullFrom(c.Conn, int(length)) - if err != nil { - buffer.Release() - return nil, M.Socksaddr{}, E.Cause(err, "UoT read") - } - return -} - func (c *Conn) NeedAdditionalReadDeadline() bool { return true } diff --git a/common/uot/conn_wait.go b/common/uot/conn_wait.go new file mode 100644 index 0000000..eecb360 --- /dev/null +++ b/common/uot/conn_wait.go @@ -0,0 +1,40 @@ +package uot + +import ( + "encoding/binary" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + if c.isConnect { + destination = c.destination + } else { + destination, err = AddrParser.ReadAddrPort(c.Conn) + if err != nil { + return + } + } + var length uint16 + err = binary.Read(c.Conn, binary.BigEndian, &length) + if err != nil { + return + } + var readBuffer *buf.Buffer + buffer, readBuffer = c.readWaitOptions.NewPacketBuffer() + _, err = readBuffer.ReadFullFrom(c.Conn, int(length)) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, E.Cause(err, "UoT read") + } + buffer.Resize(readBuffer.Start(), readBuffer.Len()) + return +} diff --git a/protocol/socks/packet_wait.go b/protocol/socks/packet_wait.go index 9b9047b..32b1341 100644 --- a/protocol/socks/packet_wait.go +++ b/protocol/socks/packet_wait.go @@ -24,8 +24,8 @@ type AssociatePacketReadWaiter struct { readWaiter N.PacketReadWaiter } -func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - w.readWaiter.InitializeReadWaiter(newBuffer) +func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) } func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {