mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 04:47:40 +03:00
Enable read wait copy for windows
This commit is contained in:
parent
8b43ec8058
commit
231d7607bc
4 changed files with 72 additions and 79 deletions
|
@ -45,7 +45,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
||||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||||
if srcIsSyscall && dstIsSyscall {
|
if srcIsSyscall && dstIsSyscall {
|
||||||
var handled bool
|
var handled bool
|
||||||
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||||
if handled {
|
if handled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
package bufio
|
package bufio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||||
rawSource, err := source.SyscallConn()
|
rawSource, err := source.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
|
||||||
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||||
|
handled = true
|
||||||
|
var (
|
||||||
|
buffer *buf.Buffer
|
||||||
|
notFirstTime bool
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
buffer, err = source.WaitReadBuffer()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataLen := buffer.Len()
|
||||||
|
err = destination.WriteBuffer(buffer)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Leak()
|
||||||
|
if !notFirstTime {
|
||||||
|
err = N.ReportHandshakeFailure(originSource, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
notFirstTime = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
||||||
|
handled = true
|
||||||
|
var (
|
||||||
|
buffer *buf.Buffer
|
||||||
|
destination M.Socksaddr
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
buffer, destination, err = source.WaitReadPacket()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataLen := buffer.Len()
|
||||||
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Leak()
|
||||||
|
if !notFirstTime {
|
||||||
|
err = N.ReportHandshakeFailure(originSource, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
notFirstTime = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
package bufio
|
package bufio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
@ -15,72 +14,6 @@ import (
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
|
||||||
handled = true
|
|
||||||
var (
|
|
||||||
buffer *buf.Buffer
|
|
||||||
notFirstTime bool
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
buffer, err = source.WaitReadBuffer()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
err = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dataLen := buffer.Len()
|
|
||||||
err = destination.WriteBuffer(buffer)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Leak()
|
|
||||||
if !notFirstTime {
|
|
||||||
err = N.ReportHandshakeFailure(originSource, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
notFirstTime = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
|
||||||
handled = true
|
|
||||||
var (
|
|
||||||
buffer *buf.Buffer
|
|
||||||
destination M.Socksaddr
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
buffer, destination, err = source.WaitReadPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dataLen := buffer.Len()
|
|
||||||
err = destinationConn.WritePacket(buffer, destination)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Leak()
|
|
||||||
if !notFirstTime {
|
|
||||||
err = N.ReportHandshakeFailure(originSource, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
notFirstTime = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||||
|
|
||||||
type syscallReadWaiter struct {
|
type syscallReadWaiter struct {
|
||||||
|
|
|
@ -1,19 +1,9 @@
|
||||||
package bufio
|
package bufio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue