Update dialer implementation

This commit is contained in:
世界 2024-11-15 10:37:06 +08:00
parent 54badfa885
commit 9a245c7e12
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 156 additions and 361 deletions

View file

@ -15,8 +15,10 @@ const (
BrutalMinSpeedBPS = 65536
)
func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error {
return binary.Write(writer, binary.BigEndian, receiveBPS)
func EncodeBrutalRequest(receiveBPS uint64) *buf.Buffer {
buffer := buf.NewSize(8)
common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS))
return buffer
}
func ReadBrutalRequest(reader io.Reader) (uint64, error) {

View file

@ -2,10 +2,12 @@ package mux
import (
"context"
"encoding/binary"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
@ -14,6 +16,11 @@ import (
"github.com/sagernet/sing/common/x/list"
)
var (
_ N.Dialer = (*Client)(nil)
_ N.PayloadDialer = (*Client)(nil)
)
type Client struct {
dialer N.Dialer
logger logger.Logger
@ -74,18 +81,71 @@ func NewClient(options Options) (*Client, error) {
}
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return c.DialPayloadContext(ctx, network, destination, nil)
}
func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
stream, err := c.openStream(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
return &clientConn{Conn: stream, destination: destination}, nil
request := StreamRequest{
Network: N.NetworkTCP,
Destination: destination,
}
buffer := buf.NewSize(streamRequestLen(request) + buf.LenMulti(payloads))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
for _, payload := range payloads {
buffer.Write(payload.Bytes())
payload.Release()
}
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
return stream, nil
case N.NetworkUDP:
stream, err := c.openStream(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
request := StreamRequest{
Network: N.NetworkUDP,
Destination: destination,
}
buffer := buf.NewSize(streamRequestLen(request) + 2*len(payloads) + buf.LenMulti(payloads))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
for _, packetPayload := range payloads {
binary.Write(buffer, binary.BigEndian, uint16(packetPayload.Len()))
buffer.Write(packetPayload.Bytes())
packetPayload.Release()
}
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
default:
@ -98,6 +158,26 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil {
return nil, err
}
request := StreamRequest{
Network: N.NetworkUDP,
Destination: destination,
PacketAddr: true,
}
buffer := buf.NewSize(streamRequestLen(request))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
}
@ -194,7 +274,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return nil, err
}
if c.brutal.Enabled {
err = c.brutalExchange(ctx, conn, session)
err = c.brutalExchange(ctx, conn)
if err != nil {
conn.Close()
session.Close()
@ -205,21 +285,16 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return session, nil
}
func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
stream, err := session.Open()
func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn) error {
stream, err := c.DialPayloadContext(ctx, N.NetworkTCP, M.Socksaddr{Fqdn: BrutalExchangeDomain}, []*buf.Buffer{EncodeBrutalRequest(c.brutal.SendBPS)})
if err != nil {
return err
}
conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
serverReceiveBPS, err := ReadBrutalResponse(stream)
if err != nil {
return err
}
serverReceiveBPS, err := ReadBrutalResponse(conn)
if err != nil {
return err
}
conn.Close()
stream.Close()
sendBPS := c.brutal.SendBPS
if serverReceiveBPS < sendBPS {
sendBPS = serverReceiveBPS

View file

@ -8,126 +8,24 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type clientConn struct {
net.Conn
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientConn) NeedHandshake() bool {
return !c.requestWritten
}
func (c *clientConn) readResponse() error {
response, err := ReadStreamResponse(c.Conn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientConn) Read(b []byte) (n int, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
return c.Conn.Read(b)
}
func (c *clientConn) Write(b []byte) (n int, err error) {
if c.requestWritten {
return c.Conn.Write(b)
}
request := StreamRequest{
Network: N.NetworkTCP,
Destination: c.destination,
}
buffer := buf.NewSize(streamRequestLen(request) + len(b))
defer buffer.Release()
err = EncodeStreamRequest(request, buffer)
if err != nil {
return
}
buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(b), nil
}
func (c *clientConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *clientConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
func (c *clientConn) ReaderReplaceable() bool {
return c.responseRead
}
func (c *clientConn) WriterReplaceable() bool {
return c.requestWritten
}
func (c *clientConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *clientConn) Upstream() any {
return c.Conn
}
var _ N.NetPacketConn = (*clientPacketConn)(nil)
var (
_ N.NetPacketConn = (*clientPacketConn)(nil)
_ N.PacketReadWaiter = (*clientPacketConn)(nil)
)
type clientPacketConn struct {
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}
func (c *clientPacketConn) NeedHandshake() bool {
return !c.requestWritten
}
func (c *clientPacketConn) readResponse() error {
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientPacketConn) Read(b []byte) (n int, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
@ -139,45 +37,7 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) {
return io.ReadFull(c.conn, b[:length])
}
func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
}
rLen := streamRequestLen(request)
if len(payload) > 0 {
rLen += 2 + len(payload)
}
buffer := buf.NewSize(rLen)
defer buffer.Release()
err = EncodeStreamRequest(request, buffer)
if err != nil {
return
}
if len(payload) > 0 {
common.Must(
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
common.Error(buffer.Write(payload)),
)
}
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(payload), nil
}
func (c *clientPacketConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(b)
}
}
err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
if err != nil {
return
@ -186,13 +46,6 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) {
}
func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
@ -203,16 +56,6 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
}
func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes()))
}
}
bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
return c.conn.WriteBuffer(buffer)
@ -223,13 +66,6 @@ func (c *clientPacketConn) FrontHeadroom() int {
}
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
@ -243,15 +79,6 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p)
}
}
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
@ -268,6 +95,27 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
return c.WriteBuffer(buffer)
}
func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}
func (c *clientPacketConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
@ -284,41 +132,20 @@ func (c *clientPacketConn) Upstream() any {
return c.conn
}
var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
var (
_ N.NetPacketConn = (*clientPacketAddrConn)(nil)
_ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)
)
type clientPacketAddrConn struct {
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}
func (c *clientPacketAddrConn) NeedHandshake() bool {
return !c.requestWritten
}
func (c *clientPacketAddrConn) readResponse() error {
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
@ -340,50 +167,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
return
}
func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
PacketAddr: true,
}
rLen := streamRequestLen(request)
if len(payload) > 0 {
rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
}
buffer := buf.NewSize(rLen)
defer buffer.Release()
err = EncodeStreamRequest(request, buffer)
if err != nil {
return
}
if len(payload) > 0 {
err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
if err != nil {
return
}
common.Must(
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
common.Error(buffer.Write(payload)),
)
}
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(payload), nil
}
func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p, M.SocksaddrFromNet(addr))
}
}
err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
if err != nil {
return
@ -396,13 +180,6 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
}
func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
@ -417,16 +194,6 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
}
func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes(), destination))
}
}
bLen := buffer.Len()
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
@ -437,6 +204,31 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
return c.conn.WriteBuffer(buffer)
}
func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}
func (c *clientPacketAddrConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}

View file

@ -1,73 +0,0 @@
package mux
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var _ N.PacketReadWaiter = (*clientPacketConn)(nil)
func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}
var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)
func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}

2
go.mod
View file

@ -4,7 +4,7 @@ go 1.18
require (
github.com/hashicorp/yamux v0.1.2
github.com/sagernet/sing v0.6.0-alpha.3
github.com/sagernet/sing v0.6.0-alpha.12
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
golang.org/x/net v0.31.0
golang.org/x/sys v0.27.0

6
go.sum
View file

@ -3,10 +3,8 @@ github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.5.1-0.20241109034027-099899991126 h1:pLMpV9pEAinrS9R1n1JLcbNesCl369RfvyxnYCPrkbw=
github.com/sagernet/sing v0.5.1-0.20241109034027-099899991126/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.3 h1:GLp9d6Gbt+Ioeplauuzojz1nY2J6moceVGYIOv/h5gA=
github.com/sagernet/sing v0.6.0-alpha.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.12 h1:RqTvSLcgnpcAVz+jzW9UE4IdqUMIxMJwZKRt+d6XDnU=
github.com/sagernet/sing v0.6.0-alpha.12/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View file

@ -107,13 +107,15 @@ func EncodeRequest(request Request, payload []byte) *buf.Buffer {
}
const (
flagUDP = 1
flagAddr = 2
statusSuccess = 0
statusError = 1
StreamVersion1 = 1
flagUDP = 1
flagAddr = 2
statusSuccess = 0
statusError = 1
)
type StreamRequest struct {
Version byte
Network string
Destination M.Socksaddr
PacketAddr bool
@ -137,18 +139,17 @@ func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
network = N.NetworkUDP
udpAddr = flags&flagAddr != 0
}
return &StreamRequest{network, destination, udpAddr}, nil
return &StreamRequest{Network: network, Destination: destination, PacketAddr: udpAddr}, nil
}
func streamRequestLen(request StreamRequest) int {
var rLen int
rLen += 1 // version
rLen += 2 // flags
rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
return rLen
}
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) error {
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
destination := request.Destination
var flags uint16
if request.Network == N.NetworkUDP {
@ -161,7 +162,7 @@ func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) error {
}
}
common.Must(binary.Write(buffer, binary.BigEndian, flags))
return M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
common.Must(M.SocksaddrSerializer.WriteAddrPort(buffer, destination))
}
type StreamResponse struct {