mirror of
https://github.com/SagerNet/sing-mux.git
synced 2025-04-01 19:17:36 +03:00
Implement read waiter for UDP
This commit is contained in:
parent
6be79e969e
commit
aa458ed011
5 changed files with 127 additions and 42 deletions
|
@ -86,7 +86,8 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M.
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
|
||||
extendedConn := bufio.NewExtendedConn(stream)
|
||||
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
|
||||
default:
|
||||
return nil, E.Extend(N.ErrUnknownNetwork, network)
|
||||
}
|
||||
|
@ -97,7 +98,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
|
||||
extendedConn := bufio.NewExtendedConn(stream)
|
||||
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
|
||||
}
|
||||
|
||||
func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
|
||||
|
|
|
@ -93,12 +93,16 @@ func (c *clientConn) Upstream() any {
|
|||
return c.Conn
|
||||
}
|
||||
|
||||
var _ N.NetPacketConn = (*clientPacketConn)(nil)
|
||||
|
||||
type clientPacketConn struct {
|
||||
N.ExtendedConn
|
||||
access sync.Mutex
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
N.AbstractConn
|
||||
conn N.ExtendedConn
|
||||
access sync.Mutex
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) NeedHandshake() bool {
|
||||
|
@ -106,7 +110,7 @@ func (c *clientPacketConn) NeedHandshake() bool {
|
|||
}
|
||||
|
||||
func (c *clientPacketConn) readResponse() error {
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
response, err := ReadStreamResponse(c.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -125,14 +129,14 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) {
|
|||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cap(b) < int(length) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
return io.ReadFull(c.ExtendedConn, b[:length])
|
||||
return io.ReadFull(c.conn, b[:length])
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
||||
|
@ -156,7 +160,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
|||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
}
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
_, err = c.conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -174,11 +178,11 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) {
|
|||
return c.writeRequest(b)
|
||||
}
|
||||
}
|
||||
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
|
||||
err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(b)
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||
|
@ -190,11 +194,11 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
|||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
_, err = buffer.ReadFullFrom(c.conn, int(length))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -211,7 +215,7 @@ func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
|
|||
}
|
||||
bLen := buffer.Len()
|
||||
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
return c.conn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) FrontHeadroom() int {
|
||||
|
@ -227,14 +231,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
|||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cap(p) < int(length) {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(c.ExtendedConn, p[:length])
|
||||
n, err = io.ReadFull(c.conn, p[:length])
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -248,11 +252,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
return c.writeRequest(p)
|
||||
}
|
||||
}
|
||||
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
|
||||
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(p)
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
|
@ -265,7 +269,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
|||
}
|
||||
|
||||
func (c *clientPacketConn) LocalAddr() net.Addr {
|
||||
return c.ExtendedConn.LocalAddr()
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) RemoteAddr() net.Addr {
|
||||
|
@ -277,17 +281,19 @@ func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
|
|||
}
|
||||
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
return c.conn
|
||||
}
|
||||
|
||||
var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
|
||||
|
||||
type clientPacketAddrConn struct {
|
||||
N.ExtendedConn
|
||||
access sync.Mutex
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
N.AbstractConn
|
||||
conn N.ExtendedConn
|
||||
access sync.Mutex
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) NeedHandshake() bool {
|
||||
|
@ -295,7 +301,7 @@ func (c *clientPacketAddrConn) NeedHandshake() bool {
|
|||
}
|
||||
|
||||
func (c *clientPacketAddrConn) readResponse() error {
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
response, err := ReadStreamResponse(c.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -313,7 +319,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -323,14 +329,14 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
addr = destination.UDPAddr()
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cap(p) < int(length) {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(c.ExtendedConn, p[:length])
|
||||
n, err = io.ReadFull(c.conn, p[:length])
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -360,7 +366,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
|
|||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
}
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
_, err = c.conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -378,15 +384,15 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
|
|||
return c.writeRequest(p, M.SocksaddrFromNet(addr))
|
||||
}
|
||||
}
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
|
||||
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(p)
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
|
@ -397,16 +403,16 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
|
|||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
_, err = buffer.ReadFullFrom(c.conn, int(length))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -428,11 +434,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
|
|||
return err
|
||||
}
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(bLen)))
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
return c.conn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) LocalAddr() net.Addr {
|
||||
return c.ExtendedConn.LocalAddr()
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) FrontHeadroom() int {
|
||||
|
@ -444,5 +450,5 @@ func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
|
|||
}
|
||||
|
||||
func (c *clientPacketAddrConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
return c.conn
|
||||
}
|
||||
|
|
73
client_conn_wait.go
Normal file
73
client_conn_wait.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.PacketReadWaiter = (*clientPacketConn)(nil)
|
||||
|
||||
func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, err
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
||||
|
||||
var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)
|
||||
|
||||
func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.conn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buffer = c.readWaitOptions.NewPacketBuffer()
|
||||
_, err = buffer.ReadFullFrom(c.conn, int(length))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, M.Socksaddr{}, err
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
2
go.mod
2
go.mod
|
@ -4,7 +4,7 @@ go 1.18
|
|||
|
||||
require (
|
||||
github.com/hashicorp/yamux v0.1.1
|
||||
github.com/sagernet/sing v0.2.20
|
||||
github.com/sagernet/sing v0.3.0-rc.2
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
|
||||
golang.org/x/net v0.19.0
|
||||
golang.org/x/sys v0.15.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -3,6 +3,10 @@ github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbg
|
|||
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
|
||||
github.com/sagernet/sing v0.2.20 h1:ckcCB/5xu8G8wElNeH74IF6Soac5xWN+eQUXRuonjPQ=
|
||||
github.com/sagernet/sing v0.2.20/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
|
||||
github.com/sagernet/sing v0.3.0-rc.1 h1:XcdCC9CcLNfMSlObIQPjxyzenGQT2R1sGLHvdwDmQFU=
|
||||
github.com/sagernet/sing v0.3.0-rc.1/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
|
||||
github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug=
|
||||
github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
|
||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue