From 6b086ed6bb0790160de73b16683e75efe2220a79 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Thu, 2 Nov 2023 15:24:48 +0800
Subject: [PATCH] Add tcp-brutal support

---
 brutal.go         |  61 +++++++++++++++++++++++++++
 brutal_linux.go   |  57 +++++++++++++++++++++++++
 brutal_stub.go    |  15 +++++++
 client.go         |  53 ++++++++++++++++++++++++
 go.mod            |   8 ++--
 go.sum            |   4 +-
 padding.go        |   4 ++
 server.go         | 103 ++++++++++++++++++++++++++++++++++++++--------
 server_default.go |  36 ----------------
 9 files changed, 280 insertions(+), 61 deletions(-)
 create mode 100644 brutal.go
 create mode 100644 brutal_linux.go
 create mode 100644 brutal_stub.go
 delete mode 100644 server_default.go

diff --git a/brutal.go b/brutal.go
new file mode 100644
index 0000000..93e76b3
--- /dev/null
+++ b/brutal.go
@@ -0,0 +1,61 @@
+package mux
+
+import (
+	"encoding/binary"
+	"io"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/rw"
+)
+
+const (
+	BrutalExchangeDomain = "_BrutalBwExchange"
+	BrutalMinSpeedBPS    = 65536
+)
+
+func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error {
+	return binary.Write(writer, binary.BigEndian, receiveBPS)
+}
+
+func ReadBrutalRequest(reader io.Reader) (uint64, error) {
+	var receiveBPS uint64
+	err := binary.Read(reader, binary.BigEndian, &receiveBPS)
+	return receiveBPS, err
+}
+
+func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message string) error {
+	buffer := buf.New()
+	defer buffer.Release()
+	common.Must(binary.Write(buffer, binary.BigEndian, ok))
+	if ok {
+		common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS))
+	} else {
+		err := rw.WriteVString(buffer, message)
+		if err != nil {
+			return err
+		}
+	}
+	return common.Error(writer.Write(buffer.Bytes()))
+}
+
+func ReadBrutalResponse(reader io.Reader) (uint64, error) {
+	var ok bool
+	err := binary.Read(reader, binary.BigEndian, &ok)
+	if err != nil {
+		return 0, err
+	}
+	if ok {
+		var receiveBPS uint64
+		err = binary.Read(reader, binary.BigEndian, &receiveBPS)
+		return receiveBPS, err
+	} else {
+		var message string
+		message, err = rw.ReadVString(reader)
+		if err != nil {
+			return 0, err
+		}
+		return 0, E.New("remote error: ", message)
+	}
+}
diff --git a/brutal_linux.go b/brutal_linux.go
new file mode 100644
index 0000000..6a2e770
--- /dev/null
+++ b/brutal_linux.go
@@ -0,0 +1,57 @@
+package mux
+
+import (
+	"net"
+	"os"
+	"reflect"
+	"syscall"
+	"unsafe"
+	_ "unsafe"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/control"
+	E "github.com/sagernet/sing/common/exceptions"
+
+	"golang.org/x/sys/unix"
+)
+
+const (
+	BrutalAvailable   = true
+	TCP_BRUTAL_PARAMS = 23301
+)
+
+type TCPBrutalParams struct {
+	Rate     uint64
+	CwndGain uint32
+}
+
+//go:linkname setsockopt syscall.setsockopt
+func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error)
+
+func SetBrutalOptions(conn net.Conn, sendBPS uint64) error {
+	syscallConn, loaded := common.Cast[syscall.Conn](conn)
+	if !loaded {
+		return E.New(
+			"brutal: nested multiplexing is not supported: ",
+			"cannot convert ", reflect.TypeOf(conn), " to syscall.Conn, final type: ", reflect.TypeOf(common.Top(conn)),
+		)
+	}
+	return control.Conn(syscallConn, func(fd uintptr) error {
+		err := unix.SetsockoptString(int(fd), unix.IPPROTO_TCP, unix.TCP_CONGESTION, "brutal")
+		if err != nil {
+			return E.Extend(
+				os.NewSyscallError("setsockopt IPPROTO_TCP TCP_CONGESTION brutal", err),
+				"please make sure you have installed the tcp-brutal kernel module",
+			)
+		}
+		params := TCPBrutalParams{
+			Rate:     sendBPS,
+			CwndGain: 20, // hysteria2 default
+		}
+		err = setsockopt(int(fd), unix.IPPROTO_TCP, TCP_BRUTAL_PARAMS, unsafe.Pointer(&params), unsafe.Sizeof(params))
+		if err != nil {
+			return os.NewSyscallError("setsockopt IPPROTO_TCP TCP_BRUTAL_PARAMS", err)
+		}
+		return nil
+	})
+}
diff --git a/brutal_stub.go b/brutal_stub.go
new file mode 100644
index 0000000..67c82da
--- /dev/null
+++ b/brutal_stub.go
@@ -0,0 +1,15 @@
+//go:build !linux
+
+package mux
+
+import (
+	"net"
+
+	E "github.com/sagernet/sing/common/exceptions"
+)
+
+const BrutalAvailable = false
+
+func SetBrutalOptions(conn net.Conn, sendBPS uint64) error {
+	return E.New("TCP Brutal is only supported on Linux")
+}
diff --git a/client.go b/client.go
index c574d95..35459a5 100644
--- a/client.go
+++ b/client.go
@@ -8,6 +8,7 @@ import (
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/x/list"
@@ -15,6 +16,7 @@ import (
 
 type Client struct {
 	dialer         N.Dialer
+	logger         logger.Logger
 	protocol       byte
 	maxConnections int
 	minStreams     int
@@ -22,24 +24,35 @@ type Client struct {
 	padding        bool
 	access         sync.Mutex
 	connections    list.List[abstractSession]
+	brutal         BrutalOptions
 }
 
 type Options struct {
 	Dialer         N.Dialer
+	Logger         logger.Logger
 	Protocol       string
 	MaxConnections int
 	MinStreams     int
 	MaxStreams     int
 	Padding        bool
+	Brutal         BrutalOptions
+}
+
+type BrutalOptions struct {
+	Enabled    bool
+	SendBPS    uint64
+	ReceiveBPS uint64
 }
 
 func NewClient(options Options) (*Client, error) {
 	client := &Client{
 		dialer:         options.Dialer,
+		logger:         options.Logger,
 		maxConnections: options.MaxConnections,
 		minStreams:     options.MinStreams,
 		maxStreams:     options.MaxStreams,
 		padding:        options.Padding,
+		brutal:         options.Brutal,
 	}
 	if client.dialer == nil {
 		client.dialer = N.SystemDialer
@@ -126,6 +139,12 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) {
 		sessions = append(sessions, element.Value)
 		element = element.Next()
 	}
+	if c.brutal.Enabled {
+		if len(sessions) > 0 {
+			return sessions[0], nil
+		}
+		return c.offerNew(ctx)
+	}
 	session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
 	if session == nil {
 		return c.offerNew(ctx)
@@ -170,10 +189,44 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
 		conn.Close()
 		return nil, err
 	}
+	if c.brutal.Enabled {
+		err = c.brutalExchange(conn, session)
+		if err != nil {
+			conn.Close()
+			session.Close()
+			return nil, E.Cause(err, "brutal exchange")
+		}
+	}
 	c.connections.PushBack(session)
 	return session, nil
 }
 
+func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error {
+	stream, err := session.Open()
+	if err != nil {
+		return err
+	}
+	conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
+	err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
+	if err != nil {
+		return err
+	}
+	serverReceiveBPS, err := ReadBrutalResponse(conn)
+	if err != nil {
+		return err
+	}
+	conn.Close()
+	sendBPS := c.brutal.SendBPS
+	if serverReceiveBPS < sendBPS {
+		sendBPS = serverReceiveBPS
+	}
+	clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS)
+	if clientBrutalErr != nil {
+		c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client"))
+	}
+	return nil
+}
+
 func (c *Client) Reset() {
 	c.access.Lock()
 	defer c.access.Unlock()
diff --git a/go.mod b/go.mod
index 97b15c2..4f141fb 100644
--- a/go.mod
+++ b/go.mod
@@ -4,12 +4,10 @@ go 1.18
 
 require (
 	github.com/hashicorp/yamux v0.1.1
-	github.com/sagernet/sing v0.2.17
+	github.com/sagernet/sing v0.2.18-0.20231108041402-4fbbd193203c
 	github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
 	golang.org/x/net v0.18.0
+	golang.org/x/sys v0.14.0
 )
 
-require (
-	golang.org/x/sys v0.14.0 // indirect
-	golang.org/x/text v0.14.0 // indirect
-)
+require golang.org/x/text v0.14.0 // indirect
diff --git a/go.sum b/go.sum
index a336b45..be848a3 100644
--- a/go.sum
+++ b/go.sum
@@ -1,8 +1,8 @@
 github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
 github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
-github.com/sagernet/sing v0.2.17 h1:vMPKb3MV0Aa5ws4dCJkRI8XEjrsUcDn810czd0FwmzI=
-github.com/sagernet/sing v0.2.17/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
+github.com/sagernet/sing v0.2.18-0.20231108041402-4fbbd193203c h1:uask61Pxc3nGqsOSjqnBKrwfODWRoEa80lXm04LNk0E=
+github.com/sagernet/sing v0.2.18-0.20231108041402-4fbbd193203c/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
 golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg=
diff --git a/padding.go b/padding.go
index d902aae..870deb5 100644
--- a/padding.go
+++ b/padding.go
@@ -201,6 +201,10 @@ func (c *paddingConn) FrontHeadroom() int {
 	return 4 + 256 + 1024
 }
 
+func (c *paddingConn) Upstream() any {
+	return c.ExtendedConn
+}
+
 type vectorisedPaddingConn struct {
 	paddingConn
 	writer N.VectorisedWriter
diff --git a/server.go b/server.go
index a805254..b97a97d 100644
--- a/server.go
+++ b/server.go
@@ -5,6 +5,7 @@ import (
 	"net"
 
 	"github.com/sagernet/sing/common/bufio"
+	"github.com/sagernet/sing/common/debug"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
@@ -12,19 +13,49 @@ import (
 	"github.com/sagernet/sing/common/task"
 )
 
-type ServerHandler interface {
+type ServiceHandler interface {
 	N.TCPConnectionHandler
 	N.UDPConnectionHandler
-	E.Handler
 }
 
-func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error {
+type Service struct {
+	newStreamContext func(context.Context, net.Conn) context.Context
+	logger           logger.ContextLogger
+	handler          ServiceHandler
+	padding          bool
+	brutal           BrutalOptions
+}
+
+type ServiceOptions struct {
+	NewStreamContext func(context.Context, net.Conn) context.Context
+	Logger           logger.ContextLogger
+	Handler          ServiceHandler
+	Padding          bool
+	Brutal           BrutalOptions
+}
+
+func NewService(options ServiceOptions) (*Service, error) {
+	if options.Brutal.Enabled && !BrutalAvailable && !debug.Enabled {
+		return nil, E.New("TCP Brutal is only supported on Linux")
+	}
+	return &Service{
+		newStreamContext: options.NewStreamContext,
+		logger:           options.Logger,
+		handler:          options.Handler,
+		padding:          options.Padding,
+		brutal:           options.Brutal,
+	}, nil
+}
+
+func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
 	request, err := ReadRequest(conn)
 	if err != nil {
 		return err
 	}
 	if request.Padding {
 		conn = newPaddingConn(conn)
+	} else if s.padding {
+		return E.New("non-padded connection rejected")
 	}
 	session, err := newServerSession(conn, request.Protocol)
 	if err != nil {
@@ -38,7 +69,13 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.
 			if err != nil {
 				return err
 			}
-			go newConnection(ctx, handler, logger, stream, metadata)
+			streamCtx := s.newStreamContext(ctx, stream)
+			go func() {
+				hErr := s.newConnection(streamCtx, conn, stream, metadata)
+				if hErr != nil {
+					s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection"))
+				}
+			}()
 		}
 	})
 	group.Cleanup(func() {
@@ -47,34 +84,64 @@ func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.
 	return group.Run(ctx)
 }
 
-func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) {
+func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, stream net.Conn, metadata M.Metadata) error {
 	stream = &wrapStream{stream}
 	request, err := ReadStreamRequest(stream)
 	if err != nil {
-		logger.ErrorContext(ctx, err)
-		return
+		return E.Cause(err, "read multiplex stream request")
 	}
 	metadata.Destination = request.Destination
 	if request.Network == N.NetworkTCP {
-		logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
-		hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
-		stream.Close()
-		if hErr != nil {
-			handler.NewError(ctx, hErr)
+		conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}
+		if request.Destination.Fqdn == BrutalExchangeDomain {
+			defer stream.Close()
+			var clientReceiveBPS uint64
+			clientReceiveBPS, err = ReadBrutalRequest(conn)
+			if err != nil {
+				return E.Cause(err, "read brutal request")
+			}
+			if !s.brutal.Enabled {
+				err = WriteBrutalResponse(conn, 0, false, "brutal is not enabled by the server")
+				if err != nil {
+					return E.Cause(err, "write brutal response")
+				}
+				return nil
+			}
+			sendBPS := s.brutal.SendBPS
+			if clientReceiveBPS < sendBPS {
+				sendBPS = clientReceiveBPS
+			}
+			err = SetBrutalOptions(sessionConn, sendBPS)
+			if err != nil {
+				// ignore error in test
+				if !debug.Enabled {
+					err = WriteBrutalResponse(conn, 0, false, E.Cause(err, "enable TCP Brutal").Error())
+					if err != nil {
+						return E.Cause(err, "write brutal response")
+					}
+					return nil
+				}
+			}
+			err = WriteBrutalResponse(conn, s.brutal.ReceiveBPS, true, "")
+			if err != nil {
+				return E.Cause(err, "write brutal response")
+			}
+			return nil
 		}
+		s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
+		s.handler.NewConnection(ctx, conn, metadata)
+		stream.Close()
 	} else {
 		var packetConn N.PacketConn
 		if !request.PacketAddr {
-			logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
+			s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
 			packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
 		} else {
-			logger.InfoContext(ctx, "inbound multiplex packet connection")
+			s.logger.InfoContext(ctx, "inbound multiplex packet connection")
 			packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
 		}
-		hErr := handler.NewPacketConnection(ctx, packetConn, metadata)
+		s.handler.NewPacketConnection(ctx, packetConn, metadata)
 		stream.Close()
-		if hErr != nil {
-			handler.NewError(ctx, hErr)
-		}
 	}
+	return nil
 }
diff --git a/server_default.go b/server_default.go
deleted file mode 100644
index f10247e..0000000
--- a/server_default.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package mux
-
-import (
-	"context"
-	"net"
-
-	"github.com/sagernet/sing/common/bufio"
-	"github.com/sagernet/sing/common/logger"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-)
-
-func HandleConnectionDefault(ctx context.Context, conn net.Conn) error {
-	return HandleConnection(ctx, (*defaultServerHandler)(nil), logger.NOP(), conn, M.Metadata{})
-}
-
-type defaultServerHandler struct{}
-
-func (h *defaultServerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
-	remoteConn, err := N.SystemDialer.DialContext(ctx, N.NetworkTCP, metadata.Destination)
-	if err != nil {
-		return err
-	}
-	return bufio.CopyConn(ctx, conn, remoteConn)
-}
-
-func (h *defaultServerHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
-	remoteConn, err := N.SystemDialer.ListenPacket(ctx, metadata.Destination)
-	if err != nil {
-		return err
-	}
-	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(remoteConn))
-}
-
-func (h *defaultServerHandler) NewError(ctx context.Context, err error) {
-}