diff --git a/common/bufio/cache.go b/common/bufio/cache.go index 90f68d3..e1eda0a 100644 --- a/common/bufio/cache.go +++ b/common/bufio/cache.go @@ -186,13 +186,16 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad return c.PacketConn.ReadPacket(buffer) } -func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) { - buffer = c.buffer +func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer { + buffer := c.buffer c.buffer = nil if buffer != nil { buffer.DecRef() } - return c.destination, buffer + return &N.PacketBuffer{ + Buffer: buffer, + Destination: c.destination, + } } func (c *CachedPacketConn) Upstream() any { diff --git a/common/bufio/conn.go b/common/bufio/conn.go index 95af4d5..1b92589 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -125,7 +125,7 @@ func (w *ExtendedWriterWrapper) Upstream() any { return w.Writer } -func (w *ExtendedReaderWrapper) WriterReplaceable() bool { +func (w *ExtendedWriterWrapper) WriterReplaceable() bool { return true } diff --git a/common/bufio/copy.go b/common/bufio/copy.go index b6b4010..ef42b31 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -2,8 +2,10 @@ package bufio import ( "context" + "errors" "io" "net" + "syscall" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -14,80 +16,65 @@ import ( "github.com/sagernet/sing/common/task" ) -type readOnlyReader struct { - io.Reader -} - -func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) { - return Copy(w, r.Reader) -} - -func (r *readOnlyReader) Upstream() any { - return r.Reader -} - -func (r *readOnlyReader) ReaderReplaceable() bool { - return true -} - -type writeOnlyWriter struct { - io.Writer -} - -func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) { - return Copy(w.Writer, r) -} - -func (w *writeOnlyWriter) Upstream() any { - return w.Writer -} - -func (w *writeOnlyWriter) WriterReplaceable() bool { - return true -} - -func needWrapper(src, dst any) bool { - _, srcTCPConn := src.(*net.TCPConn) - _, dstTCPConn := dst.(*net.TCPConn) - return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn) -} - -func Copy(dst io.Writer, src io.Reader) (n int64, err error) { - if src == nil { +func Copy(destination io.Writer, source io.Reader) (n int64, err error) { + if source == nil { return 0, E.New("nil reader") - } else if dst == nil { + } else if destination == nil { return 0, E.New("nil writer") } - origDst := dst - src = N.UnwrapReader(src) - dst = N.UnwrapWriter(dst) - if wt, ok := src.(io.WriterTo); ok { - if needWrapper(dst, src) { - dst = &writeOnlyWriter{dst} + originDestination := destination + var readCounters, writeCounters []N.CountFunc + for { + source, readCounters = N.UnwrapCountReader(source, readCounters) + destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) + if cachedSrc, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedSrc.ReadCached() + if cachedBuffer != nil { + if !cachedBuffer.IsEmpty() { + _, err = destination.Write(cachedBuffer.Bytes()) + if err != nil { + cachedBuffer.Release() + return + } + } + cachedBuffer.Release() + continue + } } - return wt.WriteTo(dst) - } - if rt, ok := dst.(io.ReaderFrom); ok { - if needWrapper(rt, src) { - src = &readOnlyReader{src} + srcSyscallConn, srcIsSyscall := source.(syscall.Conn) + dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) + if srcIsSyscall && dstIsSyscall { + var handled bool + handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) + if handled { + return + } } - return rt.ReadFrom(src) + break } - return CopyExtended(origDst, NewExtendedWriter(dst), NewExtendedReader(src)) + return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } -func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { - safeSrc := N.IsSafeReader(src) - headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst) +func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + safeSrc := N.IsSafeReader(source) + headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) if safeSrc != nil { if headroom == 0 { - return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc) + return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters) } } - if N.IsUnsafeWriter(dst) { - return CopyExtendedWithPool(origDst, dst, src) + readWaiter, isReadWaiter := CreateReadWaiter(source) + if isReadWaiter { + var handled bool + handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters) + if handled { + return + } } - bufferSize := N.CalculateMTU(src, dst) + if N.IsUnsafeWriter(destination) { + return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) + } + bufferSize := N.CalculateMTU(source, destination) if bufferSize > 0 { bufferSize += headroom } else { @@ -97,63 +84,83 @@ func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - return CopyExtendedBuffer(origDst, dst, src, buffer) + return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters) } -func CopyExtendedBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) { +func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) + frontHeadroom := N.CalculateFrontHeadroom(destination) + rearHeadroom := N.CalculateRearHeadroom(destination) readBufferRaw := buffer.Slice() readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) var notFirstTime bool for { readBuffer.Resize(frontHeadroom, 0) - err = src.ReadBuffer(readBuffer) + err = source.ReadBuffer(readBuffer) if err != nil { + if errors.Is(err, io.EOF) { + err = nil + return + } if !notFirstTime { - err = N.HandshakeFailure(origDst, err) + err = N.HandshakeFailure(originDestination, err) } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) + err = destination.WriteBuffer(buffer) if err != nil { return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } notFirstTime = true } } -func CopyExtendedWithSrcBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) { +func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var notFirstTime bool for { var buffer *buf.Buffer - buffer, err = src.ReadBufferThreadSafe() + buffer, err = source.ReadBufferThreadSafe() if err != nil { + if errors.Is(err, io.EOF) { + err = nil + return + } if !notFirstTime { - err = N.HandshakeFailure(origDst, err) + err = N.HandshakeFailure(originDestination, err) } return } dataLen := buffer.Len() - err = dst.WriteBuffer(buffer) + err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } notFirstTime = true } } -func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) - bufferSize := N.CalculateMTU(src, dst) +func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(destination) + rearHeadroom := N.CalculateRearHeadroom(destination) + bufferSize := N.CalculateMTU(source, destination) if bufferSize > 0 { bufferSize += frontHeadroom + rearHeadroom } else { @@ -165,86 +172,129 @@ func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.Extende readBufferRaw := buffer.Slice() readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) - err = src.ReadBuffer(readBuffer) + err = source.ReadBuffer(readBuffer) if err != nil { buffer.Release() + if errors.Is(err, io.EOF) { + err = nil + return + } if !notFirstTime { - err = N.HandshakeFailure(origDst, err) + err = N.HandshakeFailure(originDestination, err) } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) + err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } notFirstTime = true } } -func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { - return CopyConnContextList([]context.Context{ctx}, conn, dest) +func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { + return CopyConnContextList([]context.Context{ctx}, source, destination) } -func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error { +func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { var group task.Group - if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex { + if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex { group.Append("upload", func(ctx context.Context) error { - err := common.Error(Copy(dest, conn)) - if E.IsMulti(err, io.EOF) { - rw.CloseWrite(dest) + err := common.Error(Copy(destination, source)) + if err == nil { + rw.CloseWrite(destination) } else { - common.Close(dest) + common.Close(destination) } return err }) } else { group.Append("upload", func(ctx context.Context) error { - defer common.Close(dest) - return common.Error(Copy(dest, conn)) + defer common.Close(destination) + return common.Error(Copy(destination, source)) }) } - if _, srcDuplex := common.Cast[rw.WriteCloser](conn); srcDuplex { + if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex { group.Append("download", func(ctx context.Context) error { - err := common.Error(Copy(conn, dest)) - if E.IsMulti(err, io.EOF) { - rw.CloseWrite(conn) + err := common.Error(Copy(source, destination)) + if err == nil { + rw.CloseWrite(source) } else { - common.Close(conn) + common.Close(source) } return err }) } else { group.Append("download", func(ctx context.Context) error { - defer common.Close(conn) - return common.Error(Copy(conn, dest)) + defer common.Close(source) + return common.Error(Copy(source, destination)) }) } group.Cleanup(func() { - common.Close(conn, dest) + common.Close(source, destination) }) return group.RunContextList(contextList) } -func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { - src = N.UnwrapPacketReader(src) - dst = N.UnwrapPacketWriter(dst) - safeSrc := N.IsSafePacketReader(src) - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) +func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { + var readCounters, writeCounters []N.CountFunc + var cachedPackets []*N.PacketBuffer + for { + source, readCounters = N.UnwrapCountPacketReader(source, readCounters) + destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) + if cachedReader, isCached := source.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + cachedPackets = append([]*N.PacketBuffer{packet}, cachedPackets...) + continue + } + } + break + } + if cachedPackets != nil { + n, err = WritePacketWithPool(destinationConn, cachedPackets) + if err != nil { + return + } + } + safeSrc := N.IsSafePacketReader(source) + frontHeadroom := N.CalculateFrontHeadroom(destinationConn) + rearHeadroom := N.CalculateRearHeadroom(destinationConn) headroom := frontHeadroom + rearHeadroom if safeSrc != nil { if headroom == 0 { - return CopyPacketWithSrcBuffer(dst, safeSrc) + var copyN int64 + copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters) + n += copyN + return } } - if N.IsUnsafeWriter(dst) { - return CopyPacketWithPool(dst, src) + readWaiter, isReadWaiter := CreatePacketReadWaiter(source) + if isReadWaiter { + var ( + handled bool + copeN int64 + ) + handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters) + if handled { + n += copeN + return + } } - bufferSize := N.CalculateMTU(src, dst) + if N.IsUnsafeWriter(destinationConn) { + return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) + } + bufferSize := N.CalculateMTU(source, destinationConn) if bufferSize > 0 { bufferSize += headroom } else { @@ -262,51 +312,66 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) for { readBuffer.Resize(frontHeadroom, 0) - destination, err = src.ReadPacket(readBuffer) + destination, err = source.ReadPacket(readBuffer) if err != nil { if !notFirstTime { - err = N.HandshakeFailure(dst, err) + err = N.HandshakeFailure(destinationConn, err) } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WritePacket(buffer, destination) + err = destinationConn.WritePacket(buffer, destination) if err != nil { return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } notFirstTime = true } } -func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) { +func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var buffer *buf.Buffer var destination M.Socksaddr var notFirstTime bool for { - buffer, destination, err = src.ReadPacketThreadSafe() + buffer, destination, err = source.ReadPacketThreadSafe() if err != nil { if !notFirstTime { - err = N.HandshakeFailure(dst, err) + err = N.HandshakeFailure(destinationConn, err) } return } dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) + if dataLen == 0 { + continue + } + err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } notFirstTime = true } } -func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) - bufferSize := N.CalculateMTU(src, dst) +func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(destinationConn) + rearHeadroom := N.CalculateRearHeadroom(destinationConn) + bufferSize := N.CalculateMTU(source, destinationConn) if bufferSize > 0 { bufferSize += frontHeadroom + rearHeadroom } else { @@ -319,40 +384,72 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er readBufferRaw := buffer.Slice() readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) - destination, err = src.ReadPacket(readBuffer) + destination, err = source.ReadPacket(readBuffer) if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(dst, err) + err = N.HandshakeFailure(destinationConn, err) } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WritePacket(buffer, destination) + err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() return } n += int64(dataLen) + for _, counter := range readCounters { + counter(n) + } + for _, counter := range writeCounters { + counter(n) + } notFirstTime = true } } -func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { - return CopyPacketConnContextList([]context.Context{ctx}, conn, dest) +func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(destinationConn) + rearHeadroom := N.CalculateRearHeadroom(destinationConn) + var destination M.Socksaddr + for _, packetBuffer := range packetBuffers { + buffer := buf.NewPacket() + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + _, err = readBuffer.Write(packetBuffer.Buffer.Bytes()) + packetBuffer.Buffer.Release() + if err != nil { + continue + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = destinationConn.WritePacket(buffer, destination) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + } + return } -func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error { +func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { + return CopyPacketConnContextList([]context.Context{ctx}, source, destination) +} + +func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { var group task.Group group.Append("upload", func(ctx context.Context) error { - return common.Error(CopyPacket(dest, conn)) + return common.Error(CopyPacket(destination, source)) }) group.Append("download", func(ctx context.Context) error { - return common.Error(CopyPacket(conn, dest)) + return common.Error(CopyPacket(source, destination)) }) group.Cleanup(func() { - common.Close(conn, dest) + common.Close(source, destination) }) group.FastFail() return group.RunContextList(contextList) diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go new file mode 100644 index 0000000..1648c03 --- /dev/null +++ b/common/bufio/copy_direct.go @@ -0,0 +1,20 @@ +package bufio + +import ( + "syscall" + + N "github.com/sagernet/sing/common/network" +) + +func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { + rawSource, err := source.SyscallConn() + if err != nil { + return + } + rawDestination, err := destination.SyscallConn() + if err != nil { + return + } + handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters) + return +} diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go new file mode 100644 index 0000000..8f79c94 --- /dev/null +++ b/common/bufio/copy_direct_posix.go @@ -0,0 +1,232 @@ +//go:build !windows + +package bufio + +import ( + "errors" + "io" + "net/netip" + "syscall" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + handled = true + frontHeadroom := N.CalculateFrontHeadroom(destination) + rearHeadroom := N.CalculateRearHeadroom(destination) + bufferSize := N.CalculateMTU(source, destination) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.BufferSize + } + var ( + buffer *buf.Buffer + readBuffer *buf.Buffer + ) + newBuffer := func() *buf.Buffer { + if buffer != nil { + buffer.Release() + } + buffer = buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + return readBuffer + } + var notFirstTime bool + for { + err = source.WaitReadBuffer(newBuffer) + if err != nil { + buffer.Release() + if errors.Is(err, io.EOF) { + err = nil + return + } + if !notFirstTime { + err = N.HandshakeFailure(originDestination, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = destination.WriteBuffer(buffer) + if err != nil { + if buffer != nil { + buffer.Release() + } + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} + +func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + handled = true + frontHeadroom := N.CalculateFrontHeadroom(destinationConn) + rearHeadroom := N.CalculateRearHeadroom(destinationConn) + bufferSize := N.CalculateMTU(source, destinationConn) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.UDPBufferSize + } + var ( + buffer *buf.Buffer + readBuffer *buf.Buffer + ) + newBuffer := func() *buf.Buffer { + if buffer != nil { + buffer.Release() + } + buffer = buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + return readBuffer + } + var destination M.Socksaddr + var notFirstTime bool + for { + destination, err = source.WaitReadPacket(newBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(destinationConn, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = destinationConn.WritePacket(buffer, destination) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} + +type syscallReadWaiter struct { + rawConn syscall.RawConn + readErr error + readFunc func(fd uintptr) (done bool) +} + +func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { + if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return &syscallReadWaiter{rawConn: rawConn}, true + } + } + return nil, false +} + +func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error { + if w.readFunc == nil { + w.readFunc = func(fd uintptr) (done bool) { + buffer := newBuffer() + var readN int + readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) + if readN > 0 { + buffer.Truncate(readN) + } else { + buffer.Release() + buffer = nil + } + if w.readErr == syscall.EAGAIN { + return false + } + if readN == 0 { + w.readErr = io.EOF + } + return true + } + } + err := w.rawConn.Read(w.readFunc) + if err != nil { + return err + } + if w.readErr != nil { + return E.Cause(w.readErr, "raw read") + } + return nil +} + +type syscallPacketReadWaiter struct { + rawConn syscall.RawConn + readErr error + readFrom M.Socksaddr + readFunc func(fd uintptr) (done bool) +} + +func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { + if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return &syscallPacketReadWaiter{rawConn: rawConn}, true + } + } + return nil, false +} + +func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + if w.readFunc == nil { + w.readFunc = func(fd uintptr) (done bool) { + buffer := newBuffer() + var readN int + var from syscall.Sockaddr + readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) + if readN > 0 { + buffer.Truncate(readN) + } else { + buffer.Release() + buffer = nil + } + if w.readErr == syscall.EAGAIN { + return false + } + if from != nil { + switch fromAddr := from.(type) { + case *syscall.SockaddrInet4: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) + case *syscall.SockaddrInet6: + w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)) + } + } + if readN == 0 { + w.readErr = io.EOF + } + return true + } + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + err = E.Cause(w.readErr, "raw read") + return + } + destination = w.readFrom + return +} diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go new file mode 100644 index 0000000..9c0743f --- /dev/null +++ b/common/bufio/copy_direct_windows.go @@ -0,0 +1,23 @@ +package bufio + +import ( + "io" + + N "github.com/sagernet/sing/common/network" +) + +func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + return +} + +func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + return +} + +func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) { + return nil, false +} + +func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) { + return nil, false +} diff --git a/common/bufio/counter_conn.go b/common/bufio/counter_conn.go new file mode 100644 index 0000000..5fa4731 --- /dev/null +++ b/common/bufio/counter_conn.go @@ -0,0 +1,96 @@ +package bufio + +import ( + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +func NewInt64CounterConn(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterConn { + return &CounterConn{ + NewExtendedConn(conn), + common.Map(readCounter, func(it *atomic.Int64) N.CountFunc { + return func(n int64) { + it.Add(n) + } + }), + common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc { + return func(n int64) { + it.Add(n) + } + }), + } +} + +func NewCounterConn(conn net.Conn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterConn { + return &CounterConn{NewExtendedConn(conn), readCounter, writeCounter} +} + +type CounterConn struct { + N.ExtendedConn + readCounter []N.CountFunc + writeCounter []N.CountFunc +} + +func (c *CounterConn) Read(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Read(p) + if n > 0 { + for _, counter := range c.readCounter { + counter(int64(n)) + } + } + return n, err +} + +func (c *CounterConn) ReadBuffer(buffer *buf.Buffer) error { + err := c.ExtendedConn.ReadBuffer(buffer) + if err != nil { + return err + } + if buffer.Len() > 0 { + for _, counter := range c.readCounter { + counter(int64(buffer.Len())) + } + } + return nil +} + +func (c *CounterConn) Write(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Write(p) + if n > 0 { + for _, counter := range c.writeCounter { + counter(int64(n)) + } + } + return n, err +} + +func (c *CounterConn) WriteBuffer(buffer *buf.Buffer) error { + dataLen := int64(buffer.Len()) + err := c.ExtendedConn.WriteBuffer(buffer) + if err != nil { + return err + } + if dataLen > 0 { + for _, counter := range c.writeCounter { + counter(dataLen) + } + } + return nil +} + +func (c *CounterConn) UnwrapReader() (io.Reader, []N.CountFunc) { + return c.ExtendedConn, c.readCounter +} + +func (c *CounterConn) UnwrapWriter() (io.Writer, []N.CountFunc) { + return c.ExtendedConn, c.writeCounter +} + +func (c *CounterConn) Upstream() any { + return c.ExtendedConn +} diff --git a/common/bufio/counter_packet_conn.go b/common/bufio/counter_packet_conn.go new file mode 100644 index 0000000..57aa178 --- /dev/null +++ b/common/bufio/counter_packet_conn.go @@ -0,0 +1,73 @@ +package bufio + +import ( + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type CounterPacketConn struct { + N.PacketConn + readCounter []N.CountFunc + writeCounter []N.CountFunc +} + +func NewInt64CounterPacketConn(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterPacketConn { + return &CounterPacketConn{ + conn, + common.Map(readCounter, func(it *atomic.Int64) N.CountFunc { + return func(n int64) { + it.Add(n) + } + }), + common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc { + return func(n int64) { + it.Add(n) + } + }), + } +} + +func NewCounterPacketConn(conn N.PacketConn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterPacketConn { + return &CounterPacketConn{conn, readCounter, writeCounter} +} + +func (c *CounterPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.PacketConn.ReadPacket(buffer) + if err == nil { + if buffer.Len() > 0 { + for _, counter := range c.readCounter { + counter(int64(buffer.Len())) + } + } + } + return +} + +func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + dataLen := int64(buffer.Len()) + err := c.PacketConn.WritePacket(buffer, destination) + if err != nil { + return err + } + if dataLen > 0 { + for _, counter := range c.writeCounter { + counter(dataLen) + } + } + return nil +} + +func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) { + return c.PacketConn, c.readCounter +} + +func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) { + return c.PacketConn, c.writeCounter +} + +func (c *CounterPacketConn) Upstream() any { + return c.PacketConn +} diff --git a/common/bufio/splice_linux.go b/common/bufio/splice_linux.go new file mode 100644 index 0000000..afcc694 --- /dev/null +++ b/common/bufio/splice_linux.go @@ -0,0 +1,79 @@ +package bufio + +import ( + "syscall" + + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/unix" +) + +const maxSpliceSize = 1 << 20 + +func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { + handed = true + var pipeFDs [2]int + err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK) + if err != nil { + return + } + defer unix.Close(pipeFDs[0]) + defer unix.Close(pipeFDs[1]) + + _, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize) + var readN int + var readErr error + var writeErr error + readFunc := func(fd uintptr) (done bool) { + p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK) + readN = int(p0) + readErr = p1 + return readErr != unix.EAGAIN + } + writeFunc := func(fd uintptr) (done bool) { + var writeN int + size := readN + for size > 0 { + p0, p1 := unix.Splice(pipeFDs[0], nil, int(fd), nil, size, unix.SPLICE_F_NONBLOCK|unix.SPLICE_F_MOVE) + writeN = int(p0) + writeErr = p1 + if writeErr != nil { + return writeErr != unix.EAGAIN + } + size -= int(writeN) + } + return true + } + for { + err = source.Read(readFunc) + if err != nil { + readErr = err + } + if readErr != nil { + if readErr == unix.EINVAL || readErr == unix.ENOSYS { + handed = false + return + } + err = E.Cause(readErr, "splice read") + return + } + if readN == 0 { + return + } + err = destination.Write(writeFunc) + if err != nil { + writeErr = err + } + if writeErr != nil { + err = E.Cause(writeErr, "splice write") + return + } + for _, readCounter := range readCounters { + readCounter(int64(readN)) + } + for _, writeCounter := range writeCounters { + writeCounter(int64(readN)) + } + } +} diff --git a/common/bufio/splice_stub.go b/common/bufio/splice_stub.go new file mode 100644 index 0000000..44c93b5 --- /dev/null +++ b/common/bufio/splice_stub.go @@ -0,0 +1,13 @@ +//go:build !linux + +package bufio + +import ( + "syscall" + + N "github.com/sagernet/sing/common/network" +) + +func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { + return +} diff --git a/common/bufio/wait.go b/common/bufio/wait.go new file mode 100644 index 0000000..3072d7a --- /dev/null +++ b/common/bufio/wait.go @@ -0,0 +1,35 @@ +package bufio + +import ( + "io" + + N "github.com/sagernet/sing/common/network" +) + +func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { + reader = N.UnwrapReader(reader) + if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter { + return readWaiter, true + } + if readWaitCreator, isCreator := reader.(N.ReadWaitCreator); isCreator { + return readWaitCreator.CreateReadWaiter() + } + if readWaiter, created := createSyscallReadWaiter(reader); created { + return readWaiter, true + } + return nil, false +} + +func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { + reader = N.UnwrapPacketReader(reader) + if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter { + return readWaiter, true + } + if readWaitCreator, isCreator := reader.(N.PacketReadWaitCreator); isCreator { + return readWaitCreator.CreateReadWaiter() + } + if readWaiter, created := createSyscallPacketReadWaiter(reader); created { + return readWaiter, true + } + return nil, false +} diff --git a/common/network/conn.go b/common/network/conn.go index dec3486..9a7de21 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -89,7 +89,12 @@ type CachedReader interface { } type CachedPacketReader interface { - ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) + ReadCachedPacket() *PacketBuffer +} + +type PacketBuffer struct { + Buffer *buf.Buffer + Destination M.Socksaddr } type WithUpstreamReader interface { @@ -118,7 +123,7 @@ func UnwrapReader(reader io.Reader) io.Reader { if u, ok := reader.(common.WithUpstream); ok { return UnwrapReader(u.Upstream().(io.Reader)) } - panic("bad reader") + return reader } func UnwrapPacketReader(reader PacketReader) PacketReader { @@ -131,7 +136,7 @@ func UnwrapPacketReader(reader PacketReader) PacketReader { if u, ok := reader.(common.WithUpstream); ok { return UnwrapPacketReader(u.Upstream().(PacketReader)) } - panic("bad reader") + return reader } func UnwrapWriter(writer io.Writer) io.Writer { @@ -144,7 +149,7 @@ func UnwrapWriter(writer io.Writer) io.Writer { if u, ok := writer.(common.WithUpstream); ok { return UnwrapWriter(u.Upstream().(io.Writer)) } - panic("bad writer") + return writer } func UnwrapPacketWriter(writer PacketWriter) PacketWriter { @@ -157,5 +162,5 @@ func UnwrapPacketWriter(writer PacketWriter) PacketWriter { if u, ok := writer.(common.WithUpstream); ok { return UnwrapPacketWriter(u.Upstream().(PacketWriter)) } - panic("bad writer") + return writer } diff --git a/common/network/counter.go b/common/network/counter.go new file mode 100644 index 0000000..a20c4d9 --- /dev/null +++ b/common/network/counter.go @@ -0,0 +1,67 @@ +package network + +import ( + "io" +) + +type CountFunc func(n int64) + +type ReadCounter interface { + io.Reader + UnwrapReader() (io.Reader, []CountFunc) +} + +type WriteCounter interface { + io.Writer + UnwrapWriter() (io.Writer, []CountFunc) +} + +type PacketReadCounter interface { + PacketReader + UnwrapPacketReader() (PacketReader, []CountFunc) +} + +type PacketWriteCounter interface { + PacketWriter + UnwrapPacketWriter() (PacketWriter, []CountFunc) +} + +func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []CountFunc) { + reader = UnwrapReader(reader) + if counter, isCounter := reader.(ReadCounter); isCounter { + upstreamReader, upstreamCountFunc := counter.UnwrapReader() + countFunc = append(countFunc, upstreamCountFunc...) + return UnwrapCountReader(upstreamReader, countFunc) + } + return reader, countFunc +} + +func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []CountFunc) { + writer = UnwrapWriter(writer) + if counter, isCounter := writer.(WriteCounter); isCounter { + upstreamWriter, upstreamCountFunc := counter.UnwrapWriter() + countFunc = append(countFunc, upstreamCountFunc...) + return UnwrapCountWriter(upstreamWriter, countFunc) + } + return writer, countFunc +} + +func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (PacketReader, []CountFunc) { + reader = UnwrapPacketReader(reader) + if counter, isCounter := reader.(PacketReadCounter); isCounter { + upstreamReader, upstreamCountFunc := counter.UnwrapPacketReader() + countFunc = append(countFunc, upstreamCountFunc...) + return UnwrapCountPacketReader(upstreamReader, countFunc) + } + return reader, countFunc +} + +func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (PacketWriter, []CountFunc) { + writer = UnwrapPacketWriter(writer) + if counter, isCounter := writer.(PacketWriteCounter); isCounter { + upstreamWriter, upstreamCountFunc := counter.UnwrapPacketWriter() + countFunc = append(countFunc, upstreamCountFunc...) + return UnwrapCountPacketWriter(upstreamWriter, countFunc) + } + return writer, countFunc +} diff --git a/common/network/direct.go b/common/network/direct.go new file mode 100644 index 0000000..0f09e0a --- /dev/null +++ b/common/network/direct.go @@ -0,0 +1,22 @@ +package network + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type ReadWaiter interface { + WaitReadBuffer(newBuffer func() *buf.Buffer) error +} + +type ReadWaitCreator interface { + CreateReadWaiter() (ReadWaiter, bool) +} + +type PacketReadWaiter interface { + WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) +} + +type PacketReadWaitCreator interface { + CreateReadWaiter() (PacketReadWaiter, bool) +} diff --git a/common/udpnat/service.go b/common/udpnat/service.go index 17fce00..f43dc92 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -136,6 +136,17 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } +func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = newBuffer().ReadOnceFrom(p.data) + p.data.Release() + return p.destination, err + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.source.WritePacket(buffer, destination) }