mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Fix handshake conn interface
This commit is contained in:
parent
8807070904
commit
32f9f628a0
3 changed files with 57 additions and 51 deletions
|
@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
} else if destination == nil {
|
||||
return 0, E.New("nil writer")
|
||||
}
|
||||
originDestination := destination
|
||||
originSource := source
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
for {
|
||||
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||||
|
@ -52,29 +52,29 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
}
|
||||
break
|
||||
}
|
||||
return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
safeSrc := N.IsSafeReader(source)
|
||||
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
|
||||
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
|
||||
}
|
||||
}
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
var handled bool
|
||||
handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
|
||||
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
|
||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||
}
|
||||
|
||||
func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
|
@ -90,15 +90,15 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite
|
|||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -112,7 +112,7 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite
|
|||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
|
@ -122,15 +122,15 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
|
|||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -144,7 +144,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend
|
|||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
bufferSize := N.CalculateMTU(source, destination)
|
||||
|
@ -166,9 +166,6 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
|
|||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
|
@ -176,6 +173,9 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri
|
|||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -236,6 +236,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
|
|||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||
var readCounters, writeCounters []N.CountFunc
|
||||
var cachedPackets []*N.PacketBuffer
|
||||
originSource := source
|
||||
for {
|
||||
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||||
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
|
||||
|
@ -249,8 +250,9 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
break
|
||||
}
|
||||
if cachedPackets != nil {
|
||||
n, err = WritePacketWithPool(destinationConn, cachedPackets)
|
||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||
if err != nil {
|
||||
println("err in write cached packets")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -261,36 +263,34 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
|||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
var copyN int64
|
||||
copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
|
||||
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
|
||||
n += copyN
|
||||
return
|
||||
}
|
||||
}
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||
if isReadWaiter {
|
||||
var (
|
||||
handled bool
|
||||
copeN int64
|
||||
)
|
||||
handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
|
||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||
if handled {
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
}
|
||||
return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
|
||||
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
||||
n += copeN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer, destination, err = source.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
|
@ -300,6 +300,9 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe
|
|||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -313,7 +316,7 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe
|
|||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||
|
@ -323,7 +326,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
|
@ -332,9 +334,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
destination, err = source.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
|
@ -342,6 +341,9 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -355,9 +357,10 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r
|
|||
}
|
||||
}
|
||||
|
||||
func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
var notFirstTime bool
|
||||
for _, packetBuffer := range packetBuffers {
|
||||
buffer := buf.NewPacket()
|
||||
readBufferRaw := buffer.Slice()
|
||||
|
@ -366,6 +369,7 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack
|
|||
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
|
||||
packetBuffer.Buffer.Release()
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
|
@ -373,6 +377,9 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack
|
|||
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||
|
@ -45,9 +45,6 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
|
|||
err = nil
|
||||
return
|
||||
}
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originDestination, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
|
@ -55,6 +52,9 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
|
|||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -68,7 +68,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
|
|||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||
|
@ -79,10 +79,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
notFirstTime bool
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
|
@ -95,9 +94,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
for {
|
||||
destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
|
@ -105,6 +101,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue