This commit is contained in:
世界 2023-04-19 21:31:26 +08:00
parent 3b5e6c1812
commit b84f8d5313
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 364 additions and 0 deletions

View file

@ -100,6 +100,14 @@ func (r *BufferedReader) ReaderReplaceable() bool {
return buffer == nil || buffer.Closed()
}
func (r *BufferedReader) CreateReadWaiter() (ReadWaiter, bool) {
reader, created := CreateReadWaiter(r.upstream)
if !created {
return nil, false
}
return &bufferedReadWaiter{r, reader}, true
}
type BufferedWriter struct {
upstream io.Writer
buffer *buf.Buffer

View file

@ -0,0 +1,45 @@
package bufio
import (
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
type bufferedReadWaiter struct {
*BufferedReader
upstream ReadWaiter
}
func (w *bufferedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
if w.buffer == nil {
return w.upstream.WaitReadBuffer(newBuffer)
}
if w.buffer.Closed() {
return os.ErrClosed
}
var err error
if w.buffer.IsEmpty() {
w.buffer.Reset()
w.buffer.IncRef()
err = w.upstream.WaitReadBuffer(func() *buf.Buffer {
return w.buffer
})
w.buffer.DecRef()
if err != nil {
w.buffer.Release()
return err
}
}
buffer := newBuffer()
if w.buffer.Len() > buffer.FreeLen() {
err = common.Error(buffer.ReadFullFrom(w.buffer, buffer.FreeLen()))
} else {
err = common.Error(buffer.ReadFullFrom(w.buffer, w.buffer.Len()))
}
if err != nil {
w.buffer.Release()
}
return err
}

View file

@ -83,6 +83,17 @@ func (c *CachedConn) WriterReplaceable() bool {
return true
}
func (c *CachedConn) CreateReadWaiter() (ReadWaiter, bool) {
reader, created := CreateReadWaiter(c.Conn)
if !created {
return nil, false
}
if c.buffer == nil {
return reader, true
}
return &cachedReadWaiter{reader, c.buffer}, true
}
func (c *CachedConn) Close() error {
if buffer := c.buffer; buffer != nil {
buffer.DecRef()
@ -207,6 +218,17 @@ func (c *CachedPacketConn) WriterReplaceable() bool {
return true
}
func (c *CachedPacketConn) CreatePacketReadWaiter() (PacketReadWaiter, bool) {
reader, created := CreatePacketReadWaiter(c.PacketConn)
if !created {
return nil, false
}
if c.buffer == nil {
return reader, true
}
return &cachedPacketReadWaiter{reader, c.buffer, c.destination}, true
}
func (c *CachedPacketConn) Close() error {
if buffer := c.buffer; buffer != nil {
buffer.DecRef()

View file

@ -0,0 +1,49 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type cachedReadWaiter struct {
reader ReadWaiter
cache *buf.Buffer
}
func (c *cachedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
cache := c.cache
if cache != nil {
var err error
if !cache.IsEmpty() {
_, err = newBuffer().ReadOnceFrom(c.cache)
}
if cache.IsEmpty() {
cache.Release()
c.cache = nil
}
return err
}
return c.reader.WaitReadBuffer(newBuffer)
}
type cachedPacketReadWaiter struct {
reader PacketReadWaiter
cache *buf.Buffer
destination M.Socksaddr
}
func (c *cachedPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
cache := c.cache
if cache != nil {
if !cache.IsEmpty() {
_, err = newBuffer().ReadOnceFrom(c.cache)
}
if cache.IsEmpty() {
cache.Release()
c.cache = nil
}
destination = c.destination
return
}
return c.reader.WaitReadPacket(newBuffer)
}

View file

@ -76,6 +76,10 @@ func (c *ChunkReader) MTU() int {
return c.maxChunkSize
}
func (c *ChunkReader) Upstream() any {
return c.upstream
}
type ChunkWriter struct {
upstream N.ExtendedWriter
maxChunkSize int

View file

@ -84,6 +84,11 @@ func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader)
return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc)
}
}
if common.LowMemory {
if reader, created := CreateReadWaiter(src); created {
return CopyExtendedWithPoolWait(origDst, dst, reader)
}
}
if N.IsUnsafeWriter(dst) {
return CopyExtendedWithPool(origDst, dst, src)
}
@ -185,6 +190,48 @@ func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.Extende
}
}
func CopyExtendedWithPoolWait(origDst io.Writer, dst N.ExtendedWriter, src ReadWaiter) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
)
newBuffer := func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var notFirstTime bool
for {
err = src.WaitReadBuffer(newBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(origDst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, conn, dest)
}
@ -241,6 +288,11 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
return CopyPacketWithSrcBuffer(dst, safeSrc)
}
}
if common.LowMemory {
if reader, created := CreatePacketReadWaiter(src); created {
return CopyPacketWithPoolWait(dst, reader)
}
}
if N.IsUnsafeWriter(dst) {
return CopyPacketWithPool(dst, src)
}
@ -339,6 +391,49 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
}
}
func CopyPacketWithPoolWait(dst N.PacketWriter, src PacketReadWaiter) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
)
newBuffer := func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
}
var destination M.Socksaddr
var notFirstTime bool
for {
destination, err = src.WaitReadPacket(newBuffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
}

View file

@ -0,0 +1,3 @@
package deadline

View file

@ -19,6 +19,9 @@ func NewVectorisedWriter(writer io.Writer) N.VectorisedWriter {
}
func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
if ioWriter, isWriter := writer.(io.Writer); isWriter {
writer = N.UnwrapWriter(ioWriter)
}
switch w := writer.(type) {
case N.VectorisedWriter:
return w, true
@ -42,6 +45,9 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
}
func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
if packetWriter, isWriter := writer.(N.PacketWriter); isWriter {
writer = N.UnwrapPacketWriter(packetWriter)
}
switch w := writer.(type) {
case N.VectorisedPacketWriter:
return w, true

114
common/bufio/wait.go Normal file
View file

@ -0,0 +1,114 @@
package bufio
import (
"net/netip"
"syscall"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type ReadWaiter interface {
WaitReadBuffer(newBuffer func() *buf.Buffer) error
}
type PacketReadWaiter interface {
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
}
type ReadWaiterCreator interface {
CreateReadWaiter() (ReadWaiter, bool)
}
type PacketReadWaiterCreator interface {
CreatePacketReadWaiter() (PacketReadWaiter, bool)
}
func CreateReadWaiter(conn any) (ReadWaiter, bool) {
if waiter, loaded := common.Cast[ReadWaiterCreator](conn); loaded {
return waiter.CreateReadWaiter()
}
if waiter, loaded := common.Cast[ReadWaiter](conn); loaded {
return waiter, true
}
if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded {
return &syscallReadWaiter{rawConn}, true
}
if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded {
rawConn, err := syscallConn.SyscallConn()
if err != nil {
return nil, false
}
return &syscallReadWaiter{rawConn}, true
}
return nil, false
}
func CreatePacketReadWaiter(conn any) (PacketReadWaiter, bool) {
if waiter, loaded := common.Cast[PacketReadWaiterCreator](conn); loaded {
return waiter.CreatePacketReadWaiter()
}
if waiter, loaded := common.Cast[PacketReadWaiter](conn); loaded {
return waiter, true
}
if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded {
return &syscallReadWaiter{rawConn}, true
}
if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded {
rawConn, err := syscallConn.SyscallConn()
if err != nil {
return nil, false
}
return &syscallReadWaiter{rawConn}, true
}
return nil, false
}
type syscallReadWaiter struct {
syscall.RawConn
}
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
var (
buffer *buf.Buffer
n int
err error
)
err = w.RawConn.Read(func(fd uintptr) (done bool) {
buffer = newBuffer()
n, err = syscall.Read(int(fd), buffer.FreeBytes())
if err == syscall.EAGAIN {
buffer.Release()
return false
}
buffer.Truncate(n)
return true
})
return err
}
func (w *syscallReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
var (
buffer *buf.Buffer
n int
from syscall.Sockaddr
)
err = w.RawConn.Read(func(fd uintptr) (done bool) {
buffer = newBuffer()
n, _, _, from, err = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if err == syscall.EAGAIN {
buffer.Release()
return false
}
buffer.Truncate(n)
return true
})
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
destination = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
destination = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
}
return destination, err
}

View file

@ -71,6 +71,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet, 64),
newData: make(chan struct{}, 1),
localAddr: metadata.Source,
remoteAddr: metadata.Destination,
}
@ -97,6 +98,10 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
}
return
}
select {
case c.newData <- struct{}{}:
default:
}
c.data <- packet{
data: buffer,
destination: metadata.Destination,
@ -116,6 +121,7 @@ type conn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
data chan packet
newData chan struct{}
localAddr M.Socksaddr
remoteAddr M.Socksaddr
source N.PacketWriter
@ -161,12 +167,24 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr))
}
func (c *conn) WaitRead() {
select {
case <-c.newData:
case <-c.ctx.Done():
}
}
func (c *conn) Close() error {
select {
case <-c.ctx.Done():
default:
c.cancel(net.ErrClosed)
}
select {
case <-c.newData:
default:
close(c.newData)
}
if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser {
return sourceCloser.Close()
}