Refactor read waiter interface

This commit is contained in:
世界 2023-12-06 20:15:40 +08:00
parent 05c71c99d1
commit ae8098ad39
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 110 additions and 27 deletions

View file

@ -8,51 +8,76 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type BindPacketConn struct { type BindPacketConn interface {
N.NetPacketConn N.NetPacketConn
Addr net.Addr net.Conn
} }
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn { type bindPacketConn struct {
return &BindPacketConn{ N.NetPacketConn
addr net.Addr
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
return &bindPacketConn{
NewPacketConn(conn), NewPacketConn(conn),
addr, addr,
} }
} }
func (c *BindPacketConn) Read(b []byte) (n int, err error) { func (c *bindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b) n, _, err = c.ReadFrom(b)
return return
} }
func (c *BindPacketConn) Write(b []byte) (n int, err error) { func (c *bindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr) return c.WriteTo(b, c.addr)
} }
func (c *BindPacketConn) RemoteAddr() net.Addr { func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
return c.Addr readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &BindPacketReadWaiter{readWaiter}, true
} }
func (c *BindPacketConn) Upstream() any { func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *bindPacketConn) Upstream() any {
return c.NetPacketConn return c.NetPacketConn
} }
var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
)
type UnbindPacketConn struct { type UnbindPacketConn struct {
N.ExtendedConn N.ExtendedConn
Addr M.Socksaddr addr M.Socksaddr
} }
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn { func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{ return &UnbindPacketConn{
NewExtendedConn(conn), NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()), M.SocksaddrFromNet(conn.RemoteAddr()),
} }
} }
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
addr,
}
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p) n, err = c.ExtendedConn.Read(p)
if err == nil { if err == nil {
addr = c.Addr.UDPAddr() addr = c.addr.UDPAddr()
} }
return return
} }
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
if err != nil { if err != nil {
return return
} }
destination = c.Addr destination = c.addr
return return
} }
@ -74,6 +99,14 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
return c.ExtendedConn.WriteBuffer(buffer) return c.ExtendedConn.WriteBuffer(buffer)
} }
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
if !isReadWaiter {
return nil, false
}
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any { func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn return c.ExtendedConn
} }

42
common/bufio/bind_wait.go Normal file
View file

@ -0,0 +1,42 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
type BindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
}
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
type UnbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
}
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination = w.addr
return
}

View file

@ -39,7 +39,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
}) })
defer source.InitializeReadWaiter(nil) defer source.InitializeReadWaiter(nil)
for { for {
err = source.WaitReadBuffer() _, err = source.WaitReadBuffer()
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
err = nil err = nil
@ -92,7 +92,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
}) })
defer source.InitializeReadWaiter(nil) defer source.InitializeReadWaiter(nil)
for { for {
destination, err = source.WaitReadPacket() _, destination, err = source.WaitReadPacket()
if err != nil { if err != nil {
return return
} }
@ -123,6 +123,7 @@ type syscallReadWaiter struct {
rawConn syscall.RawConn rawConn syscall.RawConn
readErr error readErr error
readFunc func(fd uintptr) (done bool) readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
} }
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
@ -156,26 +157,29 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
if readN == 0 { if readN == 0 {
w.readErr = io.EOF w.readErr = io.EOF
} }
w.buffer = buffer
return true return true
} }
} }
} }
func (w *syscallReadWaiter) WaitReadBuffer() error { func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil { if w.readFunc == nil {
return os.ErrInvalid return nil, os.ErrInvalid
} }
err := w.rawConn.Read(w.readFunc) err = w.rawConn.Read(w.readFunc)
if err != nil { if err != nil {
return err return
} }
if w.readErr != nil { if w.readErr != nil {
if w.readErr == io.EOF { if w.readErr == io.EOF {
return io.EOF return nil, io.EOF
} }
return E.Cause(w.readErr, "raw read") return nil, E.Cause(w.readErr, "raw read")
} }
return nil buffer = w.buffer
w.buffer = nil
return
} }
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
@ -185,6 +189,7 @@ type syscallPacketReadWaiter struct {
readErr error readErr error
readFrom M.Socksaddr readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool) readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
} }
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
@ -225,14 +230,15 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
} }
} }
w.buffer = buffer
return true return true
} }
} }
} }
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil { if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid return nil, M.Socksaddr{}, os.ErrInvalid
} }
err = w.rawConn.Read(w.readFunc) err = w.rawConn.Read(w.readFunc)
if err != nil { if err != nil {
@ -242,6 +248,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
err = E.Cause(w.readErr, "raw read") err = E.Cause(w.readErr, "raw read")
return return
} }
buffer = w.buffer
w.buffer = nil
destination = w.readFrom destination = w.readFrom
return return
} }

View file

@ -7,7 +7,7 @@ import (
type ReadWaiter interface { type ReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer) InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadBuffer() error WaitReadBuffer() (buffer *buf.Buffer, err error)
} }
type ReadWaitCreator interface { type ReadWaitCreator interface {
@ -16,7 +16,7 @@ type ReadWaitCreator interface {
type PacketReadWaiter interface { type PacketReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer) InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadPacket() (destination M.Socksaddr, err error) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
} }
type PacketReadWaitCreator interface { type PacketReadWaitCreator interface {