package mux import ( "context" "encoding/binary" "net" "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) var ( _ N.Dialer = (*Client)(nil) _ N.PayloadDialer = (*Client)(nil) ) type Client struct { dialer N.Dialer logger logger.Logger protocol byte maxConnections int minStreams int maxStreams int padding bool access sync.Mutex connections list.List[abstractSession] brutal BrutalOptions } type Options struct { Dialer N.Dialer Logger logger.Logger Protocol string MaxConnections int MinStreams int MaxStreams int Padding bool Brutal BrutalOptions } type BrutalOptions struct { Enabled bool SendBPS uint64 ReceiveBPS uint64 } func NewClient(options Options) (*Client, error) { client := &Client{ dialer: options.Dialer, logger: options.Logger, maxConnections: options.MaxConnections, minStreams: options.MinStreams, maxStreams: options.MaxStreams, padding: options.Padding, brutal: options.Brutal, } if client.dialer == nil { client.dialer = N.SystemDialer } if client.maxStreams == 0 && client.maxConnections == 0 { client.minStreams = 8 } switch options.Protocol { case "", "h2mux": client.protocol = ProtocolH2Mux case "smux": client.protocol = ProtocolSmux case "yamux": client.protocol = ProtocolYAMux default: return nil, E.New("unknown protocol: " + options.Protocol) } return client, nil } func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { return c.DialPayloadContext(ctx, network, destination, nil) } func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) { switch N.NetworkName(network) { case N.NetworkTCP: stream, err := c.openStream(ctx) if err != nil { buf.ReleaseMulti(payloads) return nil, err } request := StreamRequest{ Network: N.NetworkTCP, Destination: destination, } buffer := buf.NewSize(streamRequestLen(request) + buf.LenMulti(payloads)) defer buffer.Release() EncodeStreamRequest(request, buffer) for _, payload := range payloads { buffer.Write(payload.Bytes()) payload.Release() } _, err = stream.Write(buffer.Bytes()) if err != nil { stream.Close() return nil, E.Cause(err, "write multiplex handshake request") } response, err := ReadStreamResponse(stream) if err != nil { return nil, E.Cause(err, "read multiplex handshake response") } if response.Status == statusError { return nil, E.New("remote error: " + response.Message) } return stream, nil case N.NetworkUDP: stream, err := c.openStream(ctx) if err != nil { buf.ReleaseMulti(payloads) return nil, err } request := StreamRequest{ Network: N.NetworkUDP, Destination: destination, } buffer := buf.NewSize(streamRequestLen(request) + 2*len(payloads) + buf.LenMulti(payloads)) defer buffer.Release() EncodeStreamRequest(request, buffer) for _, packetPayload := range payloads { binary.Write(buffer, binary.BigEndian, uint16(packetPayload.Len())) buffer.Write(packetPayload.Bytes()) packetPayload.Release() } _, err = stream.Write(buffer.Bytes()) if err != nil { stream.Close() return nil, E.Cause(err, "write multiplex handshake request") } response, err := ReadStreamResponse(stream) if err != nil { return nil, E.Cause(err, "read multiplex handshake response") } if response.Status == statusError { return nil, E.New("remote error: " + response.Message) } extendedConn := bufio.NewExtendedConn(stream) return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil default: return nil, E.Extend(N.ErrUnknownNetwork, network) } } func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { stream, err := c.openStream(ctx) if err != nil { return nil, err } request := StreamRequest{ Network: N.NetworkUDP, Destination: destination, PacketAddr: true, } buffer := buf.NewSize(streamRequestLen(request)) defer buffer.Release() EncodeStreamRequest(request, buffer) _, err = stream.Write(buffer.Bytes()) if err != nil { stream.Close() return nil, E.Cause(err, "write multiplex handshake request") } response, err := ReadStreamResponse(stream) if err != nil { return nil, E.Cause(err, "read multiplex handshake response") } if response.Status == statusError { return nil, E.New("remote error: " + response.Message) } extendedConn := bufio.NewExtendedConn(stream) return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil } func (c *Client) openStream(ctx context.Context) (net.Conn, error) { var ( session abstractSession stream net.Conn err error ) for attempts := 0; attempts < 2; attempts++ { session, err = c.offer(ctx) if err != nil { continue } stream, err = session.Open() if err != nil { continue } break } if err != nil { return nil, err } return &wrapStream{stream}, nil } func (c *Client) offer(ctx context.Context) (abstractSession, error) { c.access.Lock() defer c.access.Unlock() var sessions []abstractSession for element := c.connections.Front(); element != nil; { if element.Value.IsClosed() { element.Value.Close() nextElement := element.Next() c.connections.Remove(element) element = nextElement continue } sessions = append(sessions, element.Value) element = element.Next() } if c.brutal.Enabled { if len(sessions) > 0 { return sessions[0], nil } return c.offerNew(ctx) } session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) if session == nil { return c.offerNew(ctx) } numStreams := session.NumStreams() if numStreams == 0 { return session, nil } if c.maxConnections > 0 { if len(sessions) >= c.maxConnections || numStreams < c.minStreams { return session, nil } } else { if c.maxStreams > 0 && numStreams < c.maxStreams { return session, nil } } return c.offerNew(ctx) } func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { ctx, cancel := context.WithTimeout(ctx, TCPTimeout) defer cancel() conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) if err != nil { return nil, err } var version byte if c.padding { version = Version1 } else { version = Version0 } conn = newProtocolConn(conn, Request{ Version: version, Protocol: c.protocol, Padding: c.padding, }) if c.padding { conn = newPaddingConn(conn) } session, err := newClientSession(conn, c.protocol) if err != nil { conn.Close() return nil, err } if c.brutal.Enabled { err = c.brutalExchange(ctx, conn) if err != nil { conn.Close() session.Close() return nil, E.Cause(err, "brutal exchange") } } c.connections.PushBack(session) return session, nil } func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn) error { stream, err := c.DialPayloadContext(ctx, N.NetworkTCP, M.Socksaddr{Fqdn: BrutalExchangeDomain}, []*buf.Buffer{EncodeBrutalRequest(c.brutal.SendBPS)}) if err != nil { return err } serverReceiveBPS, err := ReadBrutalResponse(stream) if err != nil { return err } stream.Close() sendBPS := c.brutal.SendBPS if serverReceiveBPS < sendBPS { sendBPS = serverReceiveBPS } clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS) if clientBrutalErr != nil { c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client")) } return nil } func (c *Client) Reset() { c.access.Lock() defer c.access.Unlock() for _, session := range c.connections.Array() { session.Close() } c.connections.Init() } func (c *Client) Close() error { c.Reset() return nil }