Fix copy early conn

This commit is contained in:
世界 2025-03-15 08:09:04 +08:00
parent 96eb98c00a
commit 4f3ee61104
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -13,6 +14,7 @@ import (
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler" "github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
} }
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source var (
originDestination := destination sourceReader io.Reader = source
destinationWriter io.Writer = destination
)
var readCounters, writeCounters []N.CountFunc var readCounters, writeCounters []N.CountFunc
for { for {
source, readCounters = N.UnwrapCountReader(source, readCounters) sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached { if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached() cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil { if cachedBuffer != nil {
dataLen := cachedBuffer.Len() dataLen := cachedBuffer.Len()
@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
if done.Swap(true) { if done.Swap(true) {
onClose(err) onClose(err)
} }
common.Close(originSource, originDestination) common.Close(source, destination)
if !direction { if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err) m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else { } else {
@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
} }
break break
} }
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() { if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
_, err := destination.Write(nil) err := m.connectionCopyEarly(source, destination)
if err != nil { if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
if !direction { if !direction {
m.logger.ErrorContext(ctx, "connection upload handshake: ", err) m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
} else { } else {
@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
return return
} }
} }
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) _, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
if err != nil { if err != nil {
common.Close(originDestination) common.Close(source, destination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite() err = duplexDst.CloseWrite()
if err != nil { if err != nil {
common.Close(originSource, originDestination) common.Close(source, destination)
} }
} else { } else {
common.Close(originDestination) destination.Close()
} }
if done.Swap(true) { if done.Swap(true) {
onClose(err) onClose(err)
common.Close(originSource, originDestination) common.Close(source, destination)
} }
if !direction { if !direction {
if err == nil { if err == nil {
@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
} }
} }
func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
payload := buf.NewPacket()
defer payload.Release()
err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != nil {
if err == os.ErrInvalid {
return common.Error(destination.Write(nil))
}
return err
}
_, err = payload.ReadOnceFrom(source)
if err != nil && !E.IsTimeout(err) {
return E.Cause(err, "read payload")
}
_ = source.SetReadDeadline(time.Time{})
_, err = destination.Write(payload.Bytes())
if err != nil {
return E.Cause(err, "write payload")
}
return nil
}
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source) _, err := bufio.CopyPacket(destination, source)
if !direction { if !direction {