bufio: more copy funcs

This commit is contained in:
世界 2024-11-20 11:27:20 +08:00
parent 30fbafd954
commit fa5355e99e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 97 additions and 109 deletions

View file

@ -30,27 +30,38 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
cachedBuffer := cachedSrc.ReadCached() cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil { if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() { if !cachedBuffer.IsEmpty() {
dataLen := cachedBuffer.Len()
for _, counter := range readCounters {
counter(int64(dataLen))
}
_, err = destination.Write(cachedBuffer.Bytes()) _, err = destination.Write(cachedBuffer.Bytes())
if err != nil { if err != nil {
cachedBuffer.Release() cachedBuffer.Release()
return return
} }
for _, counter := range writeCounters {
counter(int64(dataLen))
}
} }
cachedBuffer.Release() cachedBuffer.Release()
continue continue
} }
} }
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break break
} }
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
} }
@ -75,6 +86,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
} }
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef() buffer.IncRef()
defer buffer.DecRef() defer buffer.DecRef()
@ -113,19 +125,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
} }
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination) options := N.NewReadWaitOptions(source, destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var notFirstTime bool var notFirstTime bool
for { for {
buffer := buf.NewSize(bufferSize) buffer := options.NewBuffer()
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
err = source.ReadBuffer(buffer) err = source.ReadBuffer(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
@ -136,7 +139,10 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return return
} }
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) for _, counter := range readCounters {
counter(int64(dataLen))
}
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer) err = destination.WriteBuffer(buffer)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
@ -146,9 +152,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return return
} }
n += int64(dataLen) n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters { for _, counter := range writeCounters {
counter(int64(dataLen)) counter(int64(dataLen))
} }
@ -196,18 +199,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
return group.Run(ctx) return group.Run(ctx)
} }
// Deprecated: not used
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
switch len(contextList) {
case 0:
return CopyConn(context.Background(), source, destination)
case 1:
return CopyConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer var cachedPackets []*N.PacketBuffer
@ -225,24 +216,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break break
} }
if cachedPackets != nil { if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil { if err != nil {
return return
} }
} }
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var ( var (
handled bool handled bool
copeN int64 copeN int64
) )
readWaiter, isReadWaiter := CreatePacketReadWaiter(source) readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter { if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
if !needCopy || common.LowMemory { if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled { if handled {
@ -256,28 +247,22 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return return
} }
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) options := N.NewReadWaitOptions(source, destination)
rearHeadroom := N.CalculateRearHeadroom(destinationConn) var destinationAddress M.Socksaddr
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
for { for {
buffer := buf.NewSize(bufferSize) buffer := options.NewPacketBuffer()
buffer.Resize(frontHeadroom, 0) destinationAddress, err = source.ReadPacket(buffer)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return return
} }
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) for _, counter := range readCounters {
err = destinationConn.WritePacket(buffer, destination) counter(int64(dataLen))
}
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
if !notFirstTime { if !notFirstTime {
@ -285,34 +270,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
} }
return return
} }
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters { for _, counter := range writeCounters {
counter(int64(dataLen)) counter(int64(dataLen))
} }
n += int64(dataLen)
notFirstTime = true notFirstTime = true
} }
} }
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn) options := N.NewReadWaitOptions(nil, destination)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
var notFirstTime bool var notFirstTime bool
for _, packetBuffer := range packetBuffers { for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket() for _, counter := range readCounters {
buffer.Resize(frontHeadroom, 0) counter(int64(packetBuffer.Buffer.Len()))
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
} }
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len() dataLen := buffer.Len()
buffer.OverCap(rearHeadroom) err = destination.WritePacket(buffer, packetBuffer.Destination)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination) N.PutPacketBuffer(packetBuffer)
if err != nil { if err != nil {
buffer.Leak() buffer.Leak()
if !notFirstTime { if !notFirstTime {
@ -320,7 +296,11 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
} }
return return
} }
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen) n += int64(dataLen)
notFirstTime = true
} }
return return
} }
@ -339,15 +319,3 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack
group.FastFail() group.FastFail()
return group.Run(ctx) return group.Run(ctx)
} }
// Deprecated: not used
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
switch len(contextList) {
case 0:
return CopyPacketConn(context.Background(), source, destination)
case 1:
return CopyPacketConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}

View file

@ -15,6 +15,14 @@ type ReadWaitOptions struct {
MTU int MTU int
} }
func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
return ReadWaitOptions{
FrontHeadroom: CalculateFrontHeadroom(destination),
RearHeadroom: CalculateRearHeadroom(destination),
MTU: CalculateMTU(source, destination),
}
}
func (o ReadWaitOptions) NeedHeadroom() bool { func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0 return o.FrontHeadroom > 0 || o.RearHeadroom > 0
} }

View file

@ -12,7 +12,14 @@ import (
"github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/common/pipe"
) )
type Conn struct { type Conn interface {
N.PacketConn
SetHandler(handler N.UDPHandlerEx)
}
var _ Conn = (*natConn)(nil)
type natConn struct {
writer N.PacketWriter writer N.PacketWriter
localAddr M.Socksaddr localAddr M.Socksaddr
handler N.UDPHandlerEx handler N.UDPHandlerEx
@ -22,7 +29,7 @@ type Conn struct {
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
} }
func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select { select {
case p := <-c.packetChan: case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer) _, err = buffer.ReadOnceFrom(p.Buffer)
@ -37,12 +44,17 @@ func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
} }
} }
func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination) return c.writer.WritePacket(buffer, destination)
} }
func (c *Conn) SetHandler(handler N.UDPHandlerEx) { func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
select {
case <-c.doneChan:
default:
}
c.handler = handler c.handler = handler
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
fetch: fetch:
for { for {
select { select {
@ -56,12 +68,12 @@ fetch:
} }
} }
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options c.readWaitOptions = options
return false return false
} }
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select { select {
case packet := <-c.packetChan: case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer) buffer = c.readWaitOptions.Copy(packet.Buffer)
@ -75,7 +87,7 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
} }
} }
func (c *Conn) Close() error { func (c *natConn) Close() error {
select { select {
case <-c.doneChan: case <-c.doneChan:
default: default:
@ -84,27 +96,27 @@ func (c *Conn) Close() error {
return nil return nil
} }
func (c *Conn) LocalAddr() net.Addr { func (c *natConn) LocalAddr() net.Addr {
return c.localAddr return c.localAddr
} }
func (c *Conn) RemoteAddr() net.Addr { func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{} return M.Socksaddr{}
} }
func (c *Conn) SetDeadline(t time.Time) error { func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t) c.readDeadline.Set(t)
return nil return nil
} }
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *Conn) Upstream() any { func (c *natConn) Upstream() any {
return c.writer return c.writer
} }

View file

@ -14,7 +14,7 @@ import (
) )
type Service struct { type Service struct {
cache freelru.Cache[netip.AddrPort, *Conn] cache freelru.Cache[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx handler N.UDPConnectionHandlerEx
prepare PrepareFunc prepare PrepareFunc
metrics Metrics metrics Metrics
@ -30,15 +30,15 @@ type Metrics struct {
} }
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
var cache freelru.Cache[netip.AddrPort, *Conn] var cache freelru.Cache[netip.AddrPort, *natConn]
if !shared { if !shared {
cache = common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
} else { } else {
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
} }
cache.SetLifetime(timeout) cache.SetLifetime(timeout)
cache.SetUpdateLifetimeOnGet(true) cache.SetUpdateLifetimeOnGet(true)
cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select { select {
case <-conn.doneChan: case <-conn.doneChan:
return false return false
@ -46,7 +46,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
return true return true
} }
}) })
cache.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
conn.Close() conn.Close()
}) })
return &Service{ return &Service{
@ -64,7 +64,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
s.metrics.Rejects++ s.metrics.Rejects++
return return
} }
conn = &Conn{ conn = &natConn{
writer: writer, writer: writer,
localAddr: source, localAddr: source,
packetChan: make(chan *N.PacketBuffer, 64), packetChan: make(chan *N.PacketBuffer, 64),