mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Refactor read waiter interface
This commit is contained in:
parent
05c71c99d1
commit
ae8098ad39
4 changed files with 110 additions and 27 deletions
|
@ -8,51 +8,76 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type BindPacketConn struct {
|
||||
type BindPacketConn interface {
|
||||
N.NetPacketConn
|
||||
Addr net.Addr
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
|
||||
return &BindPacketConn{
|
||||
type bindPacketConn struct {
|
||||
N.NetPacketConn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
|
||||
return &bindPacketConn{
|
||||
NewPacketConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
|
||||
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.Addr)
|
||||
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.addr)
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.Addr
|
||||
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &BindPacketReadWaiter{readWaiter}, true
|
||||
}
|
||||
|
||||
func (c *BindPacketConn) Upstream() any {
|
||||
func (c *bindPacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *bindPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
|
||||
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
|
||||
)
|
||||
|
||||
type UnbindPacketConn struct {
|
||||
N.ExtendedConn
|
||||
Addr M.Socksaddr
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
|
||||
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
|
||||
return &UnbindPacketConn{
|
||||
NewExtendedConn(conn),
|
||||
addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err == nil {
|
||||
addr = c.Addr.UDPAddr()
|
||||
addr = c.addr.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.Addr
|
||||
destination = c.addr
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -74,6 +99,14 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
|
|||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
|
||||
}
|
||||
|
||||
func (c *UnbindPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
|
42
common/bufio/bind_wait.go
Normal file
42
common/bufio/bind_wait.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
|
||||
|
||||
type BindPacketReadWaiter struct {
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
}
|
||||
|
||||
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
buffer, _, err = w.readWaiter.WaitReadPacket()
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
|
||||
|
||||
type UnbindPacketReadWaiter struct {
|
||||
readWaiter N.ReadWaiter
|
||||
addr M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
}
|
||||
|
||||
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, err = w.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = w.addr
|
||||
return
|
||||
}
|
|
@ -39,7 +39,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
|
|||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
err = source.WaitReadBuffer()
|
||||
_, err = source.WaitReadBuffer()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
|
@ -92,7 +92,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
|
|||
})
|
||||
defer source.InitializeReadWaiter(nil)
|
||||
for {
|
||||
destination, err = source.WaitReadPacket()
|
||||
_, destination, err = source.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -123,6 +123,7 @@ type syscallReadWaiter struct {
|
|||
rawConn syscall.RawConn
|
||||
readErr error
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||
|
@ -156,26 +157,29 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
|||
if readN == 0 {
|
||||
w.readErr = io.EOF
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() error {
|
||||
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
if w.readFunc == nil {
|
||||
return os.ErrInvalid
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
err := w.rawConn.Read(w.readFunc)
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
if w.readErr != nil {
|
||||
if w.readErr == io.EOF {
|
||||
return io.EOF
|
||||
return nil, io.EOF
|
||||
}
|
||||
return E.Cause(w.readErr, "raw read")
|
||||
return nil, E.Cause(w.readErr, "raw read")
|
||||
}
|
||||
return nil
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||
|
@ -185,6 +189,7 @@ type syscallPacketReadWaiter struct {
|
|||
readErr error
|
||||
readFrom M.Socksaddr
|
||||
readFunc func(fd uintptr) (done bool)
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||
|
@ -225,14 +230,15 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
|
|||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||
}
|
||||
}
|
||||
w.buffer = buffer
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
|
||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if w.readFunc == nil {
|
||||
return M.Socksaddr{}, os.ErrInvalid
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
err = w.rawConn.Read(w.readFunc)
|
||||
if err != nil {
|
||||
|
@ -242,6 +248,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
|
|||
err = E.Cause(w.readErr, "raw read")
|
||||
return
|
||||
}
|
||||
buffer = w.buffer
|
||||
w.buffer = nil
|
||||
destination = w.readFrom
|
||||
return
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
type ReadWaiter interface {
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
WaitReadBuffer() error
|
||||
WaitReadBuffer() (buffer *buf.Buffer, err error)
|
||||
}
|
||||
|
||||
type ReadWaitCreator interface {
|
||||
|
@ -16,7 +16,7 @@ type ReadWaitCreator interface {
|
|||
|
||||
type PacketReadWaiter interface {
|
||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||
WaitReadPacket() (destination M.Socksaddr, err error)
|
||||
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
|
||||
}
|
||||
|
||||
type PacketReadWaitCreator interface {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue