mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 03:47:38 +03:00
321 lines
9.4 KiB
Go
321 lines
9.4 KiB
Go
package bufio
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"syscall"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/buf"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/task"
|
|
)
|
|
|
|
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
|
if source == nil {
|
|
return 0, E.New("nil reader")
|
|
} else if destination == nil {
|
|
return 0, E.New("nil writer")
|
|
}
|
|
originSource := source
|
|
var readCounters, writeCounters []N.CountFunc
|
|
for {
|
|
source, readCounters = N.UnwrapCountReader(source, readCounters)
|
|
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
|
cachedBuffer := cachedSrc.ReadCached()
|
|
if cachedBuffer != nil {
|
|
if !cachedBuffer.IsEmpty() {
|
|
dataLen := cachedBuffer.Len()
|
|
for _, counter := range readCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
_, err = destination.Write(cachedBuffer.Bytes())
|
|
if err != nil {
|
|
cachedBuffer.Release()
|
|
return
|
|
}
|
|
for _, counter := range writeCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
}
|
|
cachedBuffer.Release()
|
|
continue
|
|
}
|
|
}
|
|
break
|
|
}
|
|
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
|
}
|
|
|
|
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
|
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
|
if srcIsSyscall && dstIsSyscall {
|
|
var handled bool
|
|
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
|
if handled {
|
|
return
|
|
}
|
|
}
|
|
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
|
}
|
|
|
|
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
|
readWaiter, isReadWaiter := CreateReadWaiter(source)
|
|
if isReadWaiter {
|
|
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
|
FrontHeadroom: frontHeadroom,
|
|
RearHeadroom: rearHeadroom,
|
|
MTU: N.CalculateMTU(source, destination),
|
|
})
|
|
if !needCopy || common.LowMemory {
|
|
var handled bool
|
|
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
|
|
if handled {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
|
}
|
|
|
|
// Deprecated: not used
|
|
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
buffer.IncRef()
|
|
defer buffer.DecRef()
|
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
|
buffer.Resize(frontHeadroom, 0)
|
|
buffer.Reserve(rearHeadroom)
|
|
var notFirstTime bool
|
|
for {
|
|
err = source.ReadBuffer(buffer)
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
err = nil
|
|
return
|
|
}
|
|
return
|
|
}
|
|
dataLen := buffer.Len()
|
|
buffer.OverCap(rearHeadroom)
|
|
err = destination.WriteBuffer(buffer)
|
|
if err != nil {
|
|
if !notFirstTime {
|
|
err = N.ReportHandshakeFailure(originSource, err)
|
|
}
|
|
return
|
|
}
|
|
n += int64(dataLen)
|
|
for _, counter := range readCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
for _, counter := range writeCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
notFirstTime = true
|
|
}
|
|
}
|
|
|
|
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
options := N.NewReadWaitOptions(source, destination)
|
|
var notFirstTime bool
|
|
for {
|
|
buffer := options.NewBuffer()
|
|
err = source.ReadBuffer(buffer)
|
|
if err != nil {
|
|
buffer.Release()
|
|
if errors.Is(err, io.EOF) {
|
|
err = nil
|
|
return
|
|
}
|
|
return
|
|
}
|
|
dataLen := buffer.Len()
|
|
for _, counter := range readCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
options.PostReturn(buffer)
|
|
err = destination.WriteBuffer(buffer)
|
|
if err != nil {
|
|
buffer.Leak()
|
|
if !notFirstTime {
|
|
err = N.ReportHandshakeFailure(originSource, err)
|
|
}
|
|
return
|
|
}
|
|
n += int64(dataLen)
|
|
for _, counter := range writeCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
notFirstTime = true
|
|
}
|
|
}
|
|
|
|
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
|
|
var group task.Group
|
|
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
|
|
group.Append("upload", func(ctx context.Context) error {
|
|
err := common.Error(Copy(destination, source))
|
|
if err == nil {
|
|
N.CloseWrite(destination)
|
|
} else {
|
|
common.Close(destination)
|
|
}
|
|
return err
|
|
})
|
|
} else {
|
|
group.Append("upload", func(ctx context.Context) error {
|
|
defer common.Close(destination)
|
|
return common.Error(Copy(destination, source))
|
|
})
|
|
}
|
|
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
|
|
group.Append("download", func(ctx context.Context) error {
|
|
err := common.Error(Copy(source, destination))
|
|
if err == nil {
|
|
N.CloseWrite(source)
|
|
} else {
|
|
common.Close(source)
|
|
}
|
|
return err
|
|
})
|
|
} else {
|
|
group.Append("download", func(ctx context.Context) error {
|
|
defer common.Close(source)
|
|
return common.Error(Copy(source, destination))
|
|
})
|
|
}
|
|
group.Cleanup(func() {
|
|
common.Close(source, destination)
|
|
})
|
|
return group.Run(ctx)
|
|
}
|
|
|
|
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
|
var readCounters, writeCounters []N.CountFunc
|
|
var cachedPackets []*N.PacketBuffer
|
|
originSource := source
|
|
for {
|
|
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
|
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
|
|
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
|
packet := cachedReader.ReadCachedPacket()
|
|
if packet != nil {
|
|
cachedPackets = append(cachedPackets, packet)
|
|
continue
|
|
}
|
|
}
|
|
break
|
|
}
|
|
if cachedPackets != nil {
|
|
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
|
n += copeN
|
|
return
|
|
}
|
|
|
|
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
var (
|
|
handled bool
|
|
copeN int64
|
|
)
|
|
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
|
if isReadWaiter {
|
|
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
|
if !needCopy || common.LowMemory {
|
|
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
|
if handled {
|
|
n += copeN
|
|
return
|
|
}
|
|
}
|
|
}
|
|
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
|
|
n += copeN
|
|
return
|
|
}
|
|
|
|
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
|
options := N.NewReadWaitOptions(source, destination)
|
|
var destinationAddress M.Socksaddr
|
|
for {
|
|
buffer := options.NewPacketBuffer()
|
|
destinationAddress, err = source.ReadPacket(buffer)
|
|
if err != nil {
|
|
buffer.Release()
|
|
return
|
|
}
|
|
dataLen := buffer.Len()
|
|
for _, counter := range readCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
options.PostReturn(buffer)
|
|
err = destination.WritePacket(buffer, destinationAddress)
|
|
if err != nil {
|
|
buffer.Leak()
|
|
if !notFirstTime {
|
|
err = N.ReportHandshakeFailure(originSource, err)
|
|
}
|
|
return
|
|
}
|
|
for _, counter := range writeCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
n += int64(dataLen)
|
|
notFirstTime = true
|
|
}
|
|
}
|
|
|
|
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
options := N.NewReadWaitOptions(nil, destination)
|
|
var notFirstTime bool
|
|
for _, packetBuffer := range packetBuffers {
|
|
for _, counter := range readCounters {
|
|
counter(int64(packetBuffer.Buffer.Len()))
|
|
}
|
|
buffer := options.Copy(packetBuffer.Buffer)
|
|
dataLen := buffer.Len()
|
|
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
|
N.PutPacketBuffer(packetBuffer)
|
|
if err != nil {
|
|
buffer.Leak()
|
|
if !notFirstTime {
|
|
err = N.ReportHandshakeFailure(originSource, err)
|
|
}
|
|
return
|
|
}
|
|
for _, counter := range writeCounters {
|
|
counter(int64(dataLen))
|
|
}
|
|
n += int64(dataLen)
|
|
notFirstTime = true
|
|
}
|
|
return
|
|
}
|
|
|
|
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
|
|
var group task.Group
|
|
group.Append("upload", func(ctx context.Context) error {
|
|
return common.Error(CopyPacket(destination, source))
|
|
})
|
|
group.Append("download", func(ctx context.Context) error {
|
|
return common.Error(CopyPacket(source, destination))
|
|
})
|
|
group.Cleanup(func() {
|
|
common.Close(source, destination)
|
|
})
|
|
group.FastFail()
|
|
return group.Run(ctx)
|
|
}
|