diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index d37163e..c60e7e6 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -29,6 +29,10 @@ func (r *BufferedReader) Read(p []byte) (n int, err error) { return r.buffer.Read(p) } +func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) { + return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer) +} + func (w *BufferedReader) Upstream() any { return w.upstream } diff --git a/common/bufio/conn.go b/common/bufio/conn.go index 834d007..aac59c1 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "os" "time" "github.com/sagernet/sing/common" @@ -14,6 +15,27 @@ import ( "github.com/sagernet/sing/common/task" ) +type readOnlyReader struct { + io.Reader +} + +func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) { + return Copy(w, r.Reader) +} + +func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool { + _, isTCPConn := dst.(*net.TCPConn) + if !isTCPConn { + return false + } + switch src.(type) { + case *net.TCPConn, *net.UnixConn, *os.File: + return false + default: + return true + } +} + func Copy(dst io.Writer, src io.Reader) (n int64, err error) { src = N.UnwrapReader(src) dst = N.UnwrapWriter(dst) @@ -21,6 +43,9 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) { return wt.WriteTo(dst) } if rt, ok := dst.(io.ReaderFrom); ok { + if needReadFromWrapper(rt, src) { + src = &readOnlyReader{src} + } return rt.ReadFrom(src) } return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src)) @@ -37,7 +62,11 @@ func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err erro _buffer := buf.StackNew() defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + + return CopyExtendedBuffer(dst, src, common.Dup(_buffer)) +} + +func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() for {