From ae8098ad3939cf51d4f0fb8e6d6624a8e22a694d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 20:15:40 +0800 Subject: [PATCH] Refactor read waiter interface --- common/bufio/bind.go | 61 ++++++++++++++++++++++++------- common/bufio/bind_wait.go | 42 +++++++++++++++++++++ common/bufio/copy_direct_posix.go | 30 +++++++++------ common/network/direct.go | 4 +- 4 files changed, 110 insertions(+), 27 deletions(-) create mode 100644 common/bufio/bind_wait.go diff --git a/common/bufio/bind.go b/common/bufio/bind.go index 4c84320..caa5bbd 100644 --- a/common/bufio/bind.go +++ b/common/bufio/bind.go @@ -8,51 +8,76 @@ import ( N "github.com/sagernet/sing/common/network" ) -type BindPacketConn struct { +type BindPacketConn interface { N.NetPacketConn - Addr net.Addr + net.Conn } -func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn { - return &BindPacketConn{ +type bindPacketConn struct { + N.NetPacketConn + addr net.Addr +} + +func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn { + return &bindPacketConn{ NewPacketConn(conn), addr, } } -func (c *BindPacketConn) Read(b []byte) (n int, err error) { +func (c *bindPacketConn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *BindPacketConn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.Addr) +func (c *bindPacketConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.addr) } -func (c *BindPacketConn) RemoteAddr() net.Addr { - return c.Addr +func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { + readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn) + if !isReadWaiter { + return nil, false + } + return &BindPacketReadWaiter{readWaiter}, true } -func (c *BindPacketConn) Upstream() any { +func (c *bindPacketConn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *bindPacketConn) Upstream() any { return c.NetPacketConn } +var ( + _ N.NetPacketConn = (*UnbindPacketConn)(nil) + _ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil) +) + type UnbindPacketConn struct { N.ExtendedConn - Addr M.Socksaddr + addr M.Socksaddr } -func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn { +func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn { return &UnbindPacketConn{ NewExtendedConn(conn), M.SocksaddrFromNet(conn.RemoteAddr()), } } +func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn { + return &UnbindPacketConn{ + NewExtendedConn(conn), + addr, + } +} + func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, err = c.ExtendedConn.Read(p) if err == nil { - addr = c.Addr.UDPAddr() + addr = c.addr.UDPAddr() } return } @@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad if err != nil { return } - destination = c.Addr + destination = c.addr return } @@ -74,6 +99,14 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error return c.ExtendedConn.WriteBuffer(buffer) } +func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn) + if !isReadWaiter { + return nil, false + } + return &UnbindPacketReadWaiter{readWaiter, c.addr}, true +} + func (c *UnbindPacketConn) Upstream() any { return c.ExtendedConn } diff --git a/common/bufio/bind_wait.go b/common/bufio/bind_wait.go new file mode 100644 index 0000000..724a76e --- /dev/null +++ b/common/bufio/bind_wait.go @@ -0,0 +1,42 @@ +package bufio + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil) + +type BindPacketReadWaiter struct { + readWaiter N.PacketReadWaiter +} + +func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + w.readWaiter.InitializeReadWaiter(newBuffer) +} + +func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, _, err = w.readWaiter.WaitReadPacket() + return +} + +var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil) + +type UnbindPacketReadWaiter struct { + readWaiter N.ReadWaiter + addr M.Socksaddr +} + +func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + w.readWaiter.InitializeReadWaiter(newBuffer) +} + +func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, err = w.readWaiter.WaitReadBuffer() + if err != nil { + return + } + destination = w.addr + return +} diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 3501e66..06da27d 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -39,7 +39,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour }) defer source.InitializeReadWaiter(nil) for { - err = source.WaitReadBuffer() + _, err = source.WaitReadBuffer() if err != nil { if errors.Is(err, io.EOF) { err = nil @@ -92,7 +92,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe }) defer source.InitializeReadWaiter(nil) for { - destination, err = source.WaitReadPacket() + _, destination, err = source.WaitReadPacket() if err != nil { return } @@ -123,6 +123,7 @@ type syscallReadWaiter struct { rawConn syscall.RawConn readErr error readFunc func(fd uintptr) (done bool) + buffer *buf.Buffer } func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { @@ -156,26 +157,29 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { if readN == 0 { w.readErr = io.EOF } + w.buffer = buffer return true } } } -func (w *syscallReadWaiter) WaitReadBuffer() error { +func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { if w.readFunc == nil { - return os.ErrInvalid + return nil, os.ErrInvalid } - err := w.rawConn.Read(w.readFunc) + err = w.rawConn.Read(w.readFunc) if err != nil { - return err + return } if w.readErr != nil { if w.readErr == io.EOF { - return io.EOF + return nil, io.EOF } - return E.Cause(w.readErr, "raw read") + return nil, E.Cause(w.readErr, "raw read") } - return nil + buffer = w.buffer + w.buffer = nil + return } var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) @@ -185,6 +189,7 @@ type syscallPacketReadWaiter struct { readErr error readFrom M.Socksaddr readFunc func(fd uintptr) (done bool) + buffer *buf.Buffer } func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { @@ -225,14 +230,15 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() } } + w.buffer = buffer return true } } } -func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { +func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { if w.readFunc == nil { - return M.Socksaddr{}, os.ErrInvalid + return nil, M.Socksaddr{}, os.ErrInvalid } err = w.rawConn.Read(w.readFunc) if err != nil { @@ -242,6 +248,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err err = E.Cause(w.readErr, "raw read") return } + buffer = w.buffer + w.buffer = nil destination = w.readFrom return } diff --git a/common/network/direct.go b/common/network/direct.go index a40275c..b645676 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -7,7 +7,7 @@ import ( type ReadWaiter interface { InitializeReadWaiter(newBuffer func() *buf.Buffer) - WaitReadBuffer() error + WaitReadBuffer() (buffer *buf.Buffer, err error) } type ReadWaitCreator interface { @@ -16,7 +16,7 @@ type ReadWaitCreator interface { type PacketReadWaiter interface { InitializeReadWaiter(newBuffer func() *buf.Buffer) - WaitReadPacket() (destination M.Socksaddr, err error) + WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) } type PacketReadWaitCreator interface {