Fix handshake conn interface

This commit is contained in:
世界 2023-07-23 12:43:02 +08:00
parent 8807070904
commit 32f9f628a0
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 57 additions and 51 deletions

View file

@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
} else if destination == nil {
return 0, E.New("nil writer")
}
originDestination := destination
originSource := source
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
@ -52,29 +52,29 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
}
break
}
return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
safeSrc := N.IsSafeReader(source)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
}
}
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
var handled bool
handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}
func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
@ -90,15 +90,15 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -112,7 +112,7 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite
}
}
func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
@ -122,15 +122,15 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -144,7 +144,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
}
}
func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
@ -166,9 +166,6 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
@ -176,6 +173,9 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -236,6 +236,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
@ -249,8 +250,9 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
if err != nil {
println("err in write cached packets")
return
}
}
@ -261,36 +263,34 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
if safeSrc != nil {
if headroom == 0 {
var copyN int64
copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
n += copyN
return
}
}
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
var (
handled bool
copeN int64
)
handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
}
}
return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
n += copeN
return
}
func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
buffer, destination, err = source.ReadPacketThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := buffer.Len()
@ -300,6 +300,9 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -313,7 +316,7 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe
}
}
func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
@ -323,7 +326,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
@ -332,9 +334,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
destination, err = source.ReadPacket(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := readBuffer.Len()
@ -342,6 +341,9 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -355,9 +357,10 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
}
}
func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
readBufferRaw := buffer.Slice()
@ -366,6 +369,7 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
}
dataLen := readBuffer.Len()
@ -373,6 +377,9 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)

View file

@ -15,7 +15,7 @@ import (
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
@ -45,9 +45,6 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
err = nil
return
}
if !notFirstTime {
err = N.HandshakeFailure(originDestination, err)
}
return
}
dataLen := readBuffer.Len()
@ -55,6 +52,9 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
@ -68,7 +68,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
}
}
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
@ -79,10 +79,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
notFirstTime bool
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
@ -95,9 +94,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
for {
destination, err = source.WaitReadPacket()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(destinationConn, err)
}
return
}
dataLen := readBuffer.Len()
@ -105,6 +101,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)

View file

@ -6,11 +6,11 @@ import (
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
return
}
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
return
}