mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 04:47:40 +03:00
Refactor read waiter interface
This commit is contained in:
parent
05c71c99d1
commit
ae8098ad39
4 changed files with 110 additions and 27 deletions
|
@ -8,51 +8,76 @@ import (
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BindPacketConn struct {
|
type BindPacketConn interface {
|
||||||
N.NetPacketConn
|
N.NetPacketConn
|
||||||
Addr net.Addr
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) *BindPacketConn {
|
type bindPacketConn struct {
|
||||||
return &BindPacketConn{
|
N.NetPacketConn
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBindPacketConn(conn net.PacketConn, addr net.Addr) BindPacketConn {
|
||||||
|
return &bindPacketConn{
|
||||||
NewPacketConn(conn),
|
NewPacketConn(conn),
|
||||||
addr,
|
addr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
|
func (c *bindPacketConn) Read(b []byte) (n int, err error) {
|
||||||
n, _, err = c.ReadFrom(b)
|
n, _, err = c.ReadFrom(b)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
|
func (c *bindPacketConn) Write(b []byte) (n int, err error) {
|
||||||
return c.WriteTo(b, c.Addr)
|
return c.WriteTo(b, c.addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BindPacketConn) RemoteAddr() net.Addr {
|
func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
|
||||||
return c.Addr
|
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
|
||||||
|
if !isReadWaiter {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return &BindPacketReadWaiter{readWaiter}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BindPacketConn) Upstream() any {
|
func (c *bindPacketConn) RemoteAddr() net.Addr {
|
||||||
|
return c.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *bindPacketConn) Upstream() any {
|
||||||
return c.NetPacketConn
|
return c.NetPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
|
||||||
|
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
type UnbindPacketConn struct {
|
type UnbindPacketConn struct {
|
||||||
N.ExtendedConn
|
N.ExtendedConn
|
||||||
Addr M.Socksaddr
|
addr M.Socksaddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUnbindPacketConn(conn net.Conn) *UnbindPacketConn {
|
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
return &UnbindPacketConn{
|
return &UnbindPacketConn{
|
||||||
NewExtendedConn(conn),
|
NewExtendedConn(conn),
|
||||||
M.SocksaddrFromNet(conn.RemoteAddr()),
|
M.SocksaddrFromNet(conn.RemoteAddr()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewUnbindPacketConnWithAddr(conn net.Conn, addr M.Socksaddr) N.NetPacketConn {
|
||||||
|
return &UnbindPacketConn{
|
||||||
|
NewExtendedConn(conn),
|
||||||
|
addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
n, err = c.ExtendedConn.Read(p)
|
n, err = c.ExtendedConn.Read(p)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
addr = c.Addr.UDPAddr()
|
addr = c.addr.UDPAddr()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -66,7 +91,7 @@ func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
destination = c.Addr
|
destination = c.addr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,6 +99,14 @@ func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error
|
||||||
return c.ExtendedConn.WriteBuffer(buffer)
|
return c.ExtendedConn.WriteBuffer(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||||
|
readWaiter, isReadWaiter := CreateReadWaiter(c.ExtendedConn)
|
||||||
|
if !isReadWaiter {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UnbindPacketConn) Upstream() any {
|
func (c *UnbindPacketConn) Upstream() any {
|
||||||
return c.ExtendedConn
|
return c.ExtendedConn
|
||||||
}
|
}
|
||||||
|
|
42
common/bufio/bind_wait.go
Normal file
42
common/bufio/bind_wait.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package bufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
|
||||||
|
|
||||||
|
type BindPacketReadWaiter struct {
|
||||||
|
readWaiter N.PacketReadWaiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||||
|
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
|
buffer, _, err = w.readWaiter.WaitReadPacket()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
|
||||||
|
|
||||||
|
type UnbindPacketReadWaiter struct {
|
||||||
|
readWaiter N.ReadWaiter
|
||||||
|
addr M.Socksaddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||||
|
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
|
buffer, err = w.readWaiter.WaitReadBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
destination = w.addr
|
||||||
|
return
|
||||||
|
}
|
|
@ -39,7 +39,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
|
||||||
})
|
})
|
||||||
defer source.InitializeReadWaiter(nil)
|
defer source.InitializeReadWaiter(nil)
|
||||||
for {
|
for {
|
||||||
err = source.WaitReadBuffer()
|
_, err = source.WaitReadBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
err = nil
|
err = nil
|
||||||
|
@ -92,7 +92,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
|
||||||
})
|
})
|
||||||
defer source.InitializeReadWaiter(nil)
|
defer source.InitializeReadWaiter(nil)
|
||||||
for {
|
for {
|
||||||
destination, err = source.WaitReadPacket()
|
_, destination, err = source.WaitReadPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -123,6 +123,7 @@ type syscallReadWaiter struct {
|
||||||
rawConn syscall.RawConn
|
rawConn syscall.RawConn
|
||||||
readErr error
|
readErr error
|
||||||
readFunc func(fd uintptr) (done bool)
|
readFunc func(fd uintptr) (done bool)
|
||||||
|
buffer *buf.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
|
||||||
|
@ -156,26 +157,29 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||||
if readN == 0 {
|
if readN == 0 {
|
||||||
w.readErr = io.EOF
|
w.readErr = io.EOF
|
||||||
}
|
}
|
||||||
|
w.buffer = buffer
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *syscallReadWaiter) WaitReadBuffer() error {
|
func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
if w.readFunc == nil {
|
if w.readFunc == nil {
|
||||||
return os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
}
|
}
|
||||||
err := w.rawConn.Read(w.readFunc)
|
err = w.rawConn.Read(w.readFunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if w.readErr != nil {
|
if w.readErr != nil {
|
||||||
if w.readErr == io.EOF {
|
if w.readErr == io.EOF {
|
||||||
return io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
return E.Cause(w.readErr, "raw read")
|
return nil, E.Cause(w.readErr, "raw read")
|
||||||
}
|
}
|
||||||
return nil
|
buffer = w.buffer
|
||||||
|
w.buffer = nil
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
|
||||||
|
@ -185,6 +189,7 @@ type syscallPacketReadWaiter struct {
|
||||||
readErr error
|
readErr error
|
||||||
readFrom M.Socksaddr
|
readFrom M.Socksaddr
|
||||||
readFunc func(fd uintptr) (done bool)
|
readFunc func(fd uintptr) (done bool)
|
||||||
|
buffer *buf.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
|
||||||
|
@ -225,14 +230,15 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
|
||||||
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
w.buffer = buffer
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
|
func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
if w.readFunc == nil {
|
if w.readFunc == nil {
|
||||||
return M.Socksaddr{}, os.ErrInvalid
|
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||||
}
|
}
|
||||||
err = w.rawConn.Read(w.readFunc)
|
err = w.rawConn.Read(w.readFunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -242,6 +248,8 @@ func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err
|
||||||
err = E.Cause(w.readErr, "raw read")
|
err = E.Cause(w.readErr, "raw read")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
buffer = w.buffer
|
||||||
|
w.buffer = nil
|
||||||
destination = w.readFrom
|
destination = w.readFrom
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
|
|
||||||
type ReadWaiter interface {
|
type ReadWaiter interface {
|
||||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||||
WaitReadBuffer() error
|
WaitReadBuffer() (buffer *buf.Buffer, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReadWaitCreator interface {
|
type ReadWaitCreator interface {
|
||||||
|
@ -16,7 +16,7 @@ type ReadWaitCreator interface {
|
||||||
|
|
||||||
type PacketReadWaiter interface {
|
type PacketReadWaiter interface {
|
||||||
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
InitializeReadWaiter(newBuffer func() *buf.Buffer)
|
||||||
WaitReadPacket() (destination M.Socksaddr, err error)
|
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketReadWaitCreator interface {
|
type PacketReadWaitCreator interface {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue