mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Refactor bufio (break change)
This commit is contained in:
parent
f4d911a3b1
commit
169983a8d7
10 changed files with 539 additions and 309 deletions
|
@ -29,6 +29,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
if !c.cache.IsEmpty() {
|
||||
return common.Error(buffer.ReadFrom(c.cache))
|
||||
}
|
||||
c.cache.FullReset()
|
||||
err := c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -40,6 +41,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
|
|||
if !c.cache.IsEmpty() {
|
||||
return c.cache.Read(p)
|
||||
}
|
||||
c.cache.FullReset()
|
||||
err = c.upstream.ReadBuffer(c.cache)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -52,6 +54,10 @@ func (c *ChunkReader) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *ChunkReader) MTU() int {
|
||||
return c.maxChunkSize
|
||||
}
|
||||
|
||||
type ChunkWriter struct {
|
||||
upstream N.ExtendedWriter
|
||||
maxChunkSize int
|
||||
|
@ -96,3 +102,7 @@ func (w *ChunkWriter) WriteBuffer(buffer *buf.Buffer) error {
|
|||
func (w *ChunkWriter) Upstream() any {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *ChunkWriter) MTU() int {
|
||||
return w.maxChunkSize
|
||||
}
|
||||
|
|
|
@ -1,276 +1,15 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type readOnlyReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return Copy(w, r.Reader)
|
||||
}
|
||||
|
||||
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
|
||||
_, isTCPConn := dst.(*net.TCPConn)
|
||||
if !isTCPConn {
|
||||
return false
|
||||
}
|
||||
switch src.(type) {
|
||||
case *net.TCPConn, *net.UnixConn, *os.File:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
if src == nil {
|
||||
return 0, E.New("nil reader")
|
||||
} else if dst == nil {
|
||||
return 0, E.New("nil writer")
|
||||
}
|
||||
src = N.UnwrapReader(src)
|
||||
dst = N.UnwrapWriter(dst)
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
if needReadFromWrapper(rt, src) {
|
||||
src = &readOnlyReader{src}
|
||||
}
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
|
||||
}
|
||||
|
||||
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
if _, isHandshakeConn := common.Cast[N.HandshakeConn](dst); isHandshakeConn {
|
||||
n, err = CopyExtendedOnce(dst, src)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var copyN int64
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
|
||||
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
|
||||
if srcUnsafe {
|
||||
copyN, err = CopyExtendedWithSrcBuffer(dst, unsafeSrc)
|
||||
} else if dstUnsafe {
|
||||
copyN, err = CopyExtendedWithPool(dst, src)
|
||||
} else {
|
||||
_buffer := buf.StackNew()
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
copyN, err = CopyExtendedBuffer(dst, src, buffer)
|
||||
}
|
||||
n += copyN
|
||||
return
|
||||
}
|
||||
|
||||
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024])
|
||||
|
||||
for {
|
||||
buffer.Reset()
|
||||
readBuffer.Reset()
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = src.ReadBufferThreadSafe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024])
|
||||
readBuffer.Reset()
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
}
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
return common.Error(Copy(dest, conn))
|
||||
})
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(Copy(conn, dest))
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
})
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
|
||||
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
|
||||
if srcUnsafe {
|
||||
dstHeadroom := N.CalculateHeadroom(dst)
|
||||
if dstHeadroom == 0 {
|
||||
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
|
||||
}
|
||||
}
|
||||
if dstUnsafe {
|
||||
return CopyPacketWithPool(dst, src)
|
||||
}
|
||||
|
||||
_buffer := buf.StackNewPacket()
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer.Reset()
|
||||
destination, err = src.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if buffer.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer, destination, err = src.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
destination, err = src.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if buffer.IsFull() {
|
||||
buffer.Release()
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(dest, conn))
|
||||
})
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(conn, dest))
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.PacketConn) N.NetPacketConn {
|
||||
if packetConn, ok := conn.(N.NetPacketConn); ok {
|
||||
return packetConn
|
||||
|
@ -468,6 +207,34 @@ func (w *ExtendedConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
|
|||
return w.writer.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(w.writer, r)
|
||||
}
|
||||
|
||||
func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return Copy(w, r.reader)
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) UpstreamReader() io.Reader {
|
||||
return w.reader
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer {
|
||||
return w.writer
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) Upstream() any {
|
||||
return w.Conn
|
||||
}
|
||||
|
||||
func NewExtendedConn(conn net.Conn) N.ExtendedConn {
|
||||
if c, ok := conn.(N.ExtendedConn); ok {
|
||||
return c
|
||||
|
|
307
common/bufio/copy.go
Normal file
307
common/bufio/copy.go
Normal file
|
@ -0,0 +1,307 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type readOnlyReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return Copy(w, r.Reader)
|
||||
}
|
||||
|
||||
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
|
||||
_, isTCPConn := dst.(*net.TCPConn)
|
||||
if !isTCPConn {
|
||||
return false
|
||||
}
|
||||
switch src.(type) {
|
||||
case *net.TCPConn, *net.UnixConn, *os.File:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
if src == nil {
|
||||
return 0, E.New("nil reader")
|
||||
} else if dst == nil {
|
||||
return 0, E.New("nil writer")
|
||||
}
|
||||
src = N.UnwrapReader(src)
|
||||
dst = N.UnwrapWriter(dst)
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
if needReadFromWrapper(rt, src) {
|
||||
src = &readOnlyReader{src}
|
||||
}
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
|
||||
}
|
||||
|
||||
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
|
||||
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
|
||||
if srcUnsafe {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(dst, unsafeSrc)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyExtendedWithPool(dst, src)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
return CopyExtendedBuffer(dst, src, buffer)
|
||||
}
|
||||
|
||||
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
|
||||
var notFirstTime bool
|
||||
for {
|
||||
var buffer *buf.Buffer
|
||||
buffer, err = src.ReadBufferThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
return common.Error(Copy(dest, conn))
|
||||
})
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
return common.Error(Copy(conn, dest))
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
})
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if srcUnsafe {
|
||||
if headroom == 0 {
|
||||
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyPacketWithPool(dst, src)
|
||||
}
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += headroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer, destination, err = src.ReadPacketThreadSafe()
|
||||
if err != nil {
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(dest, conn))
|
||||
})
|
||||
group.Append("download", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(conn, dest))
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.Run(ctx)
|
||||
}
|
|
@ -8,40 +8,55 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func CopyOnce(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
extendedSrc, srcExtended := src.(N.ExtendedReader)
|
||||
extendedDst, dstExtended := dst.(N.ExtendedWriter)
|
||||
if !srcExtended {
|
||||
extendedSrc = &ExtendedReaderWrapper{src}
|
||||
}
|
||||
if !dstExtended {
|
||||
extendedDst = &ExtendedWriterWrapper{dst}
|
||||
}
|
||||
return CopyExtendedOnce(extendedDst, extendedSrc)
|
||||
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
|
||||
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
|
||||
}
|
||||
|
||||
func CopyExtendedOnce(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
var buffer *buf.Buffer
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
buffer = buf.New()
|
||||
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
_buffer := buf.StackNew()
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
dstUnsafe := N.IsUnsafeWriter(dst)
|
||||
var buffer *buf.Buffer
|
||||
if !dstUnsafe {
|
||||
_buffer := buf.StackNewSize(bufferSize)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer = common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.IncRef()
|
||||
defer buffer.DecRef()
|
||||
}
|
||||
err = src.ReadBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
return
|
||||
notFirstTime := true
|
||||
for i := 0; i < times; i++ {
|
||||
if dstUnsafe {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
}
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
dataLen := buffer.Len()
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -51,7 +66,21 @@ type ReadFromWriter interface {
|
|||
}
|
||||
|
||||
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
|
||||
n, err = CopyOnce(readerFrom, reader)
|
||||
n, err = CopyTimes(readerFrom, reader, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var rn int64
|
||||
rn, err = readerFrom.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += rn
|
||||
return
|
||||
}
|
||||
|
||||
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(readerFrom, reader, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -70,7 +99,21 @@ type WriteToReader interface {
|
|||
}
|
||||
|
||||
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
|
||||
n, err = CopyOnce(writer, writerTo)
|
||||
n, err = CopyTimes(writer, writerTo, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var wn int64
|
||||
wn, err = writerTo.WriteTo(writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += wn
|
||||
return
|
||||
}
|
||||
|
||||
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
|
||||
n, err = CopyTimes(writer, writerTo, times)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -10,7 +10,10 @@ import (
|
|||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
const MaxSocksaddrLength = 2 + 255 + 2
|
||||
const (
|
||||
MaxSocksaddrLength = 2 + 255 + 2
|
||||
MaxIPSocksaddrLength = 1 + 16 + 2
|
||||
)
|
||||
|
||||
type SerializerOption func(*Serializer)
|
||||
|
||||
|
|
|
@ -77,26 +77,44 @@ type CachedReader interface {
|
|||
ReadCached() *buf.Buffer
|
||||
}
|
||||
|
||||
type WithUpstreamReader interface {
|
||||
UpstreamReader() io.Reader
|
||||
}
|
||||
|
||||
type WithUpstreamWriter interface {
|
||||
UpstreamWriter() io.Writer
|
||||
}
|
||||
|
||||
type ReaderWithUpstream interface {
|
||||
common.WithUpstream
|
||||
ReaderReplaceable() bool
|
||||
}
|
||||
|
||||
type WriterWithUpstream interface {
|
||||
common.WithUpstream
|
||||
WriterReplaceable() bool
|
||||
}
|
||||
|
||||
func UnwrapReader(reader io.Reader) io.Reader {
|
||||
if u, ok := reader.(ReaderWithUpstream); ok && u.ReaderReplaceable() {
|
||||
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
|
||||
return reader
|
||||
}
|
||||
if u, ok := reader.(WithUpstreamReader); ok {
|
||||
return UnwrapReader(u.UpstreamReader())
|
||||
}
|
||||
if u, ok := reader.(common.WithUpstream); ok {
|
||||
return UnwrapReader(u.Upstream().(io.Reader))
|
||||
}
|
||||
return reader
|
||||
panic("bad reader")
|
||||
}
|
||||
|
||||
func UnwrapWriter(writer io.Writer) io.Writer {
|
||||
if u, ok := writer.(WriterWithUpstream); ok && u.WriterReplaceable() {
|
||||
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
|
||||
return writer
|
||||
}
|
||||
if u, ok := writer.(WithUpstreamWriter); ok {
|
||||
return UnwrapWriter(u.UpstreamWriter())
|
||||
}
|
||||
if u, ok := writer.(common.WithUpstream); ok {
|
||||
return UnwrapWriter(u.Upstream().(io.Writer))
|
||||
}
|
||||
return writer
|
||||
panic("bad writer")
|
||||
}
|
||||
|
|
|
@ -18,22 +18,100 @@ type ThreadSafePacketReader interface {
|
|||
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
type HeadroomWriter interface {
|
||||
Headroom() int
|
||||
}
|
||||
|
||||
func IsUnsafeWriter(writer any) bool {
|
||||
_, isUnsafe := common.Cast[ThreadUnsafeWriter](writer)
|
||||
return isUnsafe
|
||||
}
|
||||
|
||||
func CalculateHeadroom(writer any) int {
|
||||
type FrontHeadroom interface {
|
||||
FrontHeadroom() int
|
||||
}
|
||||
|
||||
type RearHeadroom interface {
|
||||
RearHeadroom() int
|
||||
}
|
||||
|
||||
func CalculateFrontHeadroom(writer any) int {
|
||||
var headroom int
|
||||
if headroomWriter, needHeadroom := writer.(HeadroomWriter); needHeadroom {
|
||||
headroom = headroomWriter.Headroom()
|
||||
if headroomWriter, needHeadroom := writer.(FrontHeadroom); needHeadroom {
|
||||
headroom = headroomWriter.FrontHeadroom()
|
||||
}
|
||||
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
|
||||
return headroom + CalculateHeadroom(upstream.Upstream())
|
||||
headroom += CalculateFrontHeadroom(upstream.Upstream())
|
||||
}
|
||||
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
|
||||
headroom += CalculateFrontHeadroom(upstream.UpstreamWriter())
|
||||
}
|
||||
return headroom
|
||||
}
|
||||
|
||||
func CalculateRearHeadroom(writer any) int {
|
||||
var headroom int
|
||||
if headroomWriter, needHeadroom := writer.(RearHeadroom); needHeadroom {
|
||||
headroom = headroomWriter.RearHeadroom()
|
||||
}
|
||||
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
|
||||
headroom += CalculateRearHeadroom(upstream.Upstream())
|
||||
}
|
||||
|
||||
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
|
||||
headroom += CalculateRearHeadroom(upstream.UpstreamWriter())
|
||||
}
|
||||
return headroom
|
||||
}
|
||||
|
||||
type ReaderWithMTU interface {
|
||||
ReaderMTU() int
|
||||
}
|
||||
|
||||
type WriterWithMTU interface {
|
||||
WriterMTU() int
|
||||
}
|
||||
|
||||
func CalculateMTU(reader any, writer any) int {
|
||||
mtu := calculateReaderMTU(reader)
|
||||
if mtu == 0 {
|
||||
return mtu
|
||||
}
|
||||
return calculateWriterMTU(writer)
|
||||
}
|
||||
|
||||
func calculateReaderMTU(reader any) int {
|
||||
var mtu int
|
||||
if withMTU, haveMTU := reader.(ReaderWithMTU); haveMTU {
|
||||
mtu = withMTU.ReaderMTU()
|
||||
}
|
||||
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
|
||||
upstreamMTU := calculateReaderMTU(upstream.Upstream())
|
||||
if upstreamMTU > mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
|
||||
upstreamMTU := calculateReaderMTU(upstream.UpstreamReader())
|
||||
if upstreamMTU > mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
return mtu
|
||||
}
|
||||
|
||||
func calculateWriterMTU(writer any) int {
|
||||
var mtu int
|
||||
if withMTU, haveMTU := writer.(WriterWithMTU); haveMTU {
|
||||
mtu = withMTU.WriterMTU()
|
||||
}
|
||||
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
|
||||
upstreamMTU := calculateWriterMTU(upstream.Upstream())
|
||||
if mtu == 0 && upstreamMTU < mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
|
||||
upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter())
|
||||
if mtu == 0 && upstreamMTU < mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
return mtu
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ func (c *AssociatePacketConn) Upstream() any {
|
|||
return c.PacketConn
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) Headroom() int {
|
||||
func (c *AssociatePacketConn) FrontHeadroom() int {
|
||||
return 3 + M.MaxSocksaddrLength
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return bufio.Copy(w, c.Conn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) Headroom() int {
|
||||
func (c *ClientConn) FrontHeadroom() int {
|
||||
if !c.headerWritten {
|
||||
return KeyLength + 5 + M.MaxSocksaddrLength
|
||||
}
|
||||
|
@ -132,11 +132,11 @@ func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Headroom() int {
|
||||
func (c *ClientPacketConn) FrontHeadroom() int {
|
||||
if !c.headerWritten {
|
||||
return KeyLength + 2*M.MaxSocksaddrLength + 9
|
||||
}
|
||||
return 0
|
||||
return M.MaxSocksaddrLength + 4
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Upstream() any {
|
||||
|
|
|
@ -132,6 +132,10 @@ func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) er
|
|||
return WritePacket(c.Conn, buffer, destination)
|
||||
}
|
||||
|
||||
func (c *PacketConn) FrontHeadroom() int {
|
||||
return M.MaxSocksaddrLength + 4
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Metadata M.Metadata
|
||||
Conn net.Conn
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue