diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index ad869f3..735e440 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -100,6 +100,14 @@ func (r *BufferedReader) ReaderReplaceable() bool { return buffer == nil || buffer.Closed() } +func (r *BufferedReader) CreateReadWaiter() (ReadWaiter, bool) { + reader, created := CreateReadWaiter(r.upstream) + if !created { + return nil, false + } + return &bufferedReadWaiter{r, reader}, true +} + type BufferedWriter struct { upstream io.Writer buffer *buf.Buffer diff --git a/common/bufio/buffer_wait.go b/common/bufio/buffer_wait.go new file mode 100644 index 0000000..fffd544 --- /dev/null +++ b/common/bufio/buffer_wait.go @@ -0,0 +1,45 @@ +package bufio + +import ( + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" +) + +type bufferedReadWaiter struct { + *BufferedReader + upstream ReadWaiter +} + +func (w *bufferedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { + if w.buffer == nil { + return w.upstream.WaitReadBuffer(newBuffer) + } + if w.buffer.Closed() { + return os.ErrClosed + } + var err error + if w.buffer.IsEmpty() { + w.buffer.Reset() + w.buffer.IncRef() + err = w.upstream.WaitReadBuffer(func() *buf.Buffer { + return w.buffer + }) + w.buffer.DecRef() + if err != nil { + w.buffer.Release() + return err + } + } + buffer := newBuffer() + if w.buffer.Len() > buffer.FreeLen() { + err = common.Error(buffer.ReadFullFrom(w.buffer, buffer.FreeLen())) + } else { + err = common.Error(buffer.ReadFullFrom(w.buffer, w.buffer.Len())) + } + if err != nil { + w.buffer.Release() + } + return err +} diff --git a/common/bufio/cache.go b/common/bufio/cache.go index 90f68d3..af3fa96 100644 --- a/common/bufio/cache.go +++ b/common/bufio/cache.go @@ -83,6 +83,17 @@ func (c *CachedConn) WriterReplaceable() bool { return true } +func (c *CachedConn) CreateReadWaiter() (ReadWaiter, bool) { + reader, created := CreateReadWaiter(c.Conn) + if !created { + return nil, false + } + if c.buffer == nil { + return reader, true + } + return &cachedReadWaiter{reader, c.buffer}, true +} + func (c *CachedConn) Close() error { if buffer := c.buffer; buffer != nil { buffer.DecRef() @@ -207,6 +218,17 @@ func (c *CachedPacketConn) WriterReplaceable() bool { return true } +func (c *CachedPacketConn) CreatePacketReadWaiter() (PacketReadWaiter, bool) { + reader, created := CreatePacketReadWaiter(c.PacketConn) + if !created { + return nil, false + } + if c.buffer == nil { + return reader, true + } + return &cachedPacketReadWaiter{reader, c.buffer, c.destination}, true +} + func (c *CachedPacketConn) Close() error { if buffer := c.buffer; buffer != nil { buffer.DecRef() diff --git a/common/bufio/cache_wait.go b/common/bufio/cache_wait.go new file mode 100644 index 0000000..a7c36d0 --- /dev/null +++ b/common/bufio/cache_wait.go @@ -0,0 +1,49 @@ +package bufio + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type cachedReadWaiter struct { + reader ReadWaiter + cache *buf.Buffer +} + +func (c *cachedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { + cache := c.cache + if cache != nil { + var err error + if !cache.IsEmpty() { + _, err = newBuffer().ReadOnceFrom(c.cache) + } + if cache.IsEmpty() { + cache.Release() + c.cache = nil + } + return err + } + return c.reader.WaitReadBuffer(newBuffer) +} + +type cachedPacketReadWaiter struct { + reader PacketReadWaiter + cache *buf.Buffer + destination M.Socksaddr +} + +func (c *cachedPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + cache := c.cache + if cache != nil { + if !cache.IsEmpty() { + _, err = newBuffer().ReadOnceFrom(c.cache) + } + if cache.IsEmpty() { + cache.Release() + c.cache = nil + } + destination = c.destination + return + } + return c.reader.WaitReadPacket(newBuffer) +} diff --git a/common/bufio/chunk.go b/common/bufio/chunk.go index 9c60f80..583da46 100644 --- a/common/bufio/chunk.go +++ b/common/bufio/chunk.go @@ -76,6 +76,10 @@ func (c *ChunkReader) MTU() int { return c.maxChunkSize } +func (c *ChunkReader) Upstream() any { + return c.upstream +} + type ChunkWriter struct { upstream N.ExtendedWriter maxChunkSize int diff --git a/common/bufio/copy.go b/common/bufio/copy.go index b6b4010..28c4f78 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -84,6 +84,11 @@ func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc) } } + if common.LowMemory { + if reader, created := CreateReadWaiter(src); created { + return CopyExtendedWithPoolWait(origDst, dst, reader) + } + } if N.IsUnsafeWriter(dst) { return CopyExtendedWithPool(origDst, dst, src) } @@ -185,6 +190,48 @@ func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.Extende } } +func CopyExtendedWithPoolWait(origDst io.Writer, dst N.ExtendedWriter, src ReadWaiter) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.BufferSize + } + var ( + buffer *buf.Buffer + readBuffer *buf.Buffer + ) + newBuffer := func() *buf.Buffer { + buffer = buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + return readBuffer + } + var notFirstTime bool + for { + err = src.WaitReadBuffer(newBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(origDst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WriteBuffer(buffer) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { return CopyConnContextList([]context.Context{ctx}, conn, dest) } @@ -241,6 +288,11 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { return CopyPacketWithSrcBuffer(dst, safeSrc) } } + if common.LowMemory { + if reader, created := CreatePacketReadWaiter(src); created { + return CopyPacketWithPoolWait(dst, reader) + } + } if N.IsUnsafeWriter(dst) { return CopyPacketWithPool(dst, src) } @@ -339,6 +391,49 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er } } +func CopyPacketWithPoolWait(dst N.PacketWriter, src PacketReadWaiter) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.UDPBufferSize + } + var ( + buffer *buf.Buffer + readBuffer *buf.Buffer + ) + newBuffer := func() *buf.Buffer { + buffer = buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + return readBuffer + } + var destination M.Socksaddr + var notFirstTime bool + for { + destination, err = src.WaitReadPacket(newBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WritePacket(buffer, destination) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { return CopyPacketConnContextList([]context.Context{ctx}, conn, dest) } diff --git a/common/bufio/deadline/chunk_wait.go b/common/bufio/deadline/chunk_wait.go new file mode 100644 index 0000000..d73be9e --- /dev/null +++ b/common/bufio/deadline/chunk_wait.go @@ -0,0 +1,3 @@ +package deadline + + diff --git a/common/bufio/vectorised.go b/common/bufio/vectorised.go index ef875fd..7339ed5 100644 --- a/common/bufio/vectorised.go +++ b/common/bufio/vectorised.go @@ -19,6 +19,9 @@ func NewVectorisedWriter(writer io.Writer) N.VectorisedWriter { } func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) { + if ioWriter, isWriter := writer.(io.Writer); isWriter { + writer = N.UnwrapWriter(ioWriter) + } switch w := writer.(type) { case N.VectorisedWriter: return w, true @@ -42,6 +45,9 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) { } func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) { + if packetWriter, isWriter := writer.(N.PacketWriter); isWriter { + writer = N.UnwrapPacketWriter(packetWriter) + } switch w := writer.(type) { case N.VectorisedPacketWriter: return w, true diff --git a/common/bufio/wait.go b/common/bufio/wait.go new file mode 100644 index 0000000..59d858e --- /dev/null +++ b/common/bufio/wait.go @@ -0,0 +1,114 @@ +package bufio + +import ( + "net/netip" + "syscall" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type ReadWaiter interface { + WaitReadBuffer(newBuffer func() *buf.Buffer) error +} + +type PacketReadWaiter interface { + WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) +} + +type ReadWaiterCreator interface { + CreateReadWaiter() (ReadWaiter, bool) +} + +type PacketReadWaiterCreator interface { + CreatePacketReadWaiter() (PacketReadWaiter, bool) +} + +func CreateReadWaiter(conn any) (ReadWaiter, bool) { + if waiter, loaded := common.Cast[ReadWaiterCreator](conn); loaded { + return waiter.CreateReadWaiter() + } + if waiter, loaded := common.Cast[ReadWaiter](conn); loaded { + return waiter, true + } + if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded { + return &syscallReadWaiter{rawConn}, true + } + if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded { + rawConn, err := syscallConn.SyscallConn() + if err != nil { + return nil, false + } + return &syscallReadWaiter{rawConn}, true + } + return nil, false +} + +func CreatePacketReadWaiter(conn any) (PacketReadWaiter, bool) { + if waiter, loaded := common.Cast[PacketReadWaiterCreator](conn); loaded { + return waiter.CreatePacketReadWaiter() + } + if waiter, loaded := common.Cast[PacketReadWaiter](conn); loaded { + return waiter, true + } + if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded { + return &syscallReadWaiter{rawConn}, true + } + if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded { + rawConn, err := syscallConn.SyscallConn() + if err != nil { + return nil, false + } + return &syscallReadWaiter{rawConn}, true + } + return nil, false +} + +type syscallReadWaiter struct { + syscall.RawConn +} + +func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { + var ( + buffer *buf.Buffer + n int + err error + ) + err = w.RawConn.Read(func(fd uintptr) (done bool) { + buffer = newBuffer() + n, err = syscall.Read(int(fd), buffer.FreeBytes()) + if err == syscall.EAGAIN { + buffer.Release() + return false + } + buffer.Truncate(n) + return true + }) + return err +} + +func (w *syscallReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + var ( + buffer *buf.Buffer + n int + from syscall.Sockaddr + ) + err = w.RawConn.Read(func(fd uintptr) (done bool) { + buffer = newBuffer() + n, _, _, from, err = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) + if err == syscall.EAGAIN { + buffer.Release() + return false + } + buffer.Truncate(n) + return true + }) + switch fromAddr := from.(type) { + case *syscall.SockaddrInet4: + destination = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) + case *syscall.SockaddrInet6: + destination = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)) + } + return destination, err +} diff --git a/common/udpnat/service.go b/common/udpnat/service.go index 4b978c1..e276828 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -71,6 +71,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu c, loaded := s.nat.LoadOrStore(key, func() *conn { c := &conn{ data: make(chan packet, 64), + newData: make(chan struct{}, 1), localAddr: metadata.Source, remoteAddr: metadata.Destination, } @@ -97,6 +98,10 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu } return } + select { + case c.newData <- struct{}{}: + default: + } c.data <- packet{ data: buffer, destination: metadata.Destination, @@ -116,6 +121,7 @@ type conn struct { ctx context.Context cancel common.ContextCancelCauseFunc data chan packet + newData chan struct{} localAddr M.Socksaddr remoteAddr M.Socksaddr source N.PacketWriter @@ -161,12 +167,24 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr)) } +func (c *conn) WaitRead() { + select { + case <-c.newData: + case <-c.ctx.Done(): + } +} + func (c *conn) Close() error { select { case <-c.ctx.Done(): default: c.cancel(net.ErrClosed) } + select { + case <-c.newData: + default: + close(c.newData) + } if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { return sourceCloser.Close() }