Enable read wait copy for windows

This commit is contained in:
世界 2023-12-07 18:01:10 +08:00
parent 8b43ec8058
commit 231d7607bc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 72 additions and 79 deletions

View file

@ -45,7 +45,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}

View file

@ -1,12 +1,16 @@
package bufio
import (
"errors"
"io"
"syscall"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
return
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
return
}
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

View file

@ -3,7 +3,6 @@
package bufio
import (
"errors"
"io"
"net/netip"
"os"
@ -15,72 +14,6 @@ import (
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
type syscallReadWaiter struct {

View file

@ -1,19 +1,9 @@
package bufio
import (
"io"
N "github.com/sagernet/sing/common/network"
)
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
return
}
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
return
}
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
return nil, false
}