From 4f3ee611042b66e55f8e1c910bd920547c8aaa4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 15 Mar 2025 08:09:04 +0800 Subject: [PATCH] Fix copy early conn --- route/conn.go | 58 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/route/conn.go b/route/conn.go index 7b9f4a63..2824ffb1 100644 --- a/route/conn.go +++ b/route/conn.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/netip" + "os" "sync" "sync/atomic" "time" @@ -13,6 +14,7 @@ import ( "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" @@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } -func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - originSource := source - originDestination := destination +func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + var ( + sourceReader io.Reader = source + destinationWriter io.Writer = destination + ) var readCounters, writeCounters []N.CountFunc for { - source, readCounters = N.UnwrapCountReader(source, readCounters) - destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) - if cachedSrc, isCached := source.(N.CachedReader); isCached { + sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters) + destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters) + if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { dataLen := cachedBuffer.Len() @@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader if done.Swap(true) { onClose(err) } - common.Close(originSource, originDestination) + common.Close(source, destination) if !direction { m.logger.ErrorContext(ctx, "connection upload payload: ", err) } else { @@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader } break } - if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() { - _, err := destination.Write(nil) + if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() { + err := m.connectionCopyEarly(source, destination) if err != nil { + if done.Swap(true) { + onClose(err) + } + common.Close(source, destination) if !direction { m.logger.ErrorContext(ctx, "connection upload handshake: ", err) } else { @@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader return } } - _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) + _, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters) if err != nil { - common.Close(originDestination) + common.Close(source, destination) } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { err = duplexDst.CloseWrite() if err != nil { - common.Close(originSource, originDestination) + common.Close(source, destination) } } else { - common.Close(originDestination) + destination.Close() } if done.Swap(true) { onClose(err) - common.Close(originSource, originDestination) + common.Close(source, destination) } if !direction { if err == nil { @@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader } } +func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error { + payload := buf.NewPacket() + defer payload.Release() + err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout)) + if err != nil { + if err == os.ErrInvalid { + return common.Error(destination.Write(nil)) + } + return err + } + _, err = payload.ReadOnceFrom(source) + if err != nil && !E.IsTimeout(err) { + return E.Cause(err, "read payload") + } + _ = source.SetReadDeadline(time.Time{}) + _, err = destination.Write(payload.Bytes()) + if err != nil { + return E.Cause(err, "write payload") + } + return nil +} + func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { _, err := bufio.CopyPacket(destination, source) if !direction {