mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 12:27:37 +03:00
Add direct copy
This commit is contained in:
parent
44534566a3
commit
bf0aaacc67
15 changed files with 912 additions and 136 deletions
|
@ -186,13 +186,16 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
||||||
return c.PacketConn.ReadPacket(buffer)
|
return c.PacketConn.ReadPacket(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
|
func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
||||||
buffer = c.buffer
|
buffer := c.buffer
|
||||||
c.buffer = nil
|
c.buffer = nil
|
||||||
if buffer != nil {
|
if buffer != nil {
|
||||||
buffer.DecRef()
|
buffer.DecRef()
|
||||||
}
|
}
|
||||||
return c.destination, buffer
|
return &N.PacketBuffer{
|
||||||
|
Buffer: buffer,
|
||||||
|
Destination: c.destination,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CachedPacketConn) Upstream() any {
|
func (c *CachedPacketConn) Upstream() any {
|
||||||
|
|
|
@ -125,7 +125,7 @@ func (w *ExtendedWriterWrapper) Upstream() any {
|
||||||
return w.Writer
|
return w.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ExtendedReaderWrapper) WriterReplaceable() bool {
|
func (w *ExtendedWriterWrapper) WriterReplaceable() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,10 @@ package bufio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
@ -14,80 +16,65 @@ import (
|
||||||
"github.com/sagernet/sing/common/task"
|
"github.com/sagernet/sing/common/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
type readOnlyReader struct {
|
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
||||||
io.Reader
|
if source == nil {
|
||||||
}
|
|
||||||
|
|
||||||
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
|
|
||||||
return Copy(w, r.Reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *readOnlyReader) Upstream() any {
|
|
||||||
return r.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *readOnlyReader) ReaderReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type writeOnlyWriter struct {
|
|
||||||
io.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
|
||||||
return Copy(w.Writer, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeOnlyWriter) Upstream() any {
|
|
||||||
return w.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeOnlyWriter) WriterReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func needWrapper(src, dst any) bool {
|
|
||||||
_, srcTCPConn := src.(*net.TCPConn)
|
|
||||||
_, dstTCPConn := dst.(*net.TCPConn)
|
|
||||||
return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
|
||||||
if src == nil {
|
|
||||||
return 0, E.New("nil reader")
|
return 0, E.New("nil reader")
|
||||||
} else if dst == nil {
|
} else if destination == nil {
|
||||||
return 0, E.New("nil writer")
|
return 0, E.New("nil writer")
|
||||||
}
|
}
|
||||||
origDst := dst
|
originDestination := destination
|
||||||
src = N.UnwrapReader(src)
|
var readCounters, writeCounters []N.CountFunc
|
||||||
dst = N.UnwrapWriter(dst)
|
for {
|
||||||
if wt, ok := src.(io.WriterTo); ok {
|
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
||||||
if needWrapper(dst, src) {
|
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
||||||
dst = &writeOnlyWriter{dst}
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||||
|
cachedBuffer := cachedSrc.ReadCached()
|
||||||
|
if cachedBuffer != nil {
|
||||||
|
if !cachedBuffer.IsEmpty() {
|
||||||
|
_, err = destination.Write(cachedBuffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
cachedBuffer.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cachedBuffer.Release()
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return wt.WriteTo(dst)
|
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||||
}
|
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
if srcIsSyscall && dstIsSyscall {
|
||||||
if needWrapper(rt, src) {
|
var handled bool
|
||||||
src = &readOnlyReader{src}
|
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||||
|
if handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return rt.ReadFrom(src)
|
break
|
||||||
}
|
}
|
||||||
return CopyExtended(origDst, NewExtendedWriter(dst), NewExtendedReader(src))
|
return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
safeSrc := N.IsSafeReader(src)
|
safeSrc := N.IsSafeReader(source)
|
||||||
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
|
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
|
||||||
if safeSrc != nil {
|
if safeSrc != nil {
|
||||||
if headroom == 0 {
|
if headroom == 0 {
|
||||||
return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc)
|
return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if N.IsUnsafeWriter(dst) {
|
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
||||||
return CopyExtendedWithPool(origDst, dst, src)
|
if isReadWaiter {
|
||||||
|
var handled bool
|
||||||
|
handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
|
||||||
|
if handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bufferSize := N.CalculateMTU(src, dst)
|
if N.IsUnsafeWriter(destination) {
|
||||||
|
return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
|
||||||
|
}
|
||||||
|
bufferSize := N.CalculateMTU(source, destination)
|
||||||
if bufferSize > 0 {
|
if bufferSize > 0 {
|
||||||
bufferSize += headroom
|
bufferSize += headroom
|
||||||
} else {
|
} else {
|
||||||
|
@ -97,63 +84,83 @@ func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader)
|
||||||
defer common.KeepAlive(_buffer)
|
defer common.KeepAlive(_buffer)
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
return CopyExtendedBuffer(origDst, dst, src, buffer)
|
return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtendedBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
|
func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
buffer.IncRef()
|
buffer.IncRef()
|
||||||
defer buffer.DecRef()
|
defer buffer.DecRef()
|
||||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||||
readBufferRaw := buffer.Slice()
|
readBufferRaw := buffer.Slice()
|
||||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for {
|
for {
|
||||||
readBuffer.Resize(frontHeadroom, 0)
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
err = src.ReadBuffer(readBuffer)
|
err = source.ReadBuffer(readBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(origDst, err)
|
err = N.HandshakeFailure(originDestination, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := readBuffer.Len()
|
dataLen := readBuffer.Len()
|
||||||
buffer.Resize(readBuffer.Start(), dataLen)
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
err = dst.WriteBuffer(buffer)
|
err = destination.WriteBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtendedWithSrcBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
|
func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for {
|
for {
|
||||||
var buffer *buf.Buffer
|
var buffer *buf.Buffer
|
||||||
buffer, err = src.ReadBufferThreadSafe()
|
buffer, err = source.ReadBufferThreadSafe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(origDst, err)
|
err = N.HandshakeFailure(originDestination, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
err = dst.WriteBuffer(buffer)
|
err = destination.WriteBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||||
bufferSize := N.CalculateMTU(src, dst)
|
bufferSize := N.CalculateMTU(source, destination)
|
||||||
if bufferSize > 0 {
|
if bufferSize > 0 {
|
||||||
bufferSize += frontHeadroom + rearHeadroom
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
} else {
|
} else {
|
||||||
|
@ -165,86 +172,129 @@ func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.Extende
|
||||||
readBufferRaw := buffer.Slice()
|
readBufferRaw := buffer.Slice()
|
||||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
readBuffer.Resize(frontHeadroom, 0)
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
err = src.ReadBuffer(readBuffer)
|
err = source.ReadBuffer(readBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(origDst, err)
|
err = N.HandshakeFailure(originDestination, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := readBuffer.Len()
|
dataLen := readBuffer.Len()
|
||||||
buffer.Resize(readBuffer.Start(), dataLen)
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
err = dst.WriteBuffer(buffer)
|
err = destination.WriteBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
||||||
return CopyConnContextList([]context.Context{ctx}, conn, dest)
|
return CopyConnContextList([]context.Context{ctx}, source, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error {
|
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
|
||||||
var group task.Group
|
var group task.Group
|
||||||
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
|
if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
|
||||||
group.Append("upload", func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
err := common.Error(Copy(dest, conn))
|
err := common.Error(Copy(destination, source))
|
||||||
if E.IsMulti(err, io.EOF) {
|
if err == nil {
|
||||||
rw.CloseWrite(dest)
|
rw.CloseWrite(destination)
|
||||||
} else {
|
} else {
|
||||||
common.Close(dest)
|
common.Close(destination)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
group.Append("upload", func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
defer common.Close(dest)
|
defer common.Close(destination)
|
||||||
return common.Error(Copy(dest, conn))
|
return common.Error(Copy(destination, source))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if _, srcDuplex := common.Cast[rw.WriteCloser](conn); srcDuplex {
|
if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
|
||||||
group.Append("download", func(ctx context.Context) error {
|
group.Append("download", func(ctx context.Context) error {
|
||||||
err := common.Error(Copy(conn, dest))
|
err := common.Error(Copy(source, destination))
|
||||||
if E.IsMulti(err, io.EOF) {
|
if err == nil {
|
||||||
rw.CloseWrite(conn)
|
rw.CloseWrite(source)
|
||||||
} else {
|
} else {
|
||||||
common.Close(conn)
|
common.Close(source)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
group.Append("download", func(ctx context.Context) error {
|
group.Append("download", func(ctx context.Context) error {
|
||||||
defer common.Close(conn)
|
defer common.Close(source)
|
||||||
return common.Error(Copy(conn, dest))
|
return common.Error(Copy(source, destination))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
group.Cleanup(func() {
|
group.Cleanup(func() {
|
||||||
common.Close(conn, dest)
|
common.Close(source, destination)
|
||||||
})
|
})
|
||||||
return group.RunContextList(contextList)
|
return group.RunContextList(contextList)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||||
src = N.UnwrapPacketReader(src)
|
var readCounters, writeCounters []N.CountFunc
|
||||||
dst = N.UnwrapPacketWriter(dst)
|
var cachedPackets []*N.PacketBuffer
|
||||||
safeSrc := N.IsSafePacketReader(src)
|
for {
|
||||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
||||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
|
||||||
|
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
||||||
|
packet := cachedReader.ReadCachedPacket()
|
||||||
|
if packet != nil {
|
||||||
|
cachedPackets = append([]*N.PacketBuffer{packet}, cachedPackets...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if cachedPackets != nil {
|
||||||
|
n, err = WritePacketWithPool(destinationConn, cachedPackets)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
safeSrc := N.IsSafePacketReader(source)
|
||||||
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
headroom := frontHeadroom + rearHeadroom
|
headroom := frontHeadroom + rearHeadroom
|
||||||
if safeSrc != nil {
|
if safeSrc != nil {
|
||||||
if headroom == 0 {
|
if headroom == 0 {
|
||||||
return CopyPacketWithSrcBuffer(dst, safeSrc)
|
var copyN int64
|
||||||
|
copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
|
||||||
|
n += copyN
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if N.IsUnsafeWriter(dst) {
|
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||||
return CopyPacketWithPool(dst, src)
|
if isReadWaiter {
|
||||||
|
var (
|
||||||
|
handled bool
|
||||||
|
copeN int64
|
||||||
|
)
|
||||||
|
handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
|
||||||
|
if handled {
|
||||||
|
n += copeN
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bufferSize := N.CalculateMTU(src, dst)
|
if N.IsUnsafeWriter(destinationConn) {
|
||||||
|
return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
|
||||||
|
}
|
||||||
|
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||||
if bufferSize > 0 {
|
if bufferSize > 0 {
|
||||||
bufferSize += headroom
|
bufferSize += headroom
|
||||||
} else {
|
} else {
|
||||||
|
@ -262,51 +312,66 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
for {
|
for {
|
||||||
readBuffer.Resize(frontHeadroom, 0)
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
destination, err = src.ReadPacket(readBuffer)
|
destination, err = source.ReadPacket(readBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(dst, err)
|
err = N.HandshakeFailure(destinationConn, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := readBuffer.Len()
|
dataLen := readBuffer.Len()
|
||||||
buffer.Resize(readBuffer.Start(), dataLen)
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
err = dst.WritePacket(buffer, destination)
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
var buffer *buf.Buffer
|
var buffer *buf.Buffer
|
||||||
var destination M.Socksaddr
|
var destination M.Socksaddr
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for {
|
for {
|
||||||
buffer, destination, err = src.ReadPacketThreadSafe()
|
buffer, destination, err = source.ReadPacketThreadSafe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(dst, err)
|
err = N.HandshakeFailure(destinationConn, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
err = dst.WritePacket(buffer, destination)
|
if dataLen == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
bufferSize := N.CalculateMTU(src, dst)
|
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||||
if bufferSize > 0 {
|
if bufferSize > 0 {
|
||||||
bufferSize += frontHeadroom + rearHeadroom
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
} else {
|
} else {
|
||||||
|
@ -319,40 +384,72 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
|
||||||
readBufferRaw := buffer.Slice()
|
readBufferRaw := buffer.Slice()
|
||||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
readBuffer.Resize(frontHeadroom, 0)
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
destination, err = src.ReadPacket(readBuffer)
|
destination, err = source.ReadPacket(readBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
err = N.HandshakeFailure(dst, err)
|
err = N.HandshakeFailure(destinationConn, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := readBuffer.Len()
|
dataLen := readBuffer.Len()
|
||||||
buffer.Resize(readBuffer.Start(), dataLen)
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
err = dst.WritePacket(buffer, destination)
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(n)
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(n)
|
||||||
|
}
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||||
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
|
var destination M.Socksaddr
|
||||||
|
for _, packetBuffer := range packetBuffers {
|
||||||
|
buffer := buf.NewPacket()
|
||||||
|
readBufferRaw := buffer.Slice()
|
||||||
|
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
|
_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
|
||||||
|
packetBuffer.Buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataLen := readBuffer.Len()
|
||||||
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||||
|
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||||
var group task.Group
|
var group task.Group
|
||||||
group.Append("upload", func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
return common.Error(CopyPacket(dest, conn))
|
return common.Error(CopyPacket(destination, source))
|
||||||
})
|
})
|
||||||
group.Append("download", func(ctx context.Context) error {
|
group.Append("download", func(ctx context.Context) error {
|
||||||
return common.Error(CopyPacket(conn, dest))
|
return common.Error(CopyPacket(source, destination))
|
||||||
})
|
})
|
||||||
group.Cleanup(func() {
|
group.Cleanup(func() {
|
||||||
common.Close(conn, dest)
|
common.Close(source, destination)
|
||||||
})
|
})
|
||||||
group.FastFail()
|
group.FastFail()
|
||||||
return group.RunContextList(contextList)
|
return group.RunContextList(contextList)
|
||||||
|
|
20
common/bufio/copy_direct.go
Normal file
20
common/bufio/copy_direct.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||||
|
rawSource, err := source.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawDestination, err := destination.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
|
||||||
|
return
|
||||||
|
}
|
232
common/bufio/copy_direct_posix.go
Normal file
232
common/bufio/copy_direct_posix.go
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||||
|
handled = true
|
||||||
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||||
|
bufferSize := N.CalculateMTU(source, destination)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.BufferSize
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
buffer *buf.Buffer
|
||||||
|
readBuffer *buf.Buffer
|
||||||
|
)
|
||||||
|
newBuffer := func() *buf.Buffer {
|
||||||
|
if buffer != nil {
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
buffer = buf.NewSize(bufferSize)
|
||||||
|
readBufferRaw := buffer.Slice()
|
||||||
|
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
|
return readBuffer
|
||||||
|
}
|
||||||
|
var notFirstTime bool
|
||||||
|
for {
|
||||||
|
err = source.WaitReadBuffer(newBuffer)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !notFirstTime {
|
||||||
|
err = N.HandshakeFailure(originDestination, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataLen := readBuffer.Len()
|
||||||
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
|
err = destination.WriteBuffer(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if buffer != nil {
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
notFirstTime = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||||
|
handled = true
|
||||||
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
|
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.UDPBufferSize
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
buffer *buf.Buffer
|
||||||
|
readBuffer *buf.Buffer
|
||||||
|
)
|
||||||
|
newBuffer := func() *buf.Buffer {
|
||||||
|
if buffer != nil {
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
buffer = buf.NewSize(bufferSize)
|
||||||
|
readBufferRaw := buffer.Slice()
|
||||||
|
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||||
|
readBuffer.Resize(frontHeadroom, 0)
|
||||||
|
return readBuffer
|
||||||
|
}
|
||||||
|
var destination M.Socksaddr
|
||||||
|
var notFirstTime bool
|
||||||
|
for {
|
||||||
|
destination, err = source.WaitReadPacket(newBuffer)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
if !notFirstTime {
|
||||||
|
err = N.HandshakeFailure(destinationConn, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataLen := readBuffer.Len()
|
||||||
|
buffer.Resize(readBuffer.Start(), dataLen)
|
||||||
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
|
for _, counter := range readCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
for _, counter := range writeCounters {
|
||||||
|
counter(int64(dataLen))
|
||||||
|
}
|
||||||
|
notFirstTime = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type syscallReadWaiter struct {
|
||||||
|
rawConn syscall.RawConn
|
||||||
|
readErr error
|
||||||
|
readFunc func(fd uintptr) (done bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||||
|
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||||
|
rawConn, err := syscallConn.SyscallConn()
|
||||||
|
if err == nil {
|
||||||
|
return &syscallReadWaiter{rawConn: rawConn}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
||||||
|
if w.readFunc == nil {
|
||||||
|
w.readFunc = func(fd uintptr) (done bool) {
|
||||||
|
buffer := newBuffer()
|
||||||
|
var readN int
|
||||||
|
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
|
||||||
|
if readN > 0 {
|
||||||
|
buffer.Truncate(readN)
|
||||||
|
} else {
|
||||||
|
buffer.Release()
|
||||||
|
buffer = nil
|
||||||
|
}
|
||||||
|
if w.readErr == syscall.EAGAIN {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if readN == 0 {
|
||||||
|
w.readErr = io.EOF
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := w.rawConn.Read(w.readFunc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if w.readErr != nil {
|
||||||
|
return E.Cause(w.readErr, "raw read")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type syscallPacketReadWaiter struct {
|
||||||
|
rawConn syscall.RawConn
|
||||||
|
readErr error
|
||||||
|
readFrom M.Socksaddr
|
||||||
|
readFunc func(fd uintptr) (done bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||||
|
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
|
||||||
|
rawConn, err := syscallConn.SyscallConn()
|
||||||
|
if err == nil {
|
||||||
|
return &syscallPacketReadWaiter{rawConn: rawConn}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
if w.readFunc == nil {
|
||||||
|
w.readFunc = func(fd uintptr) (done bool) {
|
||||||
|
buffer := newBuffer()
|
||||||
|
var readN int
|
||||||
|
var from syscall.Sockaddr
|
||||||
|
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||||
|
if readN > 0 {
|
||||||
|
buffer.Truncate(readN)
|
||||||
|
} else {
|
||||||
|
buffer.Release()
|
||||||
|
buffer = nil
|
||||||
|
}
|
||||||
|
if w.readErr == syscall.EAGAIN {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if from != nil {
|
||||||
|
switch fromAddr := from.(type) {
|
||||||
|
case *syscall.SockaddrInet4:
|
||||||
|
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||||
|
case *syscall.SockaddrInet6:
|
||||||
|
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readN == 0 {
|
||||||
|
w.readErr = io.EOF
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = w.rawConn.Read(w.readFunc)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if w.readErr != nil {
|
||||||
|
err = E.Cause(w.readErr, "raw read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
destination = w.readFrom
|
||||||
|
return
|
||||||
|
}
|
23
common/bufio/copy_direct_windows.go
Normal file
23
common/bufio/copy_direct_windows.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
|
||||||
|
return nil, false
|
||||||
|
}
|
96
common/bufio/counter_conn.go
Normal file
96
common/bufio/counter_conn.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewInt64CounterConn(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterConn {
|
||||||
|
return &CounterConn{
|
||||||
|
NewExtendedConn(conn),
|
||||||
|
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
|
||||||
|
return func(n int64) {
|
||||||
|
it.Add(n)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
|
||||||
|
return func(n int64) {
|
||||||
|
it.Add(n)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCounterConn(conn net.Conn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterConn {
|
||||||
|
return &CounterConn{NewExtendedConn(conn), readCounter, writeCounter}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CounterConn struct {
|
||||||
|
N.ExtendedConn
|
||||||
|
readCounter []N.CountFunc
|
||||||
|
writeCounter []N.CountFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = c.ExtendedConn.Read(p)
|
||||||
|
if n > 0 {
|
||||||
|
for _, counter := range c.readCounter {
|
||||||
|
counter(int64(n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||||
|
err := c.ExtendedConn.ReadBuffer(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if buffer.Len() > 0 {
|
||||||
|
for _, counter := range c.readCounter {
|
||||||
|
counter(int64(buffer.Len()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = c.ExtendedConn.Write(p)
|
||||||
|
if n > 0 {
|
||||||
|
for _, counter := range c.writeCounter {
|
||||||
|
counter(int64(n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
|
dataLen := int64(buffer.Len())
|
||||||
|
err := c.ExtendedConn.WriteBuffer(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if dataLen > 0 {
|
||||||
|
for _, counter := range c.writeCounter {
|
||||||
|
counter(dataLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) UnwrapReader() (io.Reader, []N.CountFunc) {
|
||||||
|
return c.ExtendedConn, c.readCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) UnwrapWriter() (io.Writer, []N.CountFunc) {
|
||||||
|
return c.ExtendedConn, c.writeCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterConn) Upstream() any {
|
||||||
|
return c.ExtendedConn
|
||||||
|
}
|
73
common/bufio/counter_packet_conn.go
Normal file
73
common/bufio/counter_packet_conn.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CounterPacketConn struct {
|
||||||
|
N.PacketConn
|
||||||
|
readCounter []N.CountFunc
|
||||||
|
writeCounter []N.CountFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInt64CounterPacketConn(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterPacketConn {
|
||||||
|
return &CounterPacketConn{
|
||||||
|
conn,
|
||||||
|
common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
|
||||||
|
return func(n int64) {
|
||||||
|
it.Add(n)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
|
||||||
|
return func(n int64) {
|
||||||
|
it.Add(n)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCounterPacketConn(conn N.PacketConn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterPacketConn {
|
||||||
|
return &CounterPacketConn{conn, readCounter, writeCounter}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
destination, err = c.PacketConn.ReadPacket(buffer)
|
||||||
|
if err == nil {
|
||||||
|
if buffer.Len() > 0 {
|
||||||
|
for _, counter := range c.readCounter {
|
||||||
|
counter(int64(buffer.Len()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
|
dataLen := int64(buffer.Len())
|
||||||
|
err := c.PacketConn.WritePacket(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if dataLen > 0 {
|
||||||
|
for _, counter := range c.writeCounter {
|
||||||
|
counter(dataLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) {
|
||||||
|
return c.PacketConn, c.readCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) {
|
||||||
|
return c.PacketConn, c.writeCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CounterPacketConn) Upstream() any {
|
||||||
|
return c.PacketConn
|
||||||
|
}
|
79
common/bufio/splice_linux.go
Normal file
79
common/bufio/splice_linux.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxSpliceSize = 1 << 20
|
||||||
|
|
||||||
|
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||||
|
handed = true
|
||||||
|
var pipeFDs [2]int
|
||||||
|
err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer unix.Close(pipeFDs[0])
|
||||||
|
defer unix.Close(pipeFDs[1])
|
||||||
|
|
||||||
|
_, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize)
|
||||||
|
var readN int
|
||||||
|
var readErr error
|
||||||
|
var writeErr error
|
||||||
|
readFunc := func(fd uintptr) (done bool) {
|
||||||
|
p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK)
|
||||||
|
readN = int(p0)
|
||||||
|
readErr = p1
|
||||||
|
return readErr != unix.EAGAIN
|
||||||
|
}
|
||||||
|
writeFunc := func(fd uintptr) (done bool) {
|
||||||
|
var writeN int
|
||||||
|
size := readN
|
||||||
|
for size > 0 {
|
||||||
|
p0, p1 := unix.Splice(pipeFDs[0], nil, int(fd), nil, size, unix.SPLICE_F_NONBLOCK|unix.SPLICE_F_MOVE)
|
||||||
|
writeN = int(p0)
|
||||||
|
writeErr = p1
|
||||||
|
if writeErr != nil {
|
||||||
|
return writeErr != unix.EAGAIN
|
||||||
|
}
|
||||||
|
size -= int(writeN)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
err = source.Read(readFunc)
|
||||||
|
if err != nil {
|
||||||
|
readErr = err
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
if readErr == unix.EINVAL || readErr == unix.ENOSYS {
|
||||||
|
handed = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = E.Cause(readErr, "splice read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if readN == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = destination.Write(writeFunc)
|
||||||
|
if err != nil {
|
||||||
|
writeErr = err
|
||||||
|
}
|
||||||
|
if writeErr != nil {
|
||||||
|
err = E.Cause(writeErr, "splice write")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, readCounter := range readCounters {
|
||||||
|
readCounter(int64(readN))
|
||||||
|
}
|
||||||
|
for _, writeCounter := range writeCounters {
|
||||||
|
writeCounter(int64(readN))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
13
common/bufio/splice_stub.go
Normal file
13
common/bufio/splice_stub.go
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
|
||||||
|
return
|
||||||
|
}
|
35
common/bufio/wait.go
Normal file
35
common/bufio/wait.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) {
|
||||||
|
reader = N.UnwrapReader(reader)
|
||||||
|
if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter {
|
||||||
|
return readWaiter, true
|
||||||
|
}
|
||||||
|
if readWaitCreator, isCreator := reader.(N.ReadWaitCreator); isCreator {
|
||||||
|
return readWaitCreator.CreateReadWaiter()
|
||||||
|
}
|
||||||
|
if readWaiter, created := createSyscallReadWaiter(reader); created {
|
||||||
|
return readWaiter, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) {
|
||||||
|
reader = N.UnwrapPacketReader(reader)
|
||||||
|
if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter {
|
||||||
|
return readWaiter, true
|
||||||
|
}
|
||||||
|
if readWaitCreator, isCreator := reader.(N.PacketReadWaitCreator); isCreator {
|
||||||
|
return readWaitCreator.CreateReadWaiter()
|
||||||
|
}
|
||||||
|
if readWaiter, created := createSyscallPacketReadWaiter(reader); created {
|
||||||
|
return readWaiter, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
|
@ -89,7 +89,12 @@ type CachedReader interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type CachedPacketReader interface {
|
type CachedPacketReader interface {
|
||||||
ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer)
|
ReadCachedPacket() *PacketBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketBuffer struct {
|
||||||
|
Buffer *buf.Buffer
|
||||||
|
Destination M.Socksaddr
|
||||||
}
|
}
|
||||||
|
|
||||||
type WithUpstreamReader interface {
|
type WithUpstreamReader interface {
|
||||||
|
@ -118,7 +123,7 @@ func UnwrapReader(reader io.Reader) io.Reader {
|
||||||
if u, ok := reader.(common.WithUpstream); ok {
|
if u, ok := reader.(common.WithUpstream); ok {
|
||||||
return UnwrapReader(u.Upstream().(io.Reader))
|
return UnwrapReader(u.Upstream().(io.Reader))
|
||||||
}
|
}
|
||||||
panic("bad reader")
|
return reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnwrapPacketReader(reader PacketReader) PacketReader {
|
func UnwrapPacketReader(reader PacketReader) PacketReader {
|
||||||
|
@ -131,7 +136,7 @@ func UnwrapPacketReader(reader PacketReader) PacketReader {
|
||||||
if u, ok := reader.(common.WithUpstream); ok {
|
if u, ok := reader.(common.WithUpstream); ok {
|
||||||
return UnwrapPacketReader(u.Upstream().(PacketReader))
|
return UnwrapPacketReader(u.Upstream().(PacketReader))
|
||||||
}
|
}
|
||||||
panic("bad reader")
|
return reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnwrapWriter(writer io.Writer) io.Writer {
|
func UnwrapWriter(writer io.Writer) io.Writer {
|
||||||
|
@ -144,7 +149,7 @@ func UnwrapWriter(writer io.Writer) io.Writer {
|
||||||
if u, ok := writer.(common.WithUpstream); ok {
|
if u, ok := writer.(common.WithUpstream); ok {
|
||||||
return UnwrapWriter(u.Upstream().(io.Writer))
|
return UnwrapWriter(u.Upstream().(io.Writer))
|
||||||
}
|
}
|
||||||
panic("bad writer")
|
return writer
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
|
func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
|
||||||
|
@ -157,5 +162,5 @@ func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
|
||||||
if u, ok := writer.(common.WithUpstream); ok {
|
if u, ok := writer.(common.WithUpstream); ok {
|
||||||
return UnwrapPacketWriter(u.Upstream().(PacketWriter))
|
return UnwrapPacketWriter(u.Upstream().(PacketWriter))
|
||||||
}
|
}
|
||||||
panic("bad writer")
|
return writer
|
||||||
}
|
}
|
||||||
|
|
67
common/network/counter.go
Normal file
67
common/network/counter.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CountFunc func(n int64)
|
||||||
|
|
||||||
|
type ReadCounter interface {
|
||||||
|
io.Reader
|
||||||
|
UnwrapReader() (io.Reader, []CountFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WriteCounter interface {
|
||||||
|
io.Writer
|
||||||
|
UnwrapWriter() (io.Writer, []CountFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketReadCounter interface {
|
||||||
|
PacketReader
|
||||||
|
UnwrapPacketReader() (PacketReader, []CountFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketWriteCounter interface {
|
||||||
|
PacketWriter
|
||||||
|
UnwrapPacketWriter() (PacketWriter, []CountFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []CountFunc) {
|
||||||
|
reader = UnwrapReader(reader)
|
||||||
|
if counter, isCounter := reader.(ReadCounter); isCounter {
|
||||||
|
upstreamReader, upstreamCountFunc := counter.UnwrapReader()
|
||||||
|
countFunc = append(countFunc, upstreamCountFunc...)
|
||||||
|
return UnwrapCountReader(upstreamReader, countFunc)
|
||||||
|
}
|
||||||
|
return reader, countFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []CountFunc) {
|
||||||
|
writer = UnwrapWriter(writer)
|
||||||
|
if counter, isCounter := writer.(WriteCounter); isCounter {
|
||||||
|
upstreamWriter, upstreamCountFunc := counter.UnwrapWriter()
|
||||||
|
countFunc = append(countFunc, upstreamCountFunc...)
|
||||||
|
return UnwrapCountWriter(upstreamWriter, countFunc)
|
||||||
|
}
|
||||||
|
return writer, countFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (PacketReader, []CountFunc) {
|
||||||
|
reader = UnwrapPacketReader(reader)
|
||||||
|
if counter, isCounter := reader.(PacketReadCounter); isCounter {
|
||||||
|
upstreamReader, upstreamCountFunc := counter.UnwrapPacketReader()
|
||||||
|
countFunc = append(countFunc, upstreamCountFunc...)
|
||||||
|
return UnwrapCountPacketReader(upstreamReader, countFunc)
|
||||||
|
}
|
||||||
|
return reader, countFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (PacketWriter, []CountFunc) {
|
||||||
|
writer = UnwrapPacketWriter(writer)
|
||||||
|
if counter, isCounter := writer.(PacketWriteCounter); isCounter {
|
||||||
|
upstreamWriter, upstreamCountFunc := counter.UnwrapPacketWriter()
|
||||||
|
countFunc = append(countFunc, upstreamCountFunc...)
|
||||||
|
return UnwrapCountPacketWriter(upstreamWriter, countFunc)
|
||||||
|
}
|
||||||
|
return writer, countFunc
|
||||||
|
}
|
22
common/network/direct.go
Normal file
22
common/network/direct.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReadWaiter interface {
|
||||||
|
WaitReadBuffer(newBuffer func() *buf.Buffer) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReadWaitCreator interface {
|
||||||
|
CreateReadWaiter() (ReadWaiter, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketReadWaiter interface {
|
||||||
|
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketReadWaitCreator interface {
|
||||||
|
CreateReadWaiter() (PacketReadWaiter, bool)
|
||||||
|
}
|
|
@ -136,6 +136,17 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
select {
|
||||||
|
case p := <-c.data:
|
||||||
|
_, err = newBuffer().ReadOnceFrom(p.data)
|
||||||
|
p.data.Release()
|
||||||
|
return p.destination, err
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return M.Socksaddr{}, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
return c.source.WritePacket(buffer, destination)
|
return c.source.WritePacket(buffer, destination)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue