mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-02 03:17:37 +03:00
w
This commit is contained in:
parent
3b5e6c1812
commit
b84f8d5313
10 changed files with 364 additions and 0 deletions
|
@ -100,6 +100,14 @@ func (r *BufferedReader) ReaderReplaceable() bool {
|
|||
return buffer == nil || buffer.Closed()
|
||||
}
|
||||
|
||||
func (r *BufferedReader) CreateReadWaiter() (ReadWaiter, bool) {
|
||||
reader, created := CreateReadWaiter(r.upstream)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
return &bufferedReadWaiter{r, reader}, true
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
upstream io.Writer
|
||||
buffer *buf.Buffer
|
||||
|
|
45
common/bufio/buffer_wait.go
Normal file
45
common/bufio/buffer_wait.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type bufferedReadWaiter struct {
|
||||
*BufferedReader
|
||||
upstream ReadWaiter
|
||||
}
|
||||
|
||||
func (w *bufferedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
||||
if w.buffer == nil {
|
||||
return w.upstream.WaitReadBuffer(newBuffer)
|
||||
}
|
||||
if w.buffer.Closed() {
|
||||
return os.ErrClosed
|
||||
}
|
||||
var err error
|
||||
if w.buffer.IsEmpty() {
|
||||
w.buffer.Reset()
|
||||
w.buffer.IncRef()
|
||||
err = w.upstream.WaitReadBuffer(func() *buf.Buffer {
|
||||
return w.buffer
|
||||
})
|
||||
w.buffer.DecRef()
|
||||
if err != nil {
|
||||
w.buffer.Release()
|
||||
return err
|
||||
}
|
||||
}
|
||||
buffer := newBuffer()
|
||||
if w.buffer.Len() > buffer.FreeLen() {
|
||||
err = common.Error(buffer.ReadFullFrom(w.buffer, buffer.FreeLen()))
|
||||
} else {
|
||||
err = common.Error(buffer.ReadFullFrom(w.buffer, w.buffer.Len()))
|
||||
}
|
||||
if err != nil {
|
||||
w.buffer.Release()
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -83,6 +83,17 @@ func (c *CachedConn) WriterReplaceable() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (c *CachedConn) CreateReadWaiter() (ReadWaiter, bool) {
|
||||
reader, created := CreateReadWaiter(c.Conn)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
if c.buffer == nil {
|
||||
return reader, true
|
||||
}
|
||||
return &cachedReadWaiter{reader, c.buffer}, true
|
||||
}
|
||||
|
||||
func (c *CachedConn) Close() error {
|
||||
if buffer := c.buffer; buffer != nil {
|
||||
buffer.DecRef()
|
||||
|
@ -207,6 +218,17 @@ func (c *CachedPacketConn) WriterReplaceable() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) CreatePacketReadWaiter() (PacketReadWaiter, bool) {
|
||||
reader, created := CreatePacketReadWaiter(c.PacketConn)
|
||||
if !created {
|
||||
return nil, false
|
||||
}
|
||||
if c.buffer == nil {
|
||||
return reader, true
|
||||
}
|
||||
return &cachedPacketReadWaiter{reader, c.buffer, c.destination}, true
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Close() error {
|
||||
if buffer := c.buffer; buffer != nil {
|
||||
buffer.DecRef()
|
||||
|
|
49
common/bufio/cache_wait.go
Normal file
49
common/bufio/cache_wait.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type cachedReadWaiter struct {
|
||||
reader ReadWaiter
|
||||
cache *buf.Buffer
|
||||
}
|
||||
|
||||
func (c *cachedReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
||||
cache := c.cache
|
||||
if cache != nil {
|
||||
var err error
|
||||
if !cache.IsEmpty() {
|
||||
_, err = newBuffer().ReadOnceFrom(c.cache)
|
||||
}
|
||||
if cache.IsEmpty() {
|
||||
cache.Release()
|
||||
c.cache = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return c.reader.WaitReadBuffer(newBuffer)
|
||||
}
|
||||
|
||||
type cachedPacketReadWaiter struct {
|
||||
reader PacketReadWaiter
|
||||
cache *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
func (c *cachedPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
cache := c.cache
|
||||
if cache != nil {
|
||||
if !cache.IsEmpty() {
|
||||
_, err = newBuffer().ReadOnceFrom(c.cache)
|
||||
}
|
||||
if cache.IsEmpty() {
|
||||
cache.Release()
|
||||
c.cache = nil
|
||||
}
|
||||
destination = c.destination
|
||||
return
|
||||
}
|
||||
return c.reader.WaitReadPacket(newBuffer)
|
||||
}
|
|
@ -76,6 +76,10 @@ func (c *ChunkReader) MTU() int {
|
|||
return c.maxChunkSize
|
||||
}
|
||||
|
||||
func (c *ChunkReader) Upstream() any {
|
||||
return c.upstream
|
||||
}
|
||||
|
||||
type ChunkWriter struct {
|
||||
upstream N.ExtendedWriter
|
||||
maxChunkSize int
|
||||
|
|
|
@ -84,6 +84,11 @@ func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader)
|
|||
return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc)
|
||||
}
|
||||
}
|
||||
if common.LowMemory {
|
||||
if reader, created := CreateReadWaiter(src); created {
|
||||
return CopyExtendedWithPoolWait(origDst, dst, reader)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyExtendedWithPool(origDst, dst, src)
|
||||
}
|
||||
|
@ -185,6 +190,48 @@ func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.Extende
|
|||
}
|
||||
}
|
||||
|
||||
func CopyExtendedWithPoolWait(origDst io.Writer, dst N.ExtendedWriter, src ReadWaiter) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
)
|
||||
newBuffer := func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
}
|
||||
var notFirstTime bool
|
||||
for {
|
||||
err = src.WaitReadBuffer(newBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(origDst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
return CopyConnContextList([]context.Context{ctx}, conn, dest)
|
||||
}
|
||||
|
@ -241,6 +288,11 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
|||
return CopyPacketWithSrcBuffer(dst, safeSrc)
|
||||
}
|
||||
}
|
||||
if common.LowMemory {
|
||||
if reader, created := CreatePacketReadWaiter(src); created {
|
||||
return CopyPacketWithPoolWait(dst, reader)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
return CopyPacketWithPool(dst, src)
|
||||
}
|
||||
|
@ -339,6 +391,49 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
|
|||
}
|
||||
}
|
||||
|
||||
func CopyPacketWithPoolWait(dst N.PacketWriter, src PacketReadWaiter) (n int64, err error) {
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
bufferSize := N.CalculateMTU(src, dst)
|
||||
if bufferSize > 0 {
|
||||
bufferSize += frontHeadroom + rearHeadroom
|
||||
} else {
|
||||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
)
|
||||
newBuffer := func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
for {
|
||||
destination, err = src.WaitReadPacket(newBuffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(dst, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dataLen := readBuffer.Len()
|
||||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = dst.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
notFirstTime = true
|
||||
}
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
|
||||
}
|
||||
|
|
3
common/bufio/deadline/chunk_wait.go
Normal file
3
common/bufio/deadline/chunk_wait.go
Normal file
|
@ -0,0 +1,3 @@
|
|||
package deadline
|
||||
|
||||
|
|
@ -19,6 +19,9 @@ func NewVectorisedWriter(writer io.Writer) N.VectorisedWriter {
|
|||
}
|
||||
|
||||
func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
||||
if ioWriter, isWriter := writer.(io.Writer); isWriter {
|
||||
writer = N.UnwrapWriter(ioWriter)
|
||||
}
|
||||
switch w := writer.(type) {
|
||||
case N.VectorisedWriter:
|
||||
return w, true
|
||||
|
@ -42,6 +45,9 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
|
|||
}
|
||||
|
||||
func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
|
||||
if packetWriter, isWriter := writer.(N.PacketWriter); isWriter {
|
||||
writer = N.UnwrapPacketWriter(packetWriter)
|
||||
}
|
||||
switch w := writer.(type) {
|
||||
case N.VectorisedPacketWriter:
|
||||
return w, true
|
||||
|
|
114
common/bufio/wait.go
Normal file
114
common/bufio/wait.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type ReadWaiter interface {
|
||||
WaitReadBuffer(newBuffer func() *buf.Buffer) error
|
||||
}
|
||||
|
||||
type PacketReadWaiter interface {
|
||||
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
type ReadWaiterCreator interface {
|
||||
CreateReadWaiter() (ReadWaiter, bool)
|
||||
}
|
||||
|
||||
type PacketReadWaiterCreator interface {
|
||||
CreatePacketReadWaiter() (PacketReadWaiter, bool)
|
||||
}
|
||||
|
||||
func CreateReadWaiter(conn any) (ReadWaiter, bool) {
|
||||
if waiter, loaded := common.Cast[ReadWaiterCreator](conn); loaded {
|
||||
return waiter.CreateReadWaiter()
|
||||
}
|
||||
if waiter, loaded := common.Cast[ReadWaiter](conn); loaded {
|
||||
return waiter, true
|
||||
}
|
||||
if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded {
|
||||
return &syscallReadWaiter{rawConn}, true
|
||||
}
|
||||
if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &syscallReadWaiter{rawConn}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func CreatePacketReadWaiter(conn any) (PacketReadWaiter, bool) {
|
||||
if waiter, loaded := common.Cast[PacketReadWaiterCreator](conn); loaded {
|
||||
return waiter.CreatePacketReadWaiter()
|
||||
}
|
||||
if waiter, loaded := common.Cast[PacketReadWaiter](conn); loaded {
|
||||
return waiter, true
|
||||
}
|
||||
if rawConn, loaded := common.Cast[syscall.RawConn](conn); loaded {
|
||||
return &syscallReadWaiter{rawConn}, true
|
||||
}
|
||||
if syscallConn, loaded := common.Cast[syscall.Conn](conn); loaded {
|
||||
rawConn, err := syscallConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &syscallReadWaiter{rawConn}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
syscall.RawConn
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
n int
|
||||
err error
|
||||
)
|
||||
err = w.RawConn.Read(func(fd uintptr) (done bool) {
|
||||
buffer = newBuffer()
|
||||
n, err = syscall.Read(int(fd), buffer.FreeBytes())
|
||||
if err == syscall.EAGAIN {
|
||||
buffer.Release()
|
||||
return false
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
return true
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
n int
|
||||
from syscall.Sockaddr
|
||||
)
|
||||
err = w.RawConn.Read(func(fd uintptr) (done bool) {
|
||||
buffer = newBuffer()
|
||||
n, _, _, from, err = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
|
||||
if err == syscall.EAGAIN {
|
||||
buffer.Release()
|
||||
return false
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
return true
|
||||
})
|
||||
switch fromAddr := from.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
destination = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
case *syscall.SockaddrInet6:
|
||||
destination = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
|
||||
}
|
||||
return destination, err
|
||||
}
|
|
@ -71,6 +71,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
|
|||
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
||||
c := &conn{
|
||||
data: make(chan packet, 64),
|
||||
newData: make(chan struct{}, 1),
|
||||
localAddr: metadata.Source,
|
||||
remoteAddr: metadata.Destination,
|
||||
}
|
||||
|
@ -97,6 +98,10 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
|
|||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case c.newData <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
c.data <- packet{
|
||||
data: buffer,
|
||||
destination: metadata.Destination,
|
||||
|
@ -116,6 +121,7 @@ type conn struct {
|
|||
ctx context.Context
|
||||
cancel common.ContextCancelCauseFunc
|
||||
data chan packet
|
||||
newData chan struct{}
|
||||
localAddr M.Socksaddr
|
||||
remoteAddr M.Socksaddr
|
||||
source N.PacketWriter
|
||||
|
@ -161,12 +167,24 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr))
|
||||
}
|
||||
|
||||
func (c *conn) WaitRead() {
|
||||
select {
|
||||
case <-c.newData:
|
||||
case <-c.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
default:
|
||||
c.cancel(net.ErrClosed)
|
||||
}
|
||||
select {
|
||||
case <-c.newData:
|
||||
default:
|
||||
close(c.newData)
|
||||
}
|
||||
if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser {
|
||||
return sourceCloser.Close()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue