package mux import ( "context" "net" "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) type Client struct { dialer N.Dialer logger logger.Logger protocol byte maxConnections int minStreams int maxStreams int padding bool access sync.Mutex connections list.List[abstractSession] brutal BrutalOptions } type Options struct { Dialer N.Dialer Logger logger.Logger Protocol string MaxConnections int MinStreams int MaxStreams int Padding bool Brutal BrutalOptions } type BrutalOptions struct { Enabled bool SendBPS uint64 ReceiveBPS uint64 } func NewClient(options Options) (*Client, error) { client := &Client{ dialer: options.Dialer, logger: options.Logger, maxConnections: options.MaxConnections, minStreams: options.MinStreams, maxStreams: options.MaxStreams, padding: options.Padding, brutal: options.Brutal, } if client.dialer == nil { client.dialer = N.SystemDialer } if client.maxStreams == 0 && client.maxConnections == 0 { client.minStreams = 8 } switch options.Protocol { case "", "h2mux": client.protocol = ProtocolH2Mux case "smux": client.protocol = ProtocolSmux case "yamux": client.protocol = ProtocolYAMux default: return nil, E.New("unknown protocol: " + options.Protocol) } return client, nil } func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch N.NetworkName(network) { case N.NetworkTCP: stream, err := c.openStream(ctx) if err != nil { return nil, err } return &clientConn{Conn: stream, destination: destination}, nil case N.NetworkUDP: stream, err := c.openStream(ctx) if err != nil { return nil, err } extendedConn := bufio.NewExtendedConn(stream) return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil default: return nil, E.Extend(N.ErrUnknownNetwork, network) } } func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { stream, err := c.openStream(ctx) if err != nil { return nil, err } extendedConn := bufio.NewExtendedConn(stream) return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil } func (c *Client) openStream(ctx context.Context) (net.Conn, error) { var ( session abstractSession stream net.Conn err error ) for attempts := 0; attempts < 2; attempts++ { session, err = c.offer(ctx) if err != nil { continue } stream, err = session.Open() if err != nil { continue } break } if err != nil { return nil, err } return &wrapStream{stream}, nil } func (c *Client) offer(ctx context.Context) (abstractSession, error) { c.access.Lock() defer c.access.Unlock() var sessions []abstractSession for element := c.connections.Front(); element != nil; { if element.Value.IsClosed() { element.Value.Close() nextElement := element.Next() c.connections.Remove(element) element = nextElement continue } sessions = append(sessions, element.Value) element = element.Next() } if c.brutal.Enabled { if len(sessions) > 0 { return sessions[0], nil } return c.offerNew(ctx) } session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) if session == nil { return c.offerNew(ctx) } numStreams := session.NumStreams() if numStreams == 0 { return session, nil } if c.maxConnections > 0 { if len(sessions) >= c.maxConnections || numStreams < c.minStreams { return session, nil } } else { if c.maxStreams > 0 && numStreams < c.maxStreams { return session, nil } } return c.offerNew(ctx) } func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { ctx, cancel := context.WithTimeout(ctx, TCPTimeout) defer cancel() conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) if err != nil { return nil, err } var version byte if c.padding { version = Version1 } else { version = Version0 } conn = newProtocolConn(conn, Request{ Version: version, Protocol: c.protocol, Padding: c.padding, }) if c.padding { conn = newPaddingConn(conn) } session, err := newClientSession(conn, c.protocol) if err != nil { conn.Close() return nil, err } if c.brutal.Enabled { err = c.brutalExchange(ctx, conn, session) if err != nil { conn.Close() session.Close() return nil, E.Cause(err, "brutal exchange") } } c.connections.PushBack(session) return session, nil } func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error { stream, err := session.Open() if err != nil { return err } conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}} err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS) if err != nil { return err } serverReceiveBPS, err := ReadBrutalResponse(conn) if err != nil { return err } conn.Close() sendBPS := c.brutal.SendBPS if serverReceiveBPS < sendBPS { sendBPS = serverReceiveBPS } clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS) if clientBrutalErr != nil { c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client")) } return nil } func (c *Client) Reset() { c.access.Lock() defer c.access.Unlock() for _, session := range c.connections.Array() { session.Close() } c.connections.Init() } func (c *Client) Close() error { c.Reset() return nil }