mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-04-03 20:07:36 +03:00
Fix copy early conn
This commit is contained in:
parent
96eb98c00a
commit
4f3ee61104
1 changed files with 44 additions and 14 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
"github.com/sagernet/sing/common/canceler"
|
"github.com/sagernet/sing/common/canceler"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
|
||||||
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||||||
originSource := source
|
var (
|
||||||
originDestination := destination
|
sourceReader io.Reader = source
|
||||||
|
destinationWriter io.Writer = destination
|
||||||
|
)
|
||||||
var readCounters, writeCounters []N.CountFunc
|
var readCounters, writeCounters []N.CountFunc
|
||||||
for {
|
for {
|
||||||
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
|
||||||
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
|
||||||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
|
||||||
cachedBuffer := cachedSrc.ReadCached()
|
cachedBuffer := cachedSrc.ReadCached()
|
||||||
if cachedBuffer != nil {
|
if cachedBuffer != nil {
|
||||||
dataLen := cachedBuffer.Len()
|
dataLen := cachedBuffer.Len()
|
||||||
|
@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
||||||
if done.Swap(true) {
|
if done.Swap(true) {
|
||||||
onClose(err)
|
onClose(err)
|
||||||
}
|
}
|
||||||
common.Close(originSource, originDestination)
|
common.Close(source, destination)
|
||||||
if !direction {
|
if !direction {
|
||||||
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() {
|
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
|
||||||
_, err := destination.Write(nil)
|
err := m.connectionCopyEarly(source, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if done.Swap(true) {
|
||||||
|
onClose(err)
|
||||||
|
}
|
||||||
|
common.Close(source, destination)
|
||||||
if !direction {
|
if !direction {
|
||||||
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
|
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
_, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.Close(originDestination)
|
common.Close(source, destination)
|
||||||
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
||||||
err = duplexDst.CloseWrite()
|
err = duplexDst.CloseWrite()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.Close(originSource, originDestination)
|
common.Close(source, destination)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
common.Close(originDestination)
|
destination.Close()
|
||||||
}
|
}
|
||||||
if done.Swap(true) {
|
if done.Swap(true) {
|
||||||
onClose(err)
|
onClose(err)
|
||||||
common.Close(originSource, originDestination)
|
common.Close(source, destination)
|
||||||
}
|
}
|
||||||
if !direction {
|
if !direction {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
|
||||||
|
payload := buf.NewPacket()
|
||||||
|
defer payload.Release()
|
||||||
|
err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
|
||||||
|
if err != nil {
|
||||||
|
if err == os.ErrInvalid {
|
||||||
|
return common.Error(destination.Write(nil))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = payload.ReadOnceFrom(source)
|
||||||
|
if err != nil && !E.IsTimeout(err) {
|
||||||
|
return E.Cause(err, "read payload")
|
||||||
|
}
|
||||||
|
_ = source.SetReadDeadline(time.Time{})
|
||||||
|
_, err = destination.Write(payload.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "write payload")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
||||||
_, err := bufio.CopyPacket(destination, source)
|
_, err := bufio.CopyPacket(destination, source)
|
||||||
if !direction {
|
if !direction {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue