mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Enable read wait copy for windows
This commit is contained in:
parent
8b43ec8058
commit
231d7607bc
4 changed files with 72 additions and 79 deletions
|
@ -45,7 +45,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||
if srcIsSyscall && dstIsSyscall {
|
||||
var handled bool
|
||||
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||
rawSource, err := source.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
|
|||
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||
return
|
||||
}
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
for {
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
for {
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
@ -15,72 +14,6 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
for {
|
||||
buffer, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
handled = true
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
)
|
||||
for {
|
||||
buffer, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = destinationConn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Leak()
|
||||
if !notFirstTime {
|
||||
err = N.ReportHandshakeFailure(originSource, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
for _, counter := range writeCounters {
|
||||
counter(int64(dataLen))
|
||||
}
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
|
|
|
@ -1,19 +1,9 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue