mirror of
https://github.com/SagerNet/sing-mux.git
synced 2025-04-01 19:17:36 +03:00
246 lines
5.7 KiB
Go
246 lines
5.7 KiB
Go
package mux
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/x/list"
|
|
)
|
|
|
|
type Client struct {
|
|
dialer N.Dialer
|
|
logger logger.Logger
|
|
protocol byte
|
|
maxConnections int
|
|
minStreams int
|
|
maxStreams int
|
|
padding bool
|
|
access sync.Mutex
|
|
connections list.List[abstractSession]
|
|
brutal BrutalOptions
|
|
}
|
|
|
|
type Options struct {
|
|
Dialer N.Dialer
|
|
Logger logger.Logger
|
|
Protocol string
|
|
MaxConnections int
|
|
MinStreams int
|
|
MaxStreams int
|
|
Padding bool
|
|
Brutal BrutalOptions
|
|
}
|
|
|
|
type BrutalOptions struct {
|
|
Enabled bool
|
|
SendBPS uint64
|
|
ReceiveBPS uint64
|
|
}
|
|
|
|
func NewClient(options Options) (*Client, error) {
|
|
client := &Client{
|
|
dialer: options.Dialer,
|
|
logger: options.Logger,
|
|
maxConnections: options.MaxConnections,
|
|
minStreams: options.MinStreams,
|
|
maxStreams: options.MaxStreams,
|
|
padding: options.Padding,
|
|
brutal: options.Brutal,
|
|
}
|
|
if client.dialer == nil {
|
|
client.dialer = N.SystemDialer
|
|
}
|
|
if client.maxStreams == 0 && client.maxConnections == 0 {
|
|
client.minStreams = 8
|
|
}
|
|
switch options.Protocol {
|
|
case "", "h2mux":
|
|
client.protocol = ProtocolH2Mux
|
|
case "smux":
|
|
client.protocol = ProtocolSmux
|
|
case "yamux":
|
|
client.protocol = ProtocolYAMux
|
|
default:
|
|
return nil, E.New("unknown protocol: " + options.Protocol)
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
switch N.NetworkName(network) {
|
|
case N.NetworkTCP:
|
|
stream, err := c.openStream(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &clientConn{Conn: stream, destination: destination}, nil
|
|
case N.NetworkUDP:
|
|
stream, err := c.openStream(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
extendedConn := bufio.NewExtendedConn(stream)
|
|
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
|
|
default:
|
|
return nil, E.Extend(N.ErrUnknownNetwork, network)
|
|
}
|
|
}
|
|
|
|
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
|
stream, err := c.openStream(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
extendedConn := bufio.NewExtendedConn(stream)
|
|
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
|
|
}
|
|
|
|
func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
|
|
var (
|
|
session abstractSession
|
|
stream net.Conn
|
|
err error
|
|
)
|
|
for attempts := 0; attempts < 2; attempts++ {
|
|
session, err = c.offer(ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
stream, err = session.Open()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &wrapStream{stream}, nil
|
|
}
|
|
|
|
func (c *Client) offer(ctx context.Context) (abstractSession, error) {
|
|
c.access.Lock()
|
|
defer c.access.Unlock()
|
|
|
|
var sessions []abstractSession
|
|
for element := c.connections.Front(); element != nil; {
|
|
if element.Value.IsClosed() {
|
|
element.Value.Close()
|
|
nextElement := element.Next()
|
|
c.connections.Remove(element)
|
|
element = nextElement
|
|
continue
|
|
}
|
|
sessions = append(sessions, element.Value)
|
|
element = element.Next()
|
|
}
|
|
if c.brutal.Enabled {
|
|
if len(sessions) > 0 {
|
|
return sessions[0], nil
|
|
}
|
|
return c.offerNew(ctx)
|
|
}
|
|
session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
|
|
if session == nil {
|
|
return c.offerNew(ctx)
|
|
}
|
|
numStreams := session.NumStreams()
|
|
if numStreams == 0 {
|
|
return session, nil
|
|
}
|
|
if c.maxConnections > 0 {
|
|
if len(sessions) >= c.maxConnections || numStreams < c.minStreams {
|
|
return session, nil
|
|
}
|
|
} else {
|
|
if c.maxStreams > 0 && numStreams < c.maxStreams {
|
|
return session, nil
|
|
}
|
|
}
|
|
return c.offerNew(ctx)
|
|
}
|
|
|
|
func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, TCPTimeout)
|
|
defer cancel()
|
|
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var version byte
|
|
if c.padding {
|
|
version = Version1
|
|
} else {
|
|
version = Version0
|
|
}
|
|
conn = newProtocolConn(conn, Request{
|
|
Version: version,
|
|
Protocol: c.protocol,
|
|
Padding: c.padding,
|
|
})
|
|
if c.padding {
|
|
conn = newPaddingConn(conn)
|
|
}
|
|
session, err := newClientSession(conn, c.protocol)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
if c.brutal.Enabled {
|
|
err = c.brutalExchange(ctx, conn, session)
|
|
if err != nil {
|
|
conn.Close()
|
|
session.Close()
|
|
return nil, E.Cause(err, "brutal exchange")
|
|
}
|
|
}
|
|
c.connections.PushBack(session)
|
|
return session, nil
|
|
}
|
|
|
|
func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
|
|
stream, err := session.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
|
|
err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
serverReceiveBPS, err := ReadBrutalResponse(conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn.Close()
|
|
sendBPS := c.brutal.SendBPS
|
|
if serverReceiveBPS < sendBPS {
|
|
sendBPS = serverReceiveBPS
|
|
}
|
|
clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS)
|
|
if clientBrutalErr != nil {
|
|
c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client"))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Reset() {
|
|
c.access.Lock()
|
|
defer c.access.Unlock()
|
|
for _, session := range c.connections.Array() {
|
|
session.Close()
|
|
}
|
|
c.connections.Init()
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
c.Reset()
|
|
return nil
|
|
}
|