diff --git a/common/bufio/copy.go b/common/bufio/copy.go index ea279eb..80cc823 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } else if destination == nil { return 0, E.New("nil writer") } - originDestination := destination + originSource := source var readCounters, writeCounters []N.CountFunc for { source, readCounters = N.UnwrapCountReader(source, readCounters) @@ -52,29 +52,29 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } break } - return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) + return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } -func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { safeSrc := N.IsSafeReader(source) headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) if safeSrc != nil { if headroom == 0 { - return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters) + return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters) } } readWaiter, isReadWaiter := CreateReadWaiter(source) if isReadWaiter { var handled bool - handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters) + handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) if handled { return } } - return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) + return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) } -func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() frontHeadroom := N.CalculateFrontHeadroom(destination) @@ -90,15 +90,15 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) err = destination.WriteBuffer(buffer) if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -112,7 +112,7 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite } } -func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var notFirstTime bool for { var buffer *buf.Buffer @@ -122,15 +122,15 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := buffer.Len() err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -144,7 +144,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend } } -func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) bufferSize := N.CalculateMTU(source, destination) @@ -166,9 +166,6 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() @@ -176,6 +173,9 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -236,6 +236,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer + originSource := source for { source, readCounters = N.UnwrapCountPacketReader(source, readCounters) destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) @@ -249,8 +250,9 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, break } if cachedPackets != nil { - n, err = WritePacketWithPool(destinationConn, cachedPackets) + n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) if err != nil { + println("err in write cached packets") return } } @@ -261,36 +263,34 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, if safeSrc != nil { if headroom == 0 { var copyN int64 - copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters) + copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0) n += copyN return } } + var ( + handled bool + copeN int64 + ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - var ( - handled bool - copeN int64 - ) - handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters) + handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) if handled { n += copeN return } } - return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) + copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) + n += copeN + return } -func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { var buffer *buf.Buffer var destination M.Socksaddr - var notFirstTime bool for { buffer, destination, err = source.ReadPacketThreadSafe() if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := buffer.Len() @@ -300,6 +300,9 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -313,7 +316,7 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe } } -func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) bufferSize := N.CalculateMTU(source, destinationConn) @@ -323,7 +326,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r bufferSize = buf.UDPBufferSize } var destination M.Socksaddr - var notFirstTime bool for { buffer := buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() @@ -332,9 +334,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r destination, err = source.ReadPacket(readBuffer) if err != nil { buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := readBuffer.Len() @@ -342,6 +341,9 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -355,9 +357,10 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r } } -func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { +func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) + var notFirstTime bool for _, packetBuffer := range packetBuffers { buffer := buf.NewPacket() readBufferRaw := buffer.Slice() @@ -366,6 +369,7 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack _, err = readBuffer.Write(packetBuffer.Buffer.Bytes()) packetBuffer.Buffer.Release() if err != nil { + buffer.Release() continue } dataLen := readBuffer.Len() @@ -373,6 +377,9 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack err = destinationConn.WritePacket(buffer, packetBuffer.Destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index d682558..63643a1 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -15,7 +15,7 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { handled = true frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) @@ -45,9 +45,6 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() @@ -55,6 +52,9 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -68,7 +68,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, } } -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { handled = true frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) @@ -79,10 +79,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW bufferSize = buf.UDPBufferSize } var ( - buffer *buf.Buffer - readBuffer *buf.Buffer - destination M.Socksaddr - notFirstTime bool + buffer *buf.Buffer + readBuffer *buf.Buffer + destination M.Socksaddr ) source.InitializeReadWaiter(func() *buf.Buffer { buffer = buf.NewSize(bufferSize) @@ -95,9 +94,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW for { destination, err = source.WaitReadPacket() if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := readBuffer.Len() @@ -105,6 +101,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 9c0743f..22a2de0 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -6,11 +6,11 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { return } -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { return }