Fix copy early conn

This commit is contained in:
世界 2025-03-15 08:09:04 +08:00
parent 96eb98c00a
commit 4f3ee61104
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -5,6 +5,7 @@ import (
"io"
"net"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
@ -13,6 +14,7 @@ import (
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
}
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
var (
sourceReader io.Reader = source
destinationWriter io.Writer = destination
)
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
if done.Swap(true) {
onClose(err)
}
common.Close(originSource, originDestination)
common.Close(source, destination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else {
@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
}
break
}
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() {
_, err := destination.Write(nil)
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
err := m.connectionCopyEarly(source, destination)
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
} else {
@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
return
}
}
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
_, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
if err != nil {
common.Close(originDestination)
common.Close(source, destination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(originSource, originDestination)
common.Close(source, destination)
}
} else {
common.Close(originDestination)
destination.Close()
}
if done.Swap(true) {
onClose(err)
common.Close(originSource, originDestination)
common.Close(source, destination)
}
if !direction {
if err == nil {
@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
}
}
func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
payload := buf.NewPacket()
defer payload.Release()
err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != nil {
if err == os.ErrInvalid {
return common.Error(destination.Write(nil))
}
return err
}
_, err = payload.ReadOnceFrom(source)
if err != nil && !E.IsTimeout(err) {
return E.Cause(err, "read payload")
}
_ = source.SetReadDeadline(time.Time{})
_, err = destination.Write(payload.Bytes())
if err != nil {
return E.Cause(err, "write payload")
}
return nil
}
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
if !direction {