Refactor read waiter interface

This commit is contained in:
世界 2023-12-06 20:15:40 +08:00
parent 05c71c99d1
commit ae8098ad39
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 110 additions and 27 deletions

View file

@ -8,51 +8,76 @@ import (
N "github.com/sagernet/sing/common/network"
)
type BindPacketConn struct {
type BindPacketConn interface {
N.NetPacketConn
Addr net.Addr
net.Conn
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
return &BindPacketConn{
type bindPacketConn struct {
N.NetPacketConn
addr net.Addr
}
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
return &bindPacketConn{
NewPacketConn(conn),
addr,
}
}
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr)
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.addr)
}
func (c *BindPacketConn) RemoteAddr() net.Addr {
return c.Addr
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &BindPacketReadWaiter{readWaiter}, true
}
func (c *BindPacketConn) Upstream() any {
func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *bindPacketConn) Upstream() any {
return c.NetPacketConn
}
var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
)
type UnbindPacketConn struct {
N.ExtendedConn
Addr M.Socksaddr
addr M.Socksaddr
}
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
addr,
}
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.Addr.UDPAddr()
addr = c.addr.UDPAddr()
}
return
}
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
if err != nil {
return
}
destination = c.Addr
destination = c.addr
return
}
@ -74,6 +99,14 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
if !isReadWaiter {
return nil, false
}
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}

42
common/bufio/bind_wait.go Normal file
View file

@ -0,0 +1,42 @@
package bufio
import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
type BindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
}
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
type UnbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
}
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination = w.addr
return
}

View file

@ -39,7 +39,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
})
defer source.InitializeReadWaiter(nil)
for {
err = source.WaitReadBuffer()
_, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
@ -92,7 +92,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
})
defer source.InitializeReadWaiter(nil)
for {
destination, err = source.WaitReadPacket()
_, destination, err = source.WaitReadPacket()
if err != nil {
return
}
@ -123,6 +123,7 @@ type syscallReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
}
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
@ -156,26 +157,29 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
if readN == 0 {
w.readErr = io.EOF
}
w.buffer = buffer
return true
}
}
}
func (w *syscallReadWaiter) WaitReadBuffer() error {
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if w.readFunc == nil {
return os.ErrInvalid
return nil, os.ErrInvalid
}
err := w.rawConn.Read(w.readFunc)
err = w.rawConn.Read(w.readFunc)
if err != nil {
return err
return
}
if w.readErr != nil {
if w.readErr == io.EOF {
return io.EOF
return nil, io.EOF
}
return E.Cause(w.readErr, "raw read")
return nil, E.Cause(w.readErr, "raw read")
}
return nil
buffer = w.buffer
w.buffer = nil
return
}
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
@ -185,6 +189,7 @@ type syscallPacketReadWaiter struct {
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
}
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
@ -225,14 +230,15 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.buffer = buffer
return true
}
}
}
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return M.Socksaddr{}, os.ErrInvalid
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
@ -242,6 +248,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

View file

@ -7,7 +7,7 @@ import (
type ReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadBuffer() error
WaitReadBuffer() (buffer *buf.Buffer, err error)
}
type ReadWaitCreator interface {
@ -16,7 +16,7 @@ type ReadWaitCreator interface {
type PacketReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer)
WaitReadPacket() (destination M.Socksaddr, err error)
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
}
type PacketReadWaitCreator interface {