mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Implementation read waiter for socks5 UDP and UoT
This commit is contained in:
parent
ae8098ad39
commit
aa34723225
4 changed files with 97 additions and 1 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
@ -13,11 +14,17 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var (
|
||||
_ N.NetPacketConn = (*Conn)(nil)
|
||||
_ N.PacketReadWaiter = (*Conn)(nil)
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
isConnect bool
|
||||
destination M.Socksaddr
|
||||
writer N.VectorisedWriter
|
||||
newBuffer func() *buf.Buffer
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, request Request) *Conn {
|
||||
|
@ -141,6 +148,36 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|||
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
|
||||
}
|
||||
|
||||
func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
c.newBuffer = newBuffer
|
||||
}
|
||||
|
||||
func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if c.newBuffer == nil {
|
||||
return nil, M.Socksaddr{}, os.ErrInvalid
|
||||
}
|
||||
if c.isConnect {
|
||||
destination = c.destination
|
||||
} else {
|
||||
destination, err = AddrParser.ReadAddrPort(c.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.Conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buffer = c.newBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.Conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
@ -147,7 +148,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
|
|||
tcpConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return NewAssociateConn(udpConn, address, tcpConn), nil
|
||||
return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil
|
||||
}
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
@ -17,6 +18,8 @@ import (
|
|||
// | 2 | 1 | 1 | Variable | 2 | Variable |
|
||||
// +----+------+------+----------+----------+----------+
|
||||
|
||||
var ErrInvalidPacket = E.New("socks5: invalid packet")
|
||||
|
||||
type AssociatePacketConn struct {
|
||||
N.NetPacketConn
|
||||
remoteAddr M.Socksaddr
|
||||
|
@ -31,6 +34,7 @@ func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underly
|
|||
}
|
||||
}
|
||||
|
||||
// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead.
|
||||
func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
|
||||
return &AssociatePacketConn{
|
||||
NetPacketConn: bufio.NewUnbindPacketConn(conn),
|
||||
|
@ -49,6 +53,9 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n < 3 {
|
||||
return 0, nil, ErrInvalidPacket
|
||||
}
|
||||
c.remoteAddr = M.SocksaddrFromNet(addr)
|
||||
reader := bytes.NewReader(p[3:n])
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
|
@ -92,6 +99,9 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Sock
|
|||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
if buffer.Len() < 3 {
|
||||
return M.Socksaddr{}, ErrInvalidPacket
|
||||
}
|
||||
c.remoteAddr = destination
|
||||
buffer.Advance(3)
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
|
|
48
protocol/socks/packet_wait.go
Normal file
48
protocol/socks/packet_wait.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package socks
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.PacketReadWaitCreator = (*AssociatePacketConn)(nil)
|
||||
|
||||
func (c *AssociatePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
|
||||
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(c.NetPacketConn)
|
||||
if !isReadWaiter {
|
||||
return nil, false
|
||||
}
|
||||
return &AssociatePacketReadWaiter{c, readWaiter}, true
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*AssociatePacketReadWaiter)(nil)
|
||||
|
||||
type AssociatePacketReadWaiter struct {
|
||||
conn *AssociatePacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
|
||||
w.readWaiter.InitializeReadWaiter(newBuffer)
|
||||
}
|
||||
|
||||
func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
buffer, destination, err = w.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if buffer.Len() < 3 {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, ErrInvalidPacket
|
||||
}
|
||||
w.conn.remoteAddr = destination
|
||||
buffer.Advance(3)
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, err
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue