mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Improve read waiter interface
This commit is contained in:
parent
ab3e4694cb
commit
9be7806bab
2 changed files with 51 additions and 32 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -25,24 +26,21 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
|
|||
bufferSize = buf.BufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
notFirstTime bool
|
||||
)
|
||||
newBuffer := func() *buf.Buffer {
|
||||
if buffer != nil {
|
||||
buffer.Release()
|
||||
}
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
}
|
||||
var notFirstTime bool
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
err = source.WaitReadBuffer(newBuffer)
|
||||
err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
return
|
||||
|
@ -56,9 +54,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
|
|||
buffer.Resize(readBuffer.Start(), dataLen)
|
||||
err = destination.WriteBuffer(buffer)
|
||||
if err != nil {
|
||||
if buffer != nil {
|
||||
buffer.Release()
|
||||
}
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
n += int64(dataLen)
|
||||
|
@ -83,25 +79,22 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
bufferSize = buf.UDPBufferSize
|
||||
}
|
||||
var (
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
buffer *buf.Buffer
|
||||
readBuffer *buf.Buffer
|
||||
destination M.Socksaddr
|
||||
notFirstTime bool
|
||||
)
|
||||
newBuffer := func() *buf.Buffer {
|
||||
if buffer != nil {
|
||||
buffer.Release()
|
||||
}
|
||||
source.InitializeReadWaiter(func() *buf.Buffer {
|
||||
buffer = buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
return readBuffer
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
destination, err = source.WaitReadPacket(newBuffer)
|
||||
destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !notFirstTime {
|
||||
err = N.HandshakeFailure(destinationConn, err)
|
||||
}
|
||||
|
@ -113,8 +106,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
} else {
|
||||
buffer = nil
|
||||
}
|
||||
n += int64(dataLen)
|
||||
for _, counter := range readCounters {
|
||||
|
@ -127,6 +118,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
|
|||
}
|
||||
}
|
||||
|
||||
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
|
||||
|
||||
type syscallReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
|
@ -143,8 +136,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
||||
if w.readFunc == nil {
|
||||
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
|
@ -164,16 +160,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
|
|||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() error {
|
||||
if w.readFunc == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
err := w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return io.EOF
|
||||
}
|
||||
return E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
||||
type syscallPacketReadWaiter struct {
|
||||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
|
@ -191,8 +198,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readErr = nil
|
||||
w.readFrom = M.Socksaddr{}
|
||||
if newBuffer == nil {
|
||||
w.readFunc = nil
|
||||
} else {
|
||||
w.readFunc = func(fd uintptr) (done bool) {
|
||||
buffer := newBuffer()
|
||||
var readN int
|
||||
|
@ -221,6 +232,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
|
|||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -6,7 +6,8 @@ import (
|
|||
)
|
||||
|
||||
type ReadWaiter interface {
|
||||
WaitReadBuffer(newBuffer func() *buf.Buffer) error
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
WaitReadBuffer() error
|
||||
}
|
||||
|
||||
type ReadWaitCreator interface {
|
||||
|
@ -14,7 +15,8 @@ type ReadWaitCreator interface {
|
|||
}
|
||||
|
||||
type PacketReadWaiter interface {
|
||||
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
WaitReadPacket() (destination M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
type PacketReadWaitCreator interface {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue