diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 3fb61a5..d0acc2a 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -4,8 +4,6 @@ import ( "context" "io" "net" - "os" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -23,17 +21,34 @@ func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) { return Copy(w, r.Reader) } -func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool { - _, isTCPConn := dst.(*net.TCPConn) - if !isTCPConn { - return false - } - switch src.(type) { - case *net.TCPConn, *net.UnixConn, *os.File: - return false - default: - return true - } +func (r *readOnlyReader) Upstream() any { + return r.Reader +} + +func (r *readOnlyReader) ReaderReplaceable() bool { + return true +} + +type writeOnlyWriter struct { + io.Writer +} + +func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) { + return Copy(w.Writer, r) +} + +func (w *writeOnlyWriter) Upstream() any { + return w.Writer +} + +func (w *writeOnlyWriter) WriterReplaceable() bool { + return true +} + +func needWrapper(src, dst any) bool { + _, srcTCPConn := src.(*net.TCPConn) + _, dstTCPConn := dst.(*net.TCPConn) + return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn) } func Copy(dst io.Writer, src io.Reader) (n int64, err error) { @@ -45,10 +60,13 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) { src = N.UnwrapReader(src) dst = N.UnwrapWriter(dst) if wt, ok := src.(io.WriterTo); ok { + if needWrapper(dst, src) { + dst = &writeOnlyWriter{dst} + } return wt.WriteTo(dst) } if rt, ok := dst.(io.ReaderFrom); ok { - if needReadFromWrapper(rt, src) { + if needWrapper(rt, src) { src = &readOnlyReader{src} } return rt.ReadFrom(src) @@ -86,7 +104,7 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf. frontHeadroom := N.CalculateFrontHeadroom(dst) rearHeadroom := N.CalculateRearHeadroom(dst) readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) var notFirstTime bool for { readBuffer.Resize(frontHeadroom, 0) @@ -143,7 +161,7 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, for { buffer := buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) err = src.ReadBuffer(readBuffer) if err != nil { @@ -235,7 +253,7 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { var destination M.Socksaddr var notFirstTime bool readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) for { readBuffer.Resize(frontHeadroom, 0) destination, err = src.ReadPacket(readBuffer) @@ -293,7 +311,7 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er for { buffer := buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) readBuffer.Resize(frontHeadroom, 0) destination, err = src.ReadPacket(readBuffer) if err != nil { diff --git a/common/network/thread.go b/common/network/thread.go index 3b9787e..ae16590 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -136,13 +136,13 @@ func calculateWriterMTU(writer any) int { } if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream { upstreamMTU := calculateWriterMTU(upstream.Upstream()) - if mtu == 0 && upstreamMTU < mtu { + if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu { mtu = upstreamMTU } } if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream { upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter()) - if mtu == 0 && upstreamMTU < mtu { + if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu { mtu = upstreamMTU } } diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go index b56061f..be45cae 100644 --- a/protocol/trojan/protocol.go +++ b/protocol/trojan/protocol.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" ) @@ -25,7 +26,7 @@ const ( var CRLF = []byte{'\r', '\n'} type ClientConn struct { - net.Conn + N.ExtendedConn key [KeyLength]byte destination M.Socksaddr headerWritten bool @@ -33,17 +34,17 @@ type ClientConn struct { func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { return &ClientConn{ - Conn: conn, - key: key, - destination: destination, + ExtendedConn: bufio.NewExtendedConn(conn), + key: key, + destination: destination, } } func (c *ClientConn) Write(p []byte) (n int, err error) { if c.headerWritten { - return c.Conn.Write(p) + return c.ExtendedConn.Write(p) } - err = ClientHandshake(c.Conn, c.key, c.destination, p) + err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) if err != nil { return } @@ -54,10 +55,9 @@ func (c *ClientConn) Write(p []byte) (n int, err error) { func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { if c.headerWritten { - defer buffer.Release() - return common.Error(c.Conn.Write(buffer.Bytes())) + return c.ExtendedConn.WriteBuffer(buffer) } - err := ClientHandshakeBuffer(c.Conn, c.key, c.destination, buffer) + err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) if err != nil { return err } @@ -69,11 +69,11 @@ func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { if !c.headerWritten { return bufio.ReadFrom0(c, r) } - return bufio.Copy(c.Conn, r) + return bufio.Copy(c.ExtendedConn, r) } func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { - return bufio.Copy(w, c.Conn) + return bufio.Copy(w, c.ExtendedConn) } func (c *ClientConn) FrontHeadroom() int { @@ -84,7 +84,7 @@ func (c *ClientConn) FrontHeadroom() int { } func (c *ClientConn) Upstream() any { - return c.Conn + return c.ExtendedConn } type ClientPacketConn struct {