diff --git a/common/bufio/copy.go b/common/bufio/copy.go index ebb03fe..506eab1 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -30,27 +30,38 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { if !cachedBuffer.IsEmpty() { + dataLen := cachedBuffer.Len() + for _, counter := range readCounters { + counter(int64(dataLen)) + } _, err = destination.Write(cachedBuffer.Bytes()) if err != nil { cachedBuffer.Release() return } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } } cachedBuffer.Release() continue } } - srcSyscallConn, srcIsSyscall := source.(syscall.Conn) - dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) - if srcIsSyscall && dstIsSyscall { - var handled bool - handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) - if handled { - return - } - } break } + return CopyWithCounters(destination, source, originSource, readCounters, writeCounters) +} + +func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + srcSyscallConn, srcIsSyscall := source.(syscall.Conn) + dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) + if srcIsSyscall && dstIsSyscall { + var handled bool + handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) + if handled { + return + } + } return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } @@ -75,6 +86,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) } +// Deprecated: not used func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() @@ -113,19 +125,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so } func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destination) - rearHeadroom := N.CalculateRearHeadroom(destination) - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } + options := N.NewReadWaitOptions(source, destination) var notFirstTime bool for { - buffer := buf.NewSize(bufferSize) - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) + buffer := options.NewBuffer() err = source.ReadBuffer(buffer) if err != nil { buffer.Release() @@ -136,7 +139,10 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + options.PostReturn(buffer) err = destination.WriteBuffer(buffer) if err != nil { buffer.Leak() @@ -146,9 +152,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -196,18 +199,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error return group.Run(ctx) } -// Deprecated: not used -func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { - switch len(contextList) { - case 0: - return CopyConn(context.Background(), source, destination) - case 1: - return CopyConn(contextList[0], source, destination) - default: - panic("invalid context list") - } -} - func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer @@ -225,24 +216,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, break } if cachedPackets != nil { - n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) + n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters) if err != nil { return } } - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) + copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters) + n += copeN + return +} + +func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var ( handled bool copeN int64 ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ - FrontHeadroom: frontHeadroom, - RearHeadroom: rearHeadroom, - MTU: N.CalculateMTU(source, destinationConn), - }) + needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn)) if !needCopy || common.LowMemory { handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) if handled { @@ -256,28 +247,22 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } -func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.UDPBufferSize - } - var destination M.Socksaddr +func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { + options := N.NewReadWaitOptions(source, destination) + var destinationAddress M.Socksaddr for { - buffer := buf.NewSize(bufferSize) - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) - destination, err = source.ReadPacket(buffer) + buffer := options.NewPacketBuffer() + destinationAddress, err = source.ReadPacket(buffer) if err != nil { buffer.Release() return } dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) - err = destinationConn.WritePacket(buffer, destination) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + options.PostReturn(buffer) + err = destination.WritePacket(buffer, destinationAddress) if err != nil { buffer.Leak() if !notFirstTime { @@ -285,34 +270,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri } return } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } for _, counter := range writeCounters { counter(int64(dataLen)) } + n += int64(dataLen) notFirstTime = true } } -func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) +func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + options := N.NewReadWaitOptions(nil, destination) var notFirstTime bool for _, packetBuffer := range packetBuffers { - buffer := buf.NewPacket() - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) - _, err = buffer.Write(packetBuffer.Buffer.Bytes()) - packetBuffer.Buffer.Release() - if err != nil { - buffer.Release() - continue + for _, counter := range readCounters { + counter(int64(packetBuffer.Buffer.Len())) } + buffer := options.Copy(packetBuffer.Buffer) dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) - err = destinationConn.WritePacket(buffer, packetBuffer.Destination) + err = destination.WritePacket(buffer, packetBuffer.Destination) + N.PutPacketBuffer(packetBuffer) if err != nil { buffer.Leak() if !notFirstTime { @@ -320,7 +296,11 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr } return } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } n += int64(dataLen) + notFirstTime = true } return } @@ -339,15 +319,3 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack group.FastFail() return group.Run(ctx) } - -// Deprecated: not used -func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { - switch len(contextList) { - case 0: - return CopyPacketConn(context.Background(), source, destination) - case 1: - return CopyPacketConn(contextList[0], source, destination) - default: - panic("invalid context list") - } -} diff --git a/common/network/direct.go b/common/network/direct.go index 1122d70..f587cd6 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -15,6 +15,14 @@ type ReadWaitOptions struct { MTU int } +func NewReadWaitOptions(source any, destination any) ReadWaitOptions { + return ReadWaitOptions{ + FrontHeadroom: CalculateFrontHeadroom(destination), + RearHeadroom: CalculateRearHeadroom(destination), + MTU: CalculateMTU(source, destination), + } +} + func (o ReadWaitOptions) NeedHeadroom() bool { return o.FrontHeadroom > 0 || o.RearHeadroom > 0 } diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 5d474e6..9d5bfa9 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -12,7 +12,14 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type Conn struct { +type Conn interface { + N.PacketConn + SetHandler(handler N.UDPHandlerEx) +} + +var _ Conn = (*natConn)(nil) + +type natConn struct { writer N.PacketWriter localAddr M.Socksaddr handler N.UDPHandlerEx @@ -22,7 +29,7 @@ type Conn struct { readWaitOptions N.ReadWaitOptions } -func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { +func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { select { case p := <-c.packetChan: _, err = buffer.ReadOnceFrom(p.Buffer) @@ -37,12 +44,17 @@ func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } -func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WritePacket(buffer, destination) } -func (c *Conn) SetHandler(handler N.UDPHandlerEx) { +func (c *natConn) SetHandler(handler N.UDPHandlerEx) { + select { + case <-c.doneChan: + default: + } c.handler = handler + c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) fetch: for { select { @@ -56,12 +68,12 @@ fetch: } } -func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.readWaitOptions = options return false } -func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { select { case packet := <-c.packetChan: buffer = c.readWaitOptions.Copy(packet.Buffer) @@ -75,7 +87,7 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er } } -func (c *Conn) Close() error { +func (c *natConn) Close() error { select { case <-c.doneChan: default: @@ -84,27 +96,27 @@ func (c *Conn) Close() error { return nil } -func (c *Conn) LocalAddr() net.Addr { +func (c *natConn) LocalAddr() net.Addr { return c.localAddr } -func (c *Conn) RemoteAddr() net.Addr { +func (c *natConn) RemoteAddr() net.Addr { return M.Socksaddr{} } -func (c *Conn) SetDeadline(t time.Time) error { +func (c *natConn) SetDeadline(t time.Time) error { return os.ErrInvalid } -func (c *Conn) SetReadDeadline(t time.Time) error { +func (c *natConn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) return nil } -func (c *Conn) SetWriteDeadline(t time.Time) error { +func (c *natConn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } -func (c *Conn) Upstream() any { +func (c *natConn) Upstream() any { return c.writer } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 4bbef75..ac5da1d 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - cache freelru.Cache[netip.AddrPort, *Conn] + cache freelru.Cache[netip.AddrPort, *natConn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -30,15 +30,15 @@ type Metrics struct { } func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { - var cache freelru.Cache[netip.AddrPort, *Conn] + var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { - cache = common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } else { - cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } cache.SetLifetime(timeout) cache.SetUpdateLifetimeOnGet(true) - cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { + cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { select { case <-conn.doneChan: return false @@ -46,7 +46,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - cache.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { + cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { conn.Close() }) return &Service{ @@ -64,7 +64,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati s.metrics.Rejects++ return } - conn = &Conn{ + conn = &natConn{ writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64),