diff --git a/common/bufio/chunk.go b/common/bufio/chunk.go index 9d31b59..d3058e3 100644 --- a/common/bufio/chunk.go +++ b/common/bufio/chunk.go @@ -29,6 +29,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error { if !c.cache.IsEmpty() { return common.Error(buffer.ReadFrom(c.cache)) } + c.cache.FullReset() err := c.upstream.ReadBuffer(c.cache) if err != nil { return err @@ -40,6 +41,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) { if !c.cache.IsEmpty() { return c.cache.Read(p) } + c.cache.FullReset() err = c.upstream.ReadBuffer(c.cache) if err != nil { return @@ -52,6 +54,10 @@ func (c *ChunkReader) Close() error { return nil } +func (c *ChunkReader) MTU() int { + return c.maxChunkSize +} + type ChunkWriter struct { upstream N.ExtendedWriter maxChunkSize int @@ -96,3 +102,7 @@ func (w *ChunkWriter) WriteBuffer(buffer *buf.Buffer) error { func (w *ChunkWriter) Upstream() any { return w.upstream } + +func (w *ChunkWriter) MTU() int { + return w.maxChunkSize +} diff --git a/common/bufio/conn.go b/common/bufio/conn.go index 39b22a8..b879478 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -1,276 +1,15 @@ package bufio import ( - "context" "io" "net" - "os" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" - "github.com/sagernet/sing/common/task" ) -type readOnlyReader struct { - io.Reader -} - -func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) { - return Copy(w, r.Reader) -} - -func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool { - _, isTCPConn := dst.(*net.TCPConn) - if !isTCPConn { - return false - } - switch src.(type) { - case *net.TCPConn, *net.UnixConn, *os.File: - return false - default: - return true - } -} - -func Copy(dst io.Writer, src io.Reader) (n int64, err error) { - if src == nil { - return 0, E.New("nil reader") - } else if dst == nil { - return 0, E.New("nil writer") - } - src = N.UnwrapReader(src) - dst = N.UnwrapWriter(dst) - if wt, ok := src.(io.WriterTo); ok { - return wt.WriteTo(dst) - } - if rt, ok := dst.(io.ReaderFrom); ok { - if needReadFromWrapper(rt, src) { - src = &readOnlyReader{src} - } - return rt.ReadFrom(src) - } - return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src)) -} - -func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { - if _, isHandshakeConn := common.Cast[N.HandshakeConn](dst); isHandshakeConn { - n, err = CopyExtendedOnce(dst, src) - if err != nil { - return - } - } - var copyN int64 - unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src) - _, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst) - if srcUnsafe { - copyN, err = CopyExtendedWithSrcBuffer(dst, unsafeSrc) - } else if dstUnsafe { - copyN, err = CopyExtendedWithPool(dst, src) - } else { - _buffer := buf.StackNew() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - copyN, err = CopyExtendedBuffer(dst, src, buffer) - } - n += copyN - return -} - -func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) { - buffer.IncRef() - defer buffer.DecRef() - - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024]) - - for { - buffer.Reset() - readBuffer.Reset() - err = src.ReadBuffer(readBuffer) - if err != nil { - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) - if err != nil { - return - } - n += int64(dataLen) - } -} - -func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) { - for { - var buffer *buf.Buffer - buffer, err = src.ReadBufferThreadSafe() - if err != nil { - return - } - dataLen := buffer.Len() - err = dst.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - } -} - -func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { - for { - buffer := buf.New() - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024]) - readBuffer.Reset() - err = src.ReadBuffer(readBuffer) - if err != nil { - buffer.Release() - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - } -} - -func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { - var group task.Group - group.Append("upload", func(ctx context.Context) error { - defer rw.CloseRead(conn) - defer rw.CloseWrite(dest) - return common.Error(Copy(dest, conn)) - }) - group.Append("download", func(ctx context.Context) error { - defer rw.CloseRead(dest) - defer rw.CloseWrite(conn) - return common.Error(Copy(conn, dest)) - }) - group.Cleanup(func() { - common.Close(conn, dest) - }) - return group.Run(ctx) -} - -func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { - unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src) - _, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst) - if srcUnsafe { - dstHeadroom := N.CalculateHeadroom(dst) - if dstHeadroom == 0 { - return CopyPacketWithSrcBuffer(dst, unsafeSrc) - } - } - if dstUnsafe { - return CopyPacketWithPool(dst, src) - } - - _buffer := buf.StackNewPacket() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - var destination M.Socksaddr - var notFirstTime bool - for { - buffer.Reset() - destination, err = src.ReadPacket(buffer) - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - if buffer.IsFull() { - return 0, io.ErrShortBuffer - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - return - } - n += int64(dataLen) - notFirstTime = true - } -} - -func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) { - var buffer *buf.Buffer - var destination M.Socksaddr - var notFirstTime bool - for { - buffer, destination, err = src.ReadPacketThreadSafe() - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } -} - -func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { - var destination M.Socksaddr - var notFirstTime bool - for { - buffer := buf.NewPacket() - destination, err = src.ReadPacket(buffer) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - if buffer.IsFull() { - buffer.Release() - return 0, io.ErrShortBuffer - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } -} - -func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { - var group task.Group - group.Append("upload", func(ctx context.Context) error { - return common.Error(CopyPacket(dest, conn)) - }) - group.Append("download", func(ctx context.Context) error { - return common.Error(CopyPacket(conn, dest)) - }) - group.Cleanup(func() { - common.Close(conn, dest) - }) - group.FastFail() - return group.Run(ctx) -} - func NewPacketConn(conn net.PacketConn) N.NetPacketConn { if packetConn, ok := conn.(N.NetPacketConn); ok { return packetConn @@ -468,6 +207,34 @@ func (w *ExtendedConnWrapper) WriteBuffer(buffer *buf.Buffer) error { return w.writer.WriteBuffer(buffer) } +func (w *ExtendedConnWrapper) ReadFrom(r io.Reader) (n int64, err error) { + return Copy(w.writer, r) +} + +func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) { + return Copy(w, r.reader) +} + +func (w *ExtendedConnWrapper) UpstreamReader() io.Reader { + return w.reader +} + +func (w *ExtendedConnWrapper) ReaderReplaceable() bool { + return true +} + +func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer { + return w.writer +} + +func (w *ExtendedConnWrapper) WriterReplaceable() bool { + return true +} + +func (w *ExtendedConnWrapper) Upstream() any { + return w.Conn +} + func NewExtendedConn(conn net.Conn) N.ExtendedConn { if c, ok := conn.(N.ExtendedConn); ok { return c diff --git a/common/bufio/copy.go b/common/bufio/copy.go new file mode 100644 index 0000000..aa6ca1f --- /dev/null +++ b/common/bufio/copy.go @@ -0,0 +1,307 @@ +package bufio + +import ( + "context" + "io" + "net" + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/task" +) + +type readOnlyReader struct { + io.Reader +} + +func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) { + return Copy(w, r.Reader) +} + +func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool { + _, isTCPConn := dst.(*net.TCPConn) + if !isTCPConn { + return false + } + switch src.(type) { + case *net.TCPConn, *net.UnixConn, *os.File: + return false + default: + return true + } +} + +func Copy(dst io.Writer, src io.Reader) (n int64, err error) { + if src == nil { + return 0, E.New("nil reader") + } else if dst == nil { + return 0, E.New("nil writer") + } + src = N.UnwrapReader(src) + dst = N.UnwrapWriter(dst) + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) + } + if rt, ok := dst.(io.ReaderFrom); ok { + if needReadFromWrapper(rt, src) { + src = &readOnlyReader{src} + } + return rt.ReadFrom(src) + } + return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src)) +} + +func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { + unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src) + headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst) + if srcUnsafe { + if headroom == 0 { + return CopyExtendedWithSrcBuffer(dst, unsafeSrc) + } + } + if N.IsUnsafeWriter(dst) { + return CopyExtendedWithPool(dst, src) + } + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += headroom + } else { + bufferSize = buf.BufferSize + } + _buffer := buf.StackNewSize(bufferSize) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + return CopyExtendedBuffer(dst, src, buffer) +} + +func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) { + buffer.IncRef() + defer buffer.DecRef() + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + var notFirstTime bool + for { + readBuffer.Resize(frontHeadroom, 0) + err = src.ReadBuffer(readBuffer) + if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WriteBuffer(buffer) + if err != nil { + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) { + var notFirstTime bool + for { + var buffer *buf.Buffer + buffer, err = src.ReadBufferThreadSafe() + if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := buffer.Len() + err = dst.WriteBuffer(buffer) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.BufferSize + } + var notFirstTime bool + for { + buffer := buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + err = src.ReadBuffer(readBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WriteBuffer(buffer) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { + var group task.Group + group.Append("upload", func(ctx context.Context) error { + defer rw.CloseRead(conn) + defer rw.CloseWrite(dest) + return common.Error(Copy(dest, conn)) + }) + group.Append("download", func(ctx context.Context) error { + defer rw.CloseRead(dest) + defer rw.CloseWrite(conn) + return common.Error(Copy(conn, dest)) + }) + group.Cleanup(func() { + common.Close(conn, dest) + }) + return group.Run(ctx) +} + +func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { + unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src) + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + headroom := frontHeadroom + rearHeadroom + if srcUnsafe { + if headroom == 0 { + return CopyPacketWithSrcBuffer(dst, unsafeSrc) + } + } + if N.IsUnsafeWriter(dst) { + return CopyPacketWithPool(dst, src) + } + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += headroom + } else { + bufferSize = buf.UDPBufferSize + } + _buffer := buf.StackNewSize(bufferSize) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + buffer.IncRef() + defer buffer.DecRef() + var destination M.Socksaddr + var notFirstTime bool + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + for { + readBuffer.Resize(frontHeadroom, 0) + destination, err = src.ReadPacket(readBuffer) + if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WritePacket(buffer, destination) + if err != nil { + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) { + var buffer *buf.Buffer + var destination M.Socksaddr + var notFirstTime bool + for { + buffer, destination, err = src.ReadPacketThreadSafe() + if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := buffer.Len() + err = dst.WritePacket(buffer, destination) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom + } else { + bufferSize = buf.UDPBufferSize + } + var destination M.Socksaddr + var notFirstTime bool + for { + buffer := buf.NewSize(bufferSize) + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + destination, err = src.ReadPacket(readBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WritePacket(buffer, destination) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true + } +} + +func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { + var group task.Group + group.Append("upload", func(ctx context.Context) error { + return common.Error(CopyPacket(dest, conn)) + }) + group.Append("download", func(ctx context.Context) error { + return common.Error(CopyPacket(conn, dest)) + }) + group.Cleanup(func() { + common.Close(conn, dest) + }) + group.FastFail() + return group.Run(ctx) +} diff --git a/common/bufio/once.go b/common/bufio/once.go index 7c37e31..5bfd0aa 100644 --- a/common/bufio/once.go +++ b/common/bufio/once.go @@ -8,40 +8,55 @@ import ( N "github.com/sagernet/sing/common/network" ) -func CopyOnce(dst io.Writer, src io.Reader) (n int64, err error) { - extendedSrc, srcExtended := src.(N.ExtendedReader) - extendedDst, dstExtended := dst.(N.ExtendedWriter) - if !srcExtended { - extendedSrc = &ExtendedReaderWrapper{src} - } - if !dstExtended { - extendedDst = &ExtendedWriterWrapper{dst} - } - return CopyExtendedOnce(extendedDst, extendedSrc) +func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) { + return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times) } -func CopyExtendedOnce(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { - var buffer *buf.Buffer - if N.IsUnsafeWriter(dst) { - buffer = buf.New() +func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) { + frontHeadroom := N.CalculateFrontHeadroom(dst) + rearHeadroom := N.CalculateRearHeadroom(dst) + bufferSize := N.CalculateMTU(src, dst) + if bufferSize > 0 { + bufferSize += frontHeadroom + rearHeadroom } else { - _buffer := buf.StackNew() + bufferSize = buf.BufferSize + } + dstUnsafe := N.IsUnsafeWriter(dst) + var buffer *buf.Buffer + if !dstUnsafe { + _buffer := buf.StackNewSize(bufferSize) defer common.KeepAlive(_buffer) buffer = common.Dup(_buffer) + defer buffer.Release() + buffer.IncRef() + defer buffer.DecRef() } - err = src.ReadBuffer(buffer) - if err != nil { - buffer.Release() - err = N.HandshakeFailure(dst, err) - return + notFirstTime := true + for i := 0; i < times; i++ { + if dstUnsafe { + buffer = buf.NewSize(bufferSize) + } + readBufferRaw := buffer.Slice() + readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) + readBuffer.Resize(frontHeadroom, 0) + err = src.ReadBuffer(readBuffer) + if err != nil { + buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(dst, err) + } + return + } + dataLen := readBuffer.Len() + buffer.Resize(readBuffer.Start(), dataLen) + err = dst.WriteBuffer(buffer) + if err != nil { + buffer.Release() + return + } + n += int64(dataLen) + notFirstTime = true } - dataLen := buffer.Len() - err = dst.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) return } @@ -51,7 +66,21 @@ type ReadFromWriter interface { } func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) { - n, err = CopyOnce(readerFrom, reader) + n, err = CopyTimes(readerFrom, reader, 1) + if err != nil { + return + } + var rn int64 + rn, err = readerFrom.ReadFrom(reader) + if err != nil { + return + } + n += rn + return +} + +func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) { + n, err = CopyTimes(readerFrom, reader, times) if err != nil { return } @@ -70,7 +99,21 @@ type WriteToReader interface { } func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) { - n, err = CopyOnce(writer, writerTo) + n, err = CopyTimes(writer, writerTo, 1) + if err != nil { + return + } + var wn int64 + wn, err = writerTo.WriteTo(writer) + if err != nil { + return + } + n += wn + return +} + +func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) { + n, err = CopyTimes(writer, writerTo, times) if err != nil { return } diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index fbecd88..a221a9c 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -10,7 +10,10 @@ import ( "github.com/sagernet/sing/common/rw" ) -const MaxSocksaddrLength = 2 + 255 + 2 +const ( + MaxSocksaddrLength = 2 + 255 + 2 + MaxIPSocksaddrLength = 1 + 16 + 2 +) type SerializerOption func(*Serializer) diff --git a/common/network/conn.go b/common/network/conn.go index 271c9ae..89500b9 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -77,26 +77,44 @@ type CachedReader interface { ReadCached() *buf.Buffer } +type WithUpstreamReader interface { + UpstreamReader() io.Reader +} + +type WithUpstreamWriter interface { + UpstreamWriter() io.Writer +} + type ReaderWithUpstream interface { - common.WithUpstream ReaderReplaceable() bool } type WriterWithUpstream interface { - common.WithUpstream WriterReplaceable() bool } func UnwrapReader(reader io.Reader) io.Reader { - if u, ok := reader.(ReaderWithUpstream); ok && u.ReaderReplaceable() { + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return reader + } + if u, ok := reader.(WithUpstreamReader); ok { + return UnwrapReader(u.UpstreamReader()) + } + if u, ok := reader.(common.WithUpstream); ok { return UnwrapReader(u.Upstream().(io.Reader)) } - return reader + panic("bad reader") } func UnwrapWriter(writer io.Writer) io.Writer { - if u, ok := writer.(WriterWithUpstream); ok && u.WriterReplaceable() { + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return writer + } + if u, ok := writer.(WithUpstreamWriter); ok { + return UnwrapWriter(u.UpstreamWriter()) + } + if u, ok := writer.(common.WithUpstream); ok { return UnwrapWriter(u.Upstream().(io.Writer)) } - return writer + panic("bad writer") } diff --git a/common/network/thread.go b/common/network/thread.go index dca2073..f71d62e 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -18,22 +18,100 @@ type ThreadSafePacketReader interface { ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) } -type HeadroomWriter interface { - Headroom() int -} - func IsUnsafeWriter(writer any) bool { _, isUnsafe := common.Cast[ThreadUnsafeWriter](writer) return isUnsafe } -func CalculateHeadroom(writer any) int { +type FrontHeadroom interface { + FrontHeadroom() int +} + +type RearHeadroom interface { + RearHeadroom() int +} + +func CalculateFrontHeadroom(writer any) int { var headroom int - if headroomWriter, needHeadroom := writer.(HeadroomWriter); needHeadroom { - headroom = headroomWriter.Headroom() + if headroomWriter, needHeadroom := writer.(FrontHeadroom); needHeadroom { + headroom = headroomWriter.FrontHeadroom() } if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream { - return headroom + CalculateHeadroom(upstream.Upstream()) + headroom += CalculateFrontHeadroom(upstream.Upstream()) + } + if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream { + headroom += CalculateFrontHeadroom(upstream.UpstreamWriter()) } return headroom } + +func CalculateRearHeadroom(writer any) int { + var headroom int + if headroomWriter, needHeadroom := writer.(RearHeadroom); needHeadroom { + headroom = headroomWriter.RearHeadroom() + } + if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream { + headroom += CalculateRearHeadroom(upstream.Upstream()) + } + + if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream { + headroom += CalculateRearHeadroom(upstream.UpstreamWriter()) + } + return headroom +} + +type ReaderWithMTU interface { + ReaderMTU() int +} + +type WriterWithMTU interface { + WriterMTU() int +} + +func CalculateMTU(reader any, writer any) int { + mtu := calculateReaderMTU(reader) + if mtu == 0 { + return mtu + } + return calculateWriterMTU(writer) +} + +func calculateReaderMTU(reader any) int { + var mtu int + if withMTU, haveMTU := reader.(ReaderWithMTU); haveMTU { + mtu = withMTU.ReaderMTU() + } + if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream { + upstreamMTU := calculateReaderMTU(upstream.Upstream()) + if upstreamMTU > mtu { + mtu = upstreamMTU + } + } + if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream { + upstreamMTU := calculateReaderMTU(upstream.UpstreamReader()) + if upstreamMTU > mtu { + mtu = upstreamMTU + } + } + return mtu +} + +func calculateWriterMTU(writer any) int { + var mtu int + if withMTU, haveMTU := writer.(WriterWithMTU); haveMTU { + mtu = withMTU.WriterMTU() + } + if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream { + upstreamMTU := calculateWriterMTU(upstream.Upstream()) + if mtu == 0 && upstreamMTU < mtu { + mtu = upstreamMTU + } + } + if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream { + upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter()) + if mtu == 0 && upstreamMTU < mtu { + mtu = upstreamMTU + } + } + return mtu +} diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index c95c164..7af9e39 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -110,7 +110,7 @@ func (c *AssociatePacketConn) Upstream() any { return c.PacketConn } -func (c *AssociatePacketConn) Headroom() int { +func (c *AssociatePacketConn) FrontHeadroom() int { return 3 + M.MaxSocksaddrLength } diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go index 2128bb2..72edc5b 100644 --- a/protocol/trojan/protocol.go +++ b/protocol/trojan/protocol.go @@ -75,7 +75,7 @@ func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { return bufio.Copy(w, c.Conn) } -func (c *ClientConn) Headroom() int { +func (c *ClientConn) FrontHeadroom() int { if !c.headerWritten { return KeyLength + 5 + M.MaxSocksaddrLength } @@ -132,11 +132,11 @@ func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return } -func (c *ClientPacketConn) Headroom() int { +func (c *ClientPacketConn) FrontHeadroom() int { if !c.headerWritten { return KeyLength + 2*M.MaxSocksaddrLength + 9 } - return 0 + return M.MaxSocksaddrLength + 4 } func (c *ClientPacketConn) Upstream() any { diff --git a/protocol/trojan/service.go b/protocol/trojan/service.go index a65df8f..eb8b093 100644 --- a/protocol/trojan/service.go +++ b/protocol/trojan/service.go @@ -132,6 +132,10 @@ func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) er return WritePacket(c.Conn, buffer, destination) } +func (c *PacketConn) FrontHeadroom() int { + return M.MaxSocksaddrLength + 4 +} + type Error struct { Metadata M.Metadata Conn net.Conn