Refactor bufio (break change)

This commit is contained in:
世界 2022-08-11 17:02:56 +08:00
parent f4d911a3b1
commit 169983a8d7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 539 additions and 309 deletions

View file

@ -29,6 +29,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error {
if !c.cache.IsEmpty() {
return common.Error(buffer.ReadFrom(c.cache))
}
c.cache.FullReset()
err := c.upstream.ReadBuffer(c.cache)
if err != nil {
return err
@ -40,6 +41,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) {
if !c.cache.IsEmpty() {
return c.cache.Read(p)
}
c.cache.FullReset()
err = c.upstream.ReadBuffer(c.cache)
if err != nil {
return
@ -52,6 +54,10 @@ func (c *ChunkReader) Close() error {
return nil
}
func (c *ChunkReader) MTU() int {
return c.maxChunkSize
}
type ChunkWriter struct {
upstream N.ExtendedWriter
maxChunkSize int
@ -96,3 +102,7 @@ func (w *ChunkWriter) WriteBuffer(buffer *buf.Buffer) error {
func (w *ChunkWriter) Upstream() any {
return w.upstream
}
func (w *ChunkWriter) MTU() int {
return w.maxChunkSize
}

View file

@ -1,276 +1,15 @@
package bufio
import (
"context"
"io"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type readOnlyReader struct {
io.Reader
}
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.Reader)
}
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
_, isTCPConn := dst.(*net.TCPConn)
if !isTCPConn {
return false
}
switch src.(type) {
case *net.TCPConn, *net.UnixConn, *os.File:
return false
default:
return true
}
}
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
if src == nil {
return 0, E.New("nil reader")
} else if dst == nil {
return 0, E.New("nil writer")
}
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
if rt, ok := dst.(io.ReaderFrom); ok {
if needReadFromWrapper(rt, src) {
src = &readOnlyReader{src}
}
return rt.ReadFrom(src)
}
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
}
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
if _, isHandshakeConn := common.Cast[N.HandshakeConn](dst); isHandshakeConn {
n, err = CopyExtendedOnce(dst, src)
if err != nil {
return
}
}
var copyN int64
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
copyN, err = CopyExtendedWithSrcBuffer(dst, unsafeSrc)
} else if dstUnsafe {
copyN, err = CopyExtendedWithPool(dst, src)
} else {
_buffer := buf.StackNew()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
copyN, err = CopyExtendedBuffer(dst, src, buffer)
}
n += copyN
return
}
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024])
for {
buffer.Reset()
readBuffer.Reset()
err = src.ReadBuffer(readBuffer)
if err != nil {
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
return
}
n += int64(dataLen)
}
}
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
for {
var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe()
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
for {
buffer := buf.New()
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-1024])
readBuffer.Reset()
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
return common.Error(Copy(dest, conn))
})
group.Append("download", func(ctx context.Context) error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
return common.Error(Copy(conn, dest))
})
group.Cleanup(func() {
common.Close(conn, dest)
})
return group.Run(ctx)
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
dstHeadroom := N.CalculateHeadroom(dst)
if dstHeadroom == 0 {
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
}
}
if dstUnsafe {
return CopyPacketWithPool(dst, src)
}
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
var notFirstTime bool
for {
buffer.Reset()
destination, err = src.ReadPacket(buffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
if buffer.IsFull() {
return 0, io.ErrShortBuffer
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := buf.NewPacket()
destination, err = src.ReadPacket(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
if buffer.IsFull() {
buffer.Release()
return 0, io.ErrShortBuffer
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(dest, conn))
})
group.Append("download", func(ctx context.Context) error {
return common.Error(CopyPacket(conn, dest))
})
group.Cleanup(func() {
common.Close(conn, dest)
})
group.FastFail()
return group.Run(ctx)
}
func NewPacketConn(conn net.PacketConn) N.NetPacketConn {
if packetConn, ok := conn.(N.NetPacketConn); ok {
return packetConn
@ -468,6 +207,34 @@ func (w *ExtendedConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
return w.writer.WriteBuffer(buffer)
}
func (w *ExtendedConnWrapper) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(w.writer, r)
}
func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.reader)
}
func (w *ExtendedConnWrapper) UpstreamReader() io.Reader {
return w.reader
}
func (w *ExtendedConnWrapper) ReaderReplaceable() bool {
return true
}
func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer {
return w.writer
}
func (w *ExtendedConnWrapper) WriterReplaceable() bool {
return true
}
func (w *ExtendedConnWrapper) Upstream() any {
return w.Conn
}
func NewExtendedConn(conn net.Conn) N.ExtendedConn {
if c, ok := conn.(N.ExtendedConn); ok {
return c

307
common/bufio/copy.go Normal file
View file

@ -0,0 +1,307 @@
package bufio
import (
"context"
"io"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type readOnlyReader struct {
io.Reader
}
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.Reader)
}
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
_, isTCPConn := dst.(*net.TCPConn)
if !isTCPConn {
return false
}
switch src.(type) {
case *net.TCPConn, *net.UnixConn, *os.File:
return false
default:
return true
}
}
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
if src == nil {
return 0, E.New("nil reader")
} else if dst == nil {
return 0, E.New("nil writer")
}
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
if rt, ok := dst.(io.ReaderFrom); ok {
if needReadFromWrapper(rt, src) {
src = &readOnlyReader{src}
}
return rt.ReadFrom(src)
}
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
}
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if srcUnsafe {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, unsafeSrc)
}
}
if N.IsUnsafeWriter(dst) {
return CopyExtendedWithPool(dst, src)
}
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.BufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
return CopyExtendedBuffer(dst, src, buffer)
}
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
return common.Error(Copy(dest, conn))
})
group.Append("download", func(ctx context.Context) error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
return common.Error(Copy(conn, dest))
})
group.Cleanup(func() {
common.Close(conn, dest)
})
return group.Run(ctx)
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
headroom := frontHeadroom + rearHeadroom
if srcUnsafe {
if headroom == 0 {
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
}
}
if N.IsUnsafeWriter(dst) {
return CopyPacketWithPool(dst, src)
}
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += headroom
} else {
bufferSize = buf.UDPBufferSize
}
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
var notFirstTime bool
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
for {
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(dest, conn))
})
group.Append("download", func(ctx context.Context) error {
return common.Error(CopyPacket(conn, dest))
})
group.Cleanup(func() {
common.Close(conn, dest)
})
group.FastFail()
return group.Run(ctx)
}

View file

@ -8,40 +8,55 @@ import (
N "github.com/sagernet/sing/common/network"
)
func CopyOnce(dst io.Writer, src io.Reader) (n int64, err error) {
extendedSrc, srcExtended := src.(N.ExtendedReader)
extendedDst, dstExtended := dst.(N.ExtendedWriter)
if !srcExtended {
extendedSrc = &ExtendedReaderWrapper{src}
}
if !dstExtended {
extendedDst = &ExtendedWriterWrapper{dst}
}
return CopyExtendedOnce(extendedDst, extendedSrc)
func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
}
func CopyExtendedOnce(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
var buffer *buf.Buffer
if N.IsUnsafeWriter(dst) {
buffer = buf.New()
func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
_buffer := buf.StackNew()
bufferSize = buf.BufferSize
}
dstUnsafe := N.IsUnsafeWriter(dst)
var buffer *buf.Buffer
if !dstUnsafe {
_buffer := buf.StackNewSize(bufferSize)
defer common.KeepAlive(_buffer)
buffer = common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
}
err = src.ReadBuffer(buffer)
if err != nil {
buffer.Release()
err = N.HandshakeFailure(dst, err)
return
notFirstTime := true
for i := 0; i < times; i++ {
if dstUnsafe {
buffer = buf.NewSize(bufferSize)
}
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
return
}
@ -51,7 +66,21 @@ type ReadFromWriter interface {
}
func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
n, err = CopyOnce(readerFrom, reader)
n, err = CopyTimes(readerFrom, reader, 1)
if err != nil {
return
}
var rn int64
rn, err = readerFrom.ReadFrom(reader)
if err != nil {
return
}
n += rn
return
}
func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
n, err = CopyTimes(readerFrom, reader, times)
if err != nil {
return
}
@ -70,7 +99,21 @@ type WriteToReader interface {
}
func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
n, err = CopyOnce(writer, writerTo)
n, err = CopyTimes(writer, writerTo, 1)
if err != nil {
return
}
var wn int64
wn, err = writerTo.WriteTo(writer)
if err != nil {
return
}
n += wn
return
}
func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
n, err = CopyTimes(writer, writerTo, times)
if err != nil {
return
}

View file

@ -10,7 +10,10 @@ import (
"github.com/sagernet/sing/common/rw"
)
const MaxSocksaddrLength = 2 + 255 + 2
const (
MaxSocksaddrLength = 2 + 255 + 2
MaxIPSocksaddrLength = 1 + 16 + 2
)
type SerializerOption func(*Serializer)

View file

@ -77,26 +77,44 @@ type CachedReader interface {
ReadCached() *buf.Buffer
}
type WithUpstreamReader interface {
UpstreamReader() io.Reader
}
type WithUpstreamWriter interface {
UpstreamWriter() io.Writer
}
type ReaderWithUpstream interface {
common.WithUpstream
ReaderReplaceable() bool
}
type WriterWithUpstream interface {
common.WithUpstream
WriterReplaceable() bool
}
func UnwrapReader(reader io.Reader) io.Reader {
if u, ok := reader.(ReaderWithUpstream); ok && u.ReaderReplaceable() {
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return reader
}
if u, ok := reader.(WithUpstreamReader); ok {
return UnwrapReader(u.UpstreamReader())
}
if u, ok := reader.(common.WithUpstream); ok {
return UnwrapReader(u.Upstream().(io.Reader))
}
return reader
panic("bad reader")
}
func UnwrapWriter(writer io.Writer) io.Writer {
if u, ok := writer.(WriterWithUpstream); ok && u.WriterReplaceable() {
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer
}
if u, ok := writer.(WithUpstreamWriter); ok {
return UnwrapWriter(u.UpstreamWriter())
}
if u, ok := writer.(common.WithUpstream); ok {
return UnwrapWriter(u.Upstream().(io.Writer))
}
return writer
panic("bad writer")
}

View file

@ -18,22 +18,100 @@ type ThreadSafePacketReader interface {
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
}
type HeadroomWriter interface {
Headroom() int
}
func IsUnsafeWriter(writer any) bool {
_, isUnsafe := common.Cast[ThreadUnsafeWriter](writer)
return isUnsafe
}
func CalculateHeadroom(writer any) int {
type FrontHeadroom interface {
FrontHeadroom() int
}
type RearHeadroom interface {
RearHeadroom() int
}
func CalculateFrontHeadroom(writer any) int {
var headroom int
if headroomWriter, needHeadroom := writer.(HeadroomWriter); needHeadroom {
headroom = headroomWriter.Headroom()
if headroomWriter, needHeadroom := writer.(FrontHeadroom); needHeadroom {
headroom = headroomWriter.FrontHeadroom()
}
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
return headroom + CalculateHeadroom(upstream.Upstream())
headroom += CalculateFrontHeadroom(upstream.Upstream())
}
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
headroom += CalculateFrontHeadroom(upstream.UpstreamWriter())
}
return headroom
}
func CalculateRearHeadroom(writer any) int {
var headroom int
if headroomWriter, needHeadroom := writer.(RearHeadroom); needHeadroom {
headroom = headroomWriter.RearHeadroom()
}
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
headroom += CalculateRearHeadroom(upstream.Upstream())
}
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
headroom += CalculateRearHeadroom(upstream.UpstreamWriter())
}
return headroom
}
type ReaderWithMTU interface {
ReaderMTU() int
}
type WriterWithMTU interface {
WriterMTU() int
}
func CalculateMTU(reader any, writer any) int {
mtu := calculateReaderMTU(reader)
if mtu == 0 {
return mtu
}
return calculateWriterMTU(writer)
}
func calculateReaderMTU(reader any) int {
var mtu int
if withMTU, haveMTU := reader.(ReaderWithMTU); haveMTU {
mtu = withMTU.ReaderMTU()
}
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
upstreamMTU := calculateReaderMTU(upstream.Upstream())
if upstreamMTU > mtu {
mtu = upstreamMTU
}
}
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
upstreamMTU := calculateReaderMTU(upstream.UpstreamReader())
if upstreamMTU > mtu {
mtu = upstreamMTU
}
}
return mtu
}
func calculateWriterMTU(writer any) int {
var mtu int
if withMTU, haveMTU := writer.(WriterWithMTU); haveMTU {
mtu = withMTU.WriterMTU()
}
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
upstreamMTU := calculateWriterMTU(upstream.Upstream())
if mtu == 0 && upstreamMTU < mtu {
mtu = upstreamMTU
}
}
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter())
if mtu == 0 && upstreamMTU < mtu {
mtu = upstreamMTU
}
}
return mtu
}

View file

@ -110,7 +110,7 @@ func (c *AssociatePacketConn) Upstream() any {
return c.PacketConn
}
func (c *AssociatePacketConn) Headroom() int {
func (c *AssociatePacketConn) FrontHeadroom() int {
return 3 + M.MaxSocksaddrLength
}

View file

@ -75,7 +75,7 @@ func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
return bufio.Copy(w, c.Conn)
}
func (c *ClientConn) Headroom() int {
func (c *ClientConn) FrontHeadroom() int {
if !c.headerWritten {
return KeyLength + 5 + M.MaxSocksaddrLength
}
@ -132,11 +132,11 @@ func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return
}
func (c *ClientPacketConn) Headroom() int {
func (c *ClientPacketConn) FrontHeadroom() int {
if !c.headerWritten {
return KeyLength + 2*M.MaxSocksaddrLength + 9
}
return 0
return M.MaxSocksaddrLength + 4
}
func (c *ClientPacketConn) Upstream() any {

View file

@ -132,6 +132,10 @@ func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) er
return WritePacket(c.Conn, buffer, destination)
}
func (c *PacketConn) FrontHeadroom() int {
return M.MaxSocksaddrLength + 4
}
type Error struct {
Metadata M.Metadata
Conn net.Conn