diff --git a/adapter/certificate.go b/adapter/certificate.go
index 0998e130..11eeb196 100644
--- a/adapter/certificate.go
+++ b/adapter/certificate.go
@@ -10,6 +10,9 @@ import (
type CertificateStore interface {
LifecycleService
Pool() *x509.CertPool
+ TLSDecryptionEnabled() bool
+ TLSDecryptionCertificate() *x509.Certificate
+ TLSDecryptionPrivateKey() any
}
func RootPoolFromContext(ctx context.Context) *x509.CertPool {
diff --git a/adapter/inbound.go b/adapter/inbound.go
index 1218c049..869cdec9 100644
--- a/adapter/inbound.go
+++ b/adapter/inbound.go
@@ -2,6 +2,8 @@ package adapter
import (
"context"
+ "crypto/tls"
+ "net/http"
"net/netip"
"time"
@@ -58,6 +60,8 @@ type InboundContext struct {
Client string
SniffContext any
PacketSniffError error
+ HTTPRequest *http.Request
+ ClientHello *tls.ClientHelloInfo
// cache
@@ -74,6 +78,7 @@ type InboundContext struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
+ MITM *option.MITMRouteOptions
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
diff --git a/adapter/lifecycle.go b/adapter/lifecycle.go
index aff9fadb..9e522141 100644
--- a/adapter/lifecycle.go
+++ b/adapter/lifecycle.go
@@ -1,6 +1,8 @@
package adapter
-import E "github.com/sagernet/sing/common/exceptions"
+import (
+ E "github.com/sagernet/sing/common/exceptions"
+)
type StartStage uint8
@@ -45,6 +47,9 @@ type LifecycleService interface {
func Start(stage StartStage, services ...Lifecycle) error {
for _, service := range services {
+ if service == nil {
+ continue
+ }
err := service.Start(stage)
if err != nil {
return err
diff --git a/adapter/mitm.go b/adapter/mitm.go
new file mode 100644
index 00000000..450468a9
--- /dev/null
+++ b/adapter/mitm.go
@@ -0,0 +1,13 @@
+package adapter
+
+import (
+ "context"
+ "net"
+
+ N "github.com/sagernet/sing/common/network"
+)
+
+type MITMEngine interface {
+ Lifecycle
+ NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
+}
diff --git a/box.go b/box.go
index 0f176474..664764a6 100644
--- a/box.go
+++ b/box.go
@@ -23,6 +23,7 @@ import (
"github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log"
+ "github.com/sagernet/sing-box/mitm"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/direct"
"github.com/sagernet/sing-box/route"
@@ -48,6 +49,7 @@ type Box struct {
dnsRouter *dns.Router
connection *route.ConnectionManager
router *route.Router
+ mitm adapter.MITMEngine //*mitm.Engine
services []adapter.LifecycleService
done chan struct{}
}
@@ -143,18 +145,12 @@ func New(options Options) (*Box, error) {
}
var services []adapter.LifecycleService
- certificateOptions := common.PtrValueOrDefault(options.Certificate)
- if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
- len(certificateOptions.Certificate) > 0 ||
- len(certificateOptions.CertificatePath) > 0 ||
- len(certificateOptions.CertificateDirectoryPath) > 0 {
- certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions)
- if err != nil {
- return nil, err
- }
- service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
- services = append(services, certificateStore)
+ certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), common.PtrValueOrDefault(options.Certificate))
+ if err != nil {
+ return nil, err
}
+ service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
+ services = append(services, certificateStore)
routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS)
@@ -173,7 +169,7 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize network manager")
}
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
- connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
+ connectionManager := route.NewConnectionManager(ctx, logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router)
@@ -181,8 +177,8 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "initialize router")
}
- ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
+ ntpOptions := common.PtrValueOrDefault(options.NTP)
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
@@ -345,6 +341,16 @@ func New(options Options) (*Box, error) {
timeService.TimeService = ntpService
services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
}
+ mitmOptions := common.PtrValueOrDefault(options.MITM)
+ var mitmEngine adapter.MITMEngine
+ if mitmOptions.Enabled {
+ engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions)
+ if err != nil {
+ return nil, E.Cause(err, "create MITM engine")
+ }
+ service.MustRegister[adapter.MITMEngine](ctx, engine)
+ mitmEngine = engine
+ }
return &Box{
network: networkManager,
endpoint: endpointManager,
@@ -354,6 +360,7 @@ func New(options Options) (*Box, error) {
dnsRouter: dnsRouter,
connection: connectionManager,
router: router,
+ mitm: mitmEngine,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
@@ -412,11 +419,11 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
- err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
+ err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
- err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
+ err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router, s.mitm)
if err != nil {
return err
}
@@ -440,7 +447,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
- err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint)
+ err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.inbound, s.endpoint)
if err != nil {
return err
}
@@ -448,7 +455,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
- err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
+ err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
@@ -467,7 +474,7 @@ func (s *Box) Close() error {
close(s.done)
}
err := common.Close(
- s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
+ s.inbound, s.outbound, s.endpoint, s.mitm, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
)
for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error {
diff --git a/cmd/sing-box/cmd_generate_ca.go b/cmd/sing-box/cmd_generate_ca.go
new file mode 100644
index 00000000..fcda2dd6
--- /dev/null
+++ b/cmd/sing-box/cmd_generate_ca.go
@@ -0,0 +1,120 @@
+package main
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha1"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/asn1"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/pem"
+ "math/big"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/sagernet/sing-box/log"
+ "github.com/sagernet/sing-box/option"
+ "github.com/sagernet/sing/common/json"
+
+ "github.com/spf13/cobra"
+ "software.sslmate.com/src/go-pkcs12"
+)
+
+var (
+ flagGenerateCAName string
+ flagGenerateCAPKCS12Password string
+ flagGenerateOutput string
+)
+
+var commandGenerateCAKeyPair = &cobra.Command{
+ Use: "ca-keypair",
+ Short: "Generate CA key pair",
+ Args: cobra.NoArgs,
+ Run: func(cmd *cobra.Command, args []string) {
+ err := generateCAKeyPair()
+ if err != nil {
+ log.Fatal(err)
+ }
+ },
+}
+
+func init() {
+ commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAName, "name", "n", "", "Set custom CA name")
+ commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAPKCS12Password, "p12-password", "p", "", "Set custom PKCS12 password")
+ commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateOutput, "output", "o", ".", "Set output directory")
+ commandGenerate.AddCommand(commandGenerateCAKeyPair)
+}
+
+func generateCAKeyPair() error {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return err
+ }
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return err
+ }
+ spkiASN1, err := x509.MarshalPKIXPublicKey(privateKey.Public())
+ var spki struct {
+ Algorithm pkix.AlgorithmIdentifier
+ SubjectPublicKey asn1.BitString
+ }
+ _, err = asn1.Unmarshal(spkiASN1, &spki)
+ if err != nil {
+ return err
+ }
+ skid := sha1.Sum(spki.SubjectPublicKey.Bytes)
+ var caName string
+ if flagGenerateCAName != "" {
+ caName = flagGenerateCAName
+ } else {
+ caName = "sing-box Generated CA " + strings.ToUpper(hex.EncodeToString(skid[:4]))
+ }
+ caTpl := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ Organization: []string{caName},
+ CommonName: caName,
+ },
+ SubjectKeyId: skid[:],
+ NotAfter: time.Now().AddDate(10, 0, 0),
+ NotBefore: time.Now(),
+ KeyUsage: x509.KeyUsageCertSign,
+ BasicConstraintsValid: true,
+ IsCA: true,
+ MaxPathLenZero: true,
+ }
+ publicDer, err := x509.CreateCertificate(rand.Reader, caTpl, caTpl, privateKey.Public(), privateKey)
+ var caPassword string
+ if flagGenerateCAPKCS12Password != "" {
+ caPassword = flagGenerateCAPKCS12Password
+ } else {
+ caPassword = strings.ToUpper(hex.EncodeToString(skid[:4]))
+ }
+ caTpl.Raw = publicDer
+ p12Bytes, err := pkcs12.Modern.Encode(privateKey, caTpl, nil, caPassword)
+ if err != nil {
+ return err
+ }
+ privateDer, err := x509.MarshalPKCS8PrivateKey(privateKey)
+ if err != nil {
+ return err
+ }
+ os.WriteFile(filepath.Join(flagGenerateOutput, caName+".pem"), []byte(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}))+string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateDer}))), 0o644)
+ os.WriteFile(filepath.Join(flagGenerateOutput, caName+".crt"), publicDer, 0o644)
+ os.WriteFile(filepath.Join(flagGenerateOutput, caName+".p12"), p12Bytes, 0o644)
+ var tlsDecryptionOptions option.TLSDecryptionOptions
+ tlsDecryptionOptions.Enabled = true
+ tlsDecryptionOptions.KeyPair = base64.StdEncoding.EncodeToString(p12Bytes)
+ tlsDecryptionOptions.KeyPairPassword = caPassword
+ var certificateOptions option.CertificateOptions
+ certificateOptions.TLSDecryption = &tlsDecryptionOptions
+ encoder := json.NewEncoder(os.Stdout)
+ encoder.SetIndent("", " ")
+ return encoder.Encode(certificateOptions)
+}
diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go
index 55e5b458..5f5a0d71 100644
--- a/cmd/sing-box/cmd_tools.go
+++ b/cmd/sing-box/cmd_tools.go
@@ -1,13 +1,6 @@
package main
import (
- "errors"
- "os"
-
- "github.com/sagernet/sing-box"
- E "github.com/sagernet/sing/common/exceptions"
- N "github.com/sagernet/sing/common/network"
-
"github.com/spf13/cobra"
)
@@ -19,36 +12,5 @@ var commandTools = &cobra.Command{
}
func init() {
- commandTools.PersistentFlags().StringVarP(&commandToolsFlagOutbound, "outbound", "o", "", "Use specified tag instead of default outbound")
mainCommand.AddCommand(commandTools)
}
-
-func createPreStartedClient() (*box.Box, error) {
- options, err := readConfigAndMerge()
- if err != nil {
- if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" {
- return nil, err
- }
- }
- instance, err := box.New(box.Options{Context: globalCtx, Options: options})
- if err != nil {
- return nil, E.Cause(err, "create service")
- }
- err = instance.PreStart()
- if err != nil {
- return nil, E.Cause(err, "start service")
- }
- return instance, nil
-}
-
-func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) {
- if outboundTag == "" {
- return instance.Outbound().Default(), nil
- } else {
- outbound, loaded := instance.Outbound().Outbound(outboundTag)
- if !loaded {
- return nil, E.New("outbound not found: ", outboundTag)
- }
- return outbound, nil
- }
-}
diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go
deleted file mode 100644
index d352d533..00000000
--- a/cmd/sing-box/cmd_tools_connect.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package main
-
-import (
- "context"
- "os"
-
- "github.com/sagernet/sing-box/log"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/common/task"
-
- "github.com/spf13/cobra"
-)
-
-var commandConnectFlagNetwork string
-
-var commandConnect = &cobra.Command{
- Use: "connect
",
- Short: "Connect to an address",
- Args: cobra.ExactArgs(1),
- Run: func(cmd *cobra.Command, args []string) {
- err := connect(args[0])
- if err != nil {
- log.Fatal(err)
- }
- },
-}
-
-func init() {
- commandConnect.Flags().StringVarP(&commandConnectFlagNetwork, "network", "n", "tcp", "network type")
- commandTools.AddCommand(commandConnect)
-}
-
-func connect(address string) error {
- switch N.NetworkName(commandConnectFlagNetwork) {
- case N.NetworkTCP, N.NetworkUDP:
- default:
- return E.Cause(N.ErrUnknownNetwork, commandConnectFlagNetwork)
- }
- instance, err := createPreStartedClient()
- if err != nil {
- return err
- }
- defer instance.Close()
- dialer, err := createDialer(instance, commandToolsFlagOutbound)
- if err != nil {
- return err
- }
- conn, err := dialer.DialContext(context.Background(), commandConnectFlagNetwork, M.ParseSocksaddr(address))
- if err != nil {
- return E.Cause(err, "connect to server")
- }
- var group task.Group
- group.Append("upload", func(ctx context.Context) error {
- return common.Error(bufio.Copy(conn, os.Stdin))
- })
- group.Append("download", func(ctx context.Context) error {
- return common.Error(bufio.Copy(os.Stdout, conn))
- })
- group.Cleanup(func() {
- conn.Close()
- })
- err = group.Run(context.Background())
- if E.IsClosed(err) {
- log.Info(err)
- } else {
- log.Error(err)
- }
- return nil
-}
diff --git a/cmd/sing-box/cmd_tools_fetch.go b/cmd/sing-box/cmd_tools_fetch.go
deleted file mode 100644
index 5ee3b875..00000000
--- a/cmd/sing-box/cmd_tools_fetch.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package main
-
-import (
- "context"
- "errors"
- "io"
- "net"
- "net/http"
- "net/url"
- "os"
-
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing-box/log"
- "github.com/sagernet/sing/common/bufio"
- E "github.com/sagernet/sing/common/exceptions"
- M "github.com/sagernet/sing/common/metadata"
-
- "github.com/spf13/cobra"
-)
-
-var commandFetch = &cobra.Command{
- Use: "fetch",
- Short: "Fetch an URL",
- Args: cobra.MinimumNArgs(1),
- Run: func(cmd *cobra.Command, args []string) {
- err := fetch(args)
- if err != nil {
- log.Fatal(err)
- }
- },
-}
-
-func init() {
- commandTools.AddCommand(commandFetch)
-}
-
-var (
- httpClient *http.Client
- http3Client *http.Client
-)
-
-func fetch(args []string) error {
- instance, err := createPreStartedClient()
- if err != nil {
- return err
- }
- defer instance.Close()
- httpClient = &http.Client{
- Transport: &http.Transport{
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- dialer, err := createDialer(instance, commandToolsFlagOutbound)
- if err != nil {
- return nil, err
- }
- return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
- },
- ForceAttemptHTTP2: true,
- },
- }
- defer httpClient.CloseIdleConnections()
- if C.WithQUIC {
- err = initializeHTTP3Client(instance)
- if err != nil {
- return err
- }
- defer http3Client.CloseIdleConnections()
- }
- for _, urlString := range args {
- var parsedURL *url.URL
- parsedURL, err = url.Parse(urlString)
- if err != nil {
- return err
- }
- switch parsedURL.Scheme {
- case "":
- parsedURL.Scheme = "http"
- fallthrough
- case "http", "https":
- err = fetchHTTP(httpClient, parsedURL)
- if err != nil {
- return err
- }
- case "http3":
- if !C.WithQUIC {
- return C.ErrQUICNotIncluded
- }
- parsedURL.Scheme = "https"
- err = fetchHTTP(http3Client, parsedURL)
- if err != nil {
- return err
- }
- default:
- return E.New("unsupported scheme: ", parsedURL.Scheme)
- }
- }
- return nil
-}
-
-func fetchHTTP(httpClient *http.Client, parsedURL *url.URL) error {
- request, err := http.NewRequest("GET", parsedURL.String(), nil)
- if err != nil {
- return err
- }
- request.Header.Add("User-Agent", "curl/7.88.0")
- response, err := httpClient.Do(request)
- if err != nil {
- return err
- }
- defer response.Body.Close()
- _, err = bufio.Copy(os.Stdout, response.Body)
- if errors.Is(err, io.EOF) {
- return nil
- }
- return err
-}
diff --git a/cmd/sing-box/cmd_tools_fetch_http3.go b/cmd/sing-box/cmd_tools_fetch_http3.go
deleted file mode 100644
index b7a31a72..00000000
--- a/cmd/sing-box/cmd_tools_fetch_http3.go
+++ /dev/null
@@ -1,36 +0,0 @@
-//go:build with_quic
-
-package main
-
-import (
- "context"
- "crypto/tls"
- "net/http"
-
- "github.com/sagernet/quic-go"
- "github.com/sagernet/quic-go/http3"
- box "github.com/sagernet/sing-box"
- "github.com/sagernet/sing/common/bufio"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
-)
-
-func initializeHTTP3Client(instance *box.Box) error {
- dialer, err := createDialer(instance, commandToolsFlagOutbound)
- if err != nil {
- return err
- }
- http3Client = &http.Client{
- Transport: &http3.Transport{
- Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
- destination := M.ParseSocksaddr(addr)
- udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, destination)
- if dErr != nil {
- return nil, dErr
- }
- return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsCfg, cfg)
- },
- },
- }
- return nil
-}
diff --git a/cmd/sing-box/cmd_tools_fetch_http3_stub.go b/cmd/sing-box/cmd_tools_fetch_http3_stub.go
deleted file mode 100644
index ae13f54c..00000000
--- a/cmd/sing-box/cmd_tools_fetch_http3_stub.go
+++ /dev/null
@@ -1,18 +0,0 @@
-//go:build !with_quic
-
-package main
-
-import (
- "net/url"
- "os"
-
- box "github.com/sagernet/sing-box"
-)
-
-func initializeHTTP3Client(instance *box.Box) error {
- return os.ErrInvalid
-}
-
-func fetchHTTP3(parsedURL *url.URL) error {
- return os.ErrInvalid
-}
diff --git a/cmd/sing-box/cmd_tools_install_ca.go b/cmd/sing-box/cmd_tools_install_ca.go
new file mode 100644
index 00000000..80ec14b1
--- /dev/null
+++ b/cmd/sing-box/cmd_tools_install_ca.go
@@ -0,0 +1,108 @@
+package main
+
+import (
+ "encoding/pem"
+ "errors"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+
+ "github.com/sagernet/sing-box/log"
+ E "github.com/sagernet/sing/common/exceptions"
+ "github.com/sagernet/sing/common/shell"
+
+ "github.com/spf13/cobra"
+)
+
+var commandInstallCACertificate = &cobra.Command{
+ Use: "install-ca ",
+ Short: "Install CA certificate to system",
+ Args: cobra.ExactArgs(1),
+ Run: func(cmd *cobra.Command, args []string) {
+ err := installCACertificate(args[0])
+ if err != nil {
+ log.Fatal(err)
+ }
+ },
+}
+
+func init() {
+ commandTools.AddCommand(commandInstallCACertificate)
+}
+
+func installCACertificate(path string) error {
+ switch runtime.GOOS {
+ case "windows":
+ return shell.Exec("powershell", "-Command", "Import-Certificate -FilePath \""+path+"\" -CertStoreLocation Cert:\\LocalMachine\\Root").Attach().Run()
+ case "darwin":
+ return shell.Exec("sudo", "security", "add-trusted-cert", "-d", "-r", "trustRoot", "-k", "/Library/Keychains/System.keychain", path).Attach().Run()
+ case "linux":
+ updateCertPath, updateCertPathNotFoundErr := exec.LookPath("update-ca-certificates")
+ if updateCertPathNotFoundErr == nil {
+ publicDer, err := os.ReadFile(path)
+ if err != nil {
+ return err
+ }
+ err = os.MkdirAll("/usr/local/share/ca-certificates", 0o755)
+ if err != nil {
+ if errors.Is(err, os.ErrPermission) {
+ log.Info("Try running with sudo")
+ return shell.Exec("sudo", os.Args...).Attach().Run()
+ }
+ return err
+ }
+ fileName := filepath.Base(updateCertPath)
+ if !strings.HasSuffix(fileName, ".crt") {
+ fileName = fileName + ".crt"
+ }
+ filePath, _ := filepath.Abs(filepath.Join("/usr/local/share/ca-certificates", fileName))
+ err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644)
+ if err != nil {
+ if errors.Is(err, os.ErrPermission) {
+ log.Info("Try running with sudo")
+ return shell.Exec("sudo", os.Args...).Attach().Run()
+ }
+ return err
+ }
+ log.Info("certificate written to " + filePath + "\n")
+ err = shell.Exec(updateCertPath).Attach().Run()
+ if err != nil {
+ return err
+ }
+ log.Info("certificate installed")
+ return nil
+ }
+ updateTrustPath, updateTrustPathNotFoundErr := exec.LookPath("update-ca-trust")
+ if updateTrustPathNotFoundErr == nil {
+ publicDer, err := os.ReadFile(path)
+ if err != nil {
+ return err
+ }
+ fileName := filepath.Base(updateTrustPath)
+ fileExt := filepath.Ext(path)
+ if fileExt != "" {
+ fileName = fileName[:len(fileName)-len(fileExt)]
+ }
+ filePath, _ := filepath.Abs(filepath.Join("/etc/pki/ca-trust/source/anchors/", fileName+".pem"))
+ err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644)
+ if err != nil {
+ if errors.Is(err, os.ErrPermission) {
+ log.Info("Try running with sudo")
+ return shell.Exec("sudo", os.Args...).Attach().Run()
+ }
+ return err
+ }
+ log.Info("certificate written to " + filePath + "\n")
+ err = shell.Exec(updateTrustPath, "extract").Attach().Run()
+ if err != nil {
+ return err
+ }
+ log.Info("certificate installed")
+ }
+ return E.New("update-ca-certificates or update-ca-trust not found")
+ default:
+ return E.New("unsupported operating system: ", runtime.GOOS)
+ }
+}
diff --git a/cmd/sing-box/cmd_tools_synctime.go b/cmd/sing-box/cmd_tools_synctime.go
index 09d487ef..33e38743 100644
--- a/cmd/sing-box/cmd_tools_synctime.go
+++ b/cmd/sing-box/cmd_tools_synctime.go
@@ -8,6 +8,7 @@ import (
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
+ N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/spf13/cobra"
@@ -39,20 +40,11 @@ func init() {
}
func syncTime() error {
- instance, err := createPreStartedClient()
- if err != nil {
- return err
- }
- dialer, err := createDialer(instance, commandToolsFlagOutbound)
- if err != nil {
- return err
- }
- defer instance.Close()
serverAddress := M.ParseSocksaddr(commandSyncTimeFlagServer)
if serverAddress.Port == 0 {
serverAddress.Port = 123
}
- response, err := ntp.Exchange(context.Background(), dialer, serverAddress)
+ response, err := ntp.Exchange(context.Background(), N.SystemDialer, serverAddress)
if err != nil {
return err
}
diff --git a/common/certificate/store.go b/common/certificate/store.go
index 34f20019..d17f77b5 100644
--- a/common/certificate/store.go
+++ b/common/certificate/store.go
@@ -3,6 +3,7 @@ package certificate
import (
"context"
"crypto/x509"
+ "encoding/base64"
"io/fs"
"os"
"path/filepath"
@@ -16,6 +17,8 @@ import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
+
+ "software.sslmate.com/src/go-pkcs12"
)
var _ adapter.CertificateStore = (*Store)(nil)
@@ -27,6 +30,9 @@ type Store struct {
certificatePaths []string
certificateDirectoryPaths []string
watcher *fswatch.Watcher
+ tlsDecryptionEnabled bool
+ tlsDecryptionPrivateKey any
+ tlsDecryptionCertificate *x509.Certificate
}
func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) {
@@ -90,6 +96,19 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
if err != nil {
return nil, E.Cause(err, "initializing certificate store")
}
+ if options.TLSDecryption != nil && options.TLSDecryption.Enabled {
+ pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryption.KeyPair)
+ if err != nil {
+ return nil, E.Cause(err, "decode key pair base64 bytes")
+ }
+ privateKey, certificate, err := pkcs12.Decode(pfxBytes, options.TLSDecryption.KeyPairPassword)
+ if err != nil {
+ return nil, E.Cause(err, "decode key pair")
+ }
+ store.tlsDecryptionEnabled = true
+ store.tlsDecryptionPrivateKey = privateKey
+ store.tlsDecryptionCertificate = certificate
+ }
return store, nil
}
@@ -183,3 +202,15 @@ func isSameDirSymlink(f fs.DirEntry, dir string) bool {
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}
+
+func (s *Store) TLSDecryptionEnabled() bool {
+ return s.tlsDecryptionEnabled
+}
+
+func (s *Store) TLSDecryptionCertificate() *x509.Certificate {
+ return s.tlsDecryptionCertificate
+}
+
+func (s *Store) TLSDecryptionPrivateKey() any {
+ return s.tlsDecryptionPrivateKey
+}
diff --git a/common/sniff/http.go b/common/sniff/http.go
index 0e6ab406..e7c6eb8c 100644
--- a/common/sniff/http.go
+++ b/common/sniff/http.go
@@ -18,5 +18,6 @@ func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Rea
}
metadata.Protocol = C.ProtocolHTTP
metadata.Domain = M.ParseSocksaddr(request.Host).AddrString()
+ metadata.HTTPRequest = request
return nil
}
diff --git a/common/sniff/tls.go b/common/sniff/tls.go
index 6fe430e2..27729fa2 100644
--- a/common/sniff/tls.go
+++ b/common/sniff/tls.go
@@ -21,6 +21,7 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade
if clientHello != nil {
metadata.Protocol = C.ProtocolTLS
metadata.Domain = clientHello.ServerName
+ metadata.ClientHello = clientHello
return nil
}
return err
diff --git a/common/tls/mkcert.go b/common/tls/mkcert.go
index 4e0ed102..fc9c4ab9 100644
--- a/common/tls/mkcert.go
+++ b/common/tls/mkcert.go
@@ -8,7 +8,10 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"math/big"
+ "net"
"time"
+
+ M "github.com/sagernet/sing/common/metadata"
)
func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
@@ -35,17 +38,30 @@ func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func(
if err != nil {
return
}
- template := &x509.Certificate{
- SerialNumber: serialNumber,
- NotBefore: timeFunc().Add(time.Hour * -1),
- NotAfter: expire,
- KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
- ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
- BasicConstraintsValid: true,
- Subject: pkix.Name{
- CommonName: serverName,
- },
- DNSNames: []string{serverName},
+ var template *x509.Certificate
+ if serverAddress := M.ParseAddr(serverName); serverAddress.IsValid() {
+ template = &x509.Certificate{
+ SerialNumber: serialNumber,
+ IPAddresses: []net.IP{serverAddress.AsSlice()},
+ NotBefore: timeFunc().Add(time.Hour * -1),
+ NotAfter: expire,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+ } else {
+ template = &x509.Certificate{
+ SerialNumber: serialNumber,
+ NotBefore: timeFunc().Add(time.Hour * -1),
+ NotAfter: expire,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ Subject: pkix.Name{
+ CommonName: serverName,
+ },
+ DNSNames: []string{serverName},
+ }
}
if parent == nil {
parent = template
diff --git a/experimental/clashapi/mitm.go b/experimental/clashapi/mitm.go
new file mode 100644
index 00000000..8d64081d
--- /dev/null
+++ b/experimental/clashapi/mitm.go
@@ -0,0 +1,186 @@
+package clashapi
+
+import (
+ "archive/zip"
+ "context"
+ "crypto/x509"
+ "encoding/pem"
+ "io"
+ "net/http"
+
+ "github.com/sagernet/sing-box/adapter"
+ "github.com/sagernet/sing/common"
+ "github.com/sagernet/sing/service"
+
+ "github.com/go-chi/chi/v5"
+ "github.com/go-chi/render"
+ "github.com/gofrs/uuid/v5"
+ "howett.net/plist"
+)
+
+func mitmRouter(ctx context.Context) http.Handler {
+ r := chi.NewRouter()
+ r.Get("/mobileconfig", getMobileConfig(ctx))
+ r.Get("/crt", getCertificate(ctx))
+ r.Get("/pem", getCertificatePEM(ctx))
+ r.Get("/magisk", getMagiskModule(ctx))
+ return r
+}
+
+func getMobileConfig(ctx context.Context) http.HandlerFunc {
+ return func(writer http.ResponseWriter, request *http.Request) {
+ store := service.FromContext[adapter.CertificateStore](ctx)
+ if !store.TLSDecryptionEnabled() {
+ http.NotFound(writer, request)
+ render.PlainText(writer, request, "TLS decryption not enabled")
+ return
+ }
+ certificate := store.TLSDecryptionCertificate()
+ writer.Header().Set("Content-Type", "application/x-apple-aspen-config")
+ uuidGen := common.Must1(uuid.NewV4()).String()
+ mobileConfig := map[string]interface{}{
+ "PayloadContent": []interface{}{
+ map[string]interface{}{
+ "PayloadCertificateFileName": "Certificates.cer",
+ "PayloadContent": certificate.Raw,
+ "PayloadDescription": "Adds a root certificate",
+ "PayloadDisplayName": certificate.Subject.CommonName,
+ "PayloadIdentifier": "com.apple.security.root." + uuidGen,
+ "PayloadType": "com.apple.security.root",
+ "PayloadUUID": uuidGen,
+ "PayloadVersion": 1,
+ },
+ },
+ "PayloadDisplayName": certificate.Subject.CommonName,
+ "PayloadIdentifier": "io.nekohasekai.sfa.ca.profile." + uuidGen,
+ "PayloadRemovalDisallowed": false,
+ "PayloadType": "Configuration",
+ "PayloadUUID": uuidGen,
+ "PayloadVersion": 1,
+ }
+ encoder := plist.NewEncoder(writer)
+ encoder.Indent("\t")
+ encoder.Encode(mobileConfig)
+ }
+}
+
+func getCertificate(ctx context.Context) http.HandlerFunc {
+ return func(writer http.ResponseWriter, request *http.Request) {
+ store := service.FromContext[adapter.CertificateStore](ctx)
+ if !store.TLSDecryptionEnabled() {
+ http.NotFound(writer, request)
+ render.PlainText(writer, request, "TLS decryption not enabled")
+ return
+ }
+ writer.Header().Set("Content-Type", "application/x-x509-ca-cert")
+ writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.crt")
+ writer.Write(store.TLSDecryptionCertificate().Raw)
+ }
+}
+
+func getCertificatePEM(ctx context.Context) http.HandlerFunc {
+ return func(writer http.ResponseWriter, request *http.Request) {
+ store := service.FromContext[adapter.CertificateStore](ctx)
+ if !store.TLSDecryptionEnabled() {
+ http.NotFound(writer, request)
+ render.PlainText(writer, request, "TLS decryption not enabled")
+ return
+ }
+ writer.Header().Set("Content-Type", "application/x-pem-file")
+ writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.pem")
+ writer.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: store.TLSDecryptionCertificate().Raw}))
+ }
+}
+
+func getMagiskModule(ctx context.Context) http.HandlerFunc {
+ return func(writer http.ResponseWriter, request *http.Request) {
+ store := service.FromContext[adapter.CertificateStore](ctx)
+ if !store.TLSDecryptionEnabled() {
+ http.NotFound(writer, request)
+ render.PlainText(writer, request, "TLS decryption not enabled")
+ return
+ }
+ writer.Header().Set("Content-Type", "application/zip")
+ writer.Header().Set("Content-Disposition", "attachment; filename="+store.TLSDecryptionCertificate().Subject.CommonName+".zip")
+ createMagiskModule(writer, store.TLSDecryptionCertificate())
+ }
+}
+
+func createMagiskModule(writer io.Writer, certificate *x509.Certificate) error {
+ zipWriter := zip.NewWriter(writer)
+ defer zipWriter.Close()
+ moduleProp, err := zipWriter.Create("module.prop")
+ if err != nil {
+ return err
+ }
+ _, err = moduleProp.Write([]byte(`
+id=sing-box-certificate
+name=` + certificate.Subject.CommonName + `
+version=v0.0.1
+versionCode=1
+author=sing-box
+description=This module adds ` + certificate.Subject.CommonName + ` to the system trust store.
+`))
+ if err != nil {
+ return err
+ }
+ certificateFile, err := zipWriter.Create("system/etc/security/cacerts/" + certificate.Subject.CommonName + ".pem")
+ if err != nil {
+ return err
+ }
+ err = pem.Encode(certificateFile, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw})
+ if err != nil {
+ return err
+ }
+ updateBinary, err := zipWriter.Create("META-INF/com/google/android/update-binary")
+ if err != nil {
+ return err
+ }
+ _, err = updateBinary.Write([]byte(`
+#!/sbin/sh
+
+#################
+# Initialization
+#################
+
+umask 022
+
+# echo before loading util_functions
+ui_print() { echo "$1"; }
+
+require_new_magisk() {
+ ui_print "*******************************"
+ ui_print " Please install Magisk v20.4+! "
+ ui_print "*******************************"
+ exit 1
+}
+
+#########################
+# Load util_functions.sh
+#########################
+
+OUTFD=$2
+ZIPFILE=$3
+
+mount /data 2>/dev/null
+
+[ -f /data/adb/magisk/util_functions.sh ] || require_new_magisk
+. /data/adb/magisk/util_functions.sh
+[ $MAGISK_VER_CODE -lt 20400 ] && require_new_magisk
+
+install_module
+exit 0
+`))
+ if err != nil {
+ return err
+ }
+ updaterScript, err := zipWriter.Create("META-INF/com/google/android/updater-script")
+ if err != nil {
+ return err
+ }
+ _, err = updaterScript.Write([]byte("#MAGISK"))
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go
index e6d8c4cf..e471a13f 100644
--- a/experimental/clashapi/server.go
+++ b/experimental/clashapi/server.go
@@ -124,6 +124,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx))
r.Mount("/dns", dnsRouter(s.dnsRouter))
+ r.Mount("/mitm", mitmRouter(ctx))
s.setupMetaAPI(r)
})
diff --git a/go.mod b/go.mod
index 9a970b46..d27b976d 100644
--- a/go.mod
+++ b/go.mod
@@ -53,6 +53,7 @@ require (
google.golang.org/grpc v1.70.0
google.golang.org/protobuf v1.36.5
howett.net/plist v1.0.1
+ software.sslmate.com/src/go-pkcs12 v0.4.0
)
//replace github.com/sagernet/sing => ../sing
diff --git a/log/log.go b/log/log.go
index 7b8f2843..6ee79677 100644
--- a/log/log.go
+++ b/log/log.go
@@ -10,6 +10,10 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)
+const (
+ DefaultTimeFormat = "-0700 2006-01-02 15:04:05"
+)
+
type Options struct {
Context context.Context
Options option.LogOptions
@@ -47,7 +51,7 @@ func New(options Options) (Factory, error) {
DisableColors: logOptions.DisableColor || logFilePath != "",
DisableTimestamp: !logOptions.Timestamp && logFilePath != "",
FullTimestamp: logOptions.Timestamp,
- TimestampFormat: "-0700 2006-01-02 15:04:05",
+ TimestampFormat: DefaultTimeFormat,
}
factory := NewDefaultFactory(
options.Context,
diff --git a/mitm/constants.go b/mitm/constants.go
new file mode 100644
index 00000000..e3a9a46b
--- /dev/null
+++ b/mitm/constants.go
@@ -0,0 +1,11 @@
+package mitm
+
+import (
+ "encoding/base64"
+
+ "github.com/sagernet/sing/common"
+)
+
+var surgeTinyGif = common.OnceValue(func() []byte {
+ return common.Must1(base64.StdEncoding.DecodeString("R0lGODlhAQABAAAAACH5BAEAAAAALAAAAAABAAEAAAIBAAA="))
+})
diff --git a/mitm/engine.go b/mitm/engine.go
new file mode 100644
index 00000000..a9acbc47
--- /dev/null
+++ b/mitm/engine.go
@@ -0,0 +1,811 @@
+package mitm
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/tls"
+ "io"
+ "math"
+ "mime"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+ "unicode"
+
+ "github.com/sagernet/sing-box/adapter"
+ "github.com/sagernet/sing-box/common/dialer"
+ sTLS "github.com/sagernet/sing-box/common/tls"
+ "github.com/sagernet/sing-box/option"
+ "github.com/sagernet/sing/common"
+ "github.com/sagernet/sing/common/atomic"
+ E "github.com/sagernet/sing/common/exceptions"
+ F "github.com/sagernet/sing/common/format"
+ "github.com/sagernet/sing/common/logger"
+ M "github.com/sagernet/sing/common/metadata"
+ N "github.com/sagernet/sing/common/network"
+ "github.com/sagernet/sing/common/ntp"
+ sHTTP "github.com/sagernet/sing/protocol/http"
+ "github.com/sagernet/sing/service"
+
+ "golang.org/x/net/http2"
+)
+
+var _ adapter.MITMEngine = (*Engine)(nil)
+
+type Engine struct {
+ ctx context.Context
+ logger logger.ContextLogger
+ connection adapter.ConnectionManager
+ certificate adapter.CertificateStore
+ timeFunc func() time.Time
+ http2Enabled bool
+}
+
+func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) {
+ engine := &Engine{
+ ctx: ctx,
+ logger: logger,
+ http2Enabled: options.HTTP2Enabled,
+ }
+ return engine, nil
+}
+
+func (e *Engine) Start(stage adapter.StartStage) error {
+ switch stage {
+ case adapter.StartStateInitialize:
+ e.connection = service.FromContext[adapter.ConnectionManager](e.ctx)
+ e.certificate = service.FromContext[adapter.CertificateStore](e.ctx)
+ e.timeFunc = ntp.TimeFuncFromContext(e.ctx)
+ if e.timeFunc == nil {
+ e.timeFunc = time.Now
+ }
+ }
+ return nil
+}
+
+func (e *Engine) Close() error {
+ return nil
+}
+
+func (e *Engine) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
+ if e.certificate.TLSDecryptionEnabled() && metadata.ClientHello != nil {
+ err := e.newTLS(ctx, this, conn, metadata, onClose)
+ if err != nil {
+ e.logger.ErrorContext(ctx, err)
+ } else {
+ e.logger.DebugContext(ctx, "connection closed")
+ }
+ if onClose != nil {
+ onClose(err)
+ }
+ return
+ } else if metadata.HTTPRequest != nil {
+ err := e.newHTTP1(ctx, this, conn, nil, metadata)
+ if err != nil {
+ e.logger.ErrorContext(ctx, err)
+ } else {
+ e.logger.DebugContext(ctx, "connection closed")
+ }
+ if onClose != nil {
+ onClose(err)
+ }
+ return
+ } else {
+ e.logger.DebugContext(ctx, "HTTP and TLS not detected, skipped")
+ }
+ metadata.MITM = nil
+ e.connection.NewConnection(ctx, this, conn, metadata, onClose)
+}
+
+func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
+ acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1")
+ acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2")
+ if !acceptHTTP && !acceptH2 {
+ metadata.MITM = nil
+ e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ","))
+ e.connection.NewConnection(ctx, this, conn, metadata, onClose)
+ return nil
+ }
+ var nextProtos []string
+ if acceptH2 {
+ nextProtos = append(nextProtos, "h2")
+ } else if acceptHTTP {
+ nextProtos = append(nextProtos, "http/1.1")
+ }
+ var (
+ maxVersion uint16
+ minVersion uint16
+ )
+ for _, version := range metadata.ClientHello.SupportedVersions {
+ maxVersion = common.Max(maxVersion, version)
+ minVersion = common.Min(minVersion, version)
+ }
+ serverName := metadata.ClientHello.ServerName
+ if serverName == "" && metadata.Destination.IsIP() {
+ serverName = metadata.Destination.Addr.String()
+ }
+ tlsConfig := &tls.Config{
+ Time: e.timeFunc,
+ ServerName: serverName,
+ NextProtos: nextProtos,
+ MinVersion: minVersion,
+ MaxVersion: maxVersion,
+ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return sTLS.GenerateKeyPair(e.certificate.TLSDecryptionCertificate(), e.certificate.TLSDecryptionPrivateKey(), e.timeFunc, serverName)
+ },
+ }
+ tlsConn := tls.Server(conn, tlsConfig)
+ err := tlsConn.HandshakeContext(ctx)
+ if err != nil {
+ return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", "))
+ }
+ if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
+ return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose)
+ } else {
+ return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata)
+ }
+}
+
+func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error {
+ options := metadata.MITM
+ defer conn.Close()
+ reader := bufio.NewReader(conn)
+ request, err := sHTTP.ReadRequest(reader)
+ if err != nil {
+ return E.Cause(err, "read HTTP request")
+ }
+ rawRequestURL := request.URL
+ if tlsConfig != nil {
+ rawRequestURL.Scheme = "https"
+ } else {
+ rawRequestURL.Scheme = "http"
+ }
+ if rawRequestURL.Host == "" {
+ rawRequestURL.Host = request.Host
+ }
+ requestURL := rawRequestURL.String()
+ request.RequestURI = ""
+ var requestMatch bool
+ var body []byte
+ if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
+ body, err = io.ReadAll(request.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ request.Body = io.NopCloser(bytes.NewReader(body))
+ }
+ if options.Print {
+ e.printRequest(ctx, request, body)
+ }
+ for i, rule := range options.SurgeURLRewrite {
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
+ if rule.Reject {
+ return E.New("request rejected by url_rewrite")
+ } else if rule.Redirect {
+ w := new(simpleResponseWriter)
+ http.Redirect(w, request, rule.Destination.String(), http.StatusFound)
+ err = w.Build(request).Write(conn)
+ if err != nil {
+ return E.Cause(err, "write url_rewrite 302 response")
+ }
+ return nil
+ }
+ requestMatch = true
+ request.URL = rule.Destination
+ newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
+ if newDestination.Port == 0 {
+ newDestination.Port = metadata.Destination.Port
+ }
+ metadata.Destination = newDestination
+ if tlsConfig != nil {
+ tlsConfig.ServerName = rule.Destination.Hostname()
+ }
+ break
+ }
+ for i, rule := range options.SurgeHeaderRewrite {
+ if rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+ switch {
+ case rule.Add:
+ if strings.ToLower(rule.Key) == "host" {
+ request.Host = rule.Value
+ continue
+ }
+ request.Header.Add(rule.Key, rule.Value)
+ case rule.Delete:
+ request.Header.Del(rule.Key)
+ case rule.Replace:
+ if request.Header.Get(rule.Key) != "" {
+ request.Header.Set(rule.Key, rule.Value)
+ }
+ case rule.ReplaceRegex:
+ if value := request.Header.Get(rule.Key); value != "" {
+ request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+ }
+ }
+ }
+ for i, rule := range options.SurgeBodyRewrite {
+ if rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+ if body == nil {
+ if request.ContentLength <= 0 {
+ e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+ break
+ } else if request.ContentLength > 131072 {
+ e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+ break
+ }
+ body, err = io.ReadAll(request.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ }
+ for mi := 0; i < len(rule.Match); i++ {
+ body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+ }
+ request.Body = io.NopCloser(bytes.NewReader(body))
+ request.ContentLength = int64(len(body))
+ }
+ if !requestMatch {
+ for i, rule := range options.SurgeMapLocal {
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
+ var (
+ statusCode = http.StatusOK
+ headers = make(http.Header)
+ )
+ if rule.StatusCode > 0 {
+ statusCode = rule.StatusCode
+ }
+ switch {
+ case rule.File:
+ resource, err := os.ReadFile(rule.Data)
+ if err != nil {
+ return E.Cause(err, "open map local source")
+ }
+ mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
+ if mimeType == "" {
+ mimeType = "application/octet-stream"
+ }
+ headers.Set("Content-Type", mimeType)
+ body = resource
+ case rule.Text:
+ headers.Set("Content-Type", "text/plain")
+ body = []byte(rule.Data)
+ case rule.TinyGif:
+ headers.Set("Content-Type", "image/gif")
+ body = surgeTinyGif()
+ case rule.Base64:
+ headers.Set("Content-Type", "application/octet-stream")
+ body = rule.Base64Data
+ }
+ response := &http.Response{
+ StatusCode: statusCode,
+ Status: http.StatusText(statusCode),
+ Proto: request.Proto,
+ ProtoMajor: request.ProtoMajor,
+ ProtoMinor: request.ProtoMinor,
+ Header: headers,
+ Body: io.NopCloser(bytes.NewReader(body)),
+ }
+ err = response.Write(conn)
+ if err != nil {
+ return E.Cause(err, "write map local response")
+ }
+ return nil
+ }
+ }
+ ctx = adapter.WithContext(ctx, &metadata)
+ var innerErr atomic.TypedValue[error]
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+ if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
+ return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+ } else {
+ return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+ }
+ },
+ TLSClientConfig: tlsConfig,
+ },
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+ defer httpClient.CloseIdleConnections()
+ requestCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ response, err := httpClient.Do(request.WithContext(requestCtx))
+ if err != nil {
+ cancel()
+ return E.Errors(innerErr.Load(), err)
+ }
+ var responseMatch bool
+ var responseBody []byte
+ if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
+ responseBody, err = io.ReadAll(response.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP response body")
+ }
+ response.Body = io.NopCloser(bytes.NewReader(responseBody))
+ }
+ if options.Print {
+ e.printResponse(ctx, request, response, responseBody)
+ }
+ for i, rule := range options.SurgeHeaderRewrite {
+ if !rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ responseMatch = true
+ e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+ switch {
+ case rule.Add:
+ response.Header.Add(rule.Key, rule.Value)
+ case rule.Delete:
+ response.Header.Del(rule.Key)
+ case rule.Replace:
+ if response.Header.Get(rule.Key) != "" {
+ response.Header.Set(rule.Key, rule.Value)
+ }
+ case rule.ReplaceRegex:
+ if value := response.Header.Get(rule.Key); value != "" {
+ response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+ }
+ }
+ }
+ for i, rule := range options.SurgeBodyRewrite {
+ if !rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ responseMatch = true
+ e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+ if responseBody == nil {
+ if response.ContentLength <= 0 {
+ e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+ break
+ } else if response.ContentLength > 131072 {
+ e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+ break
+ }
+ responseBody, err = io.ReadAll(response.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ }
+ for mi := 0; i < len(rule.Match); i++ {
+ responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
+ }
+ response.Body = io.NopCloser(bytes.NewReader(responseBody))
+ response.ContentLength = int64(len(responseBody))
+ }
+ if !options.Print && !requestMatch && !responseMatch {
+ e.logger.WarnContext(ctx, "request not modified")
+ }
+ err = response.Write(conn)
+ if err != nil {
+ return E.Errors(E.Cause(err, "write HTTP response"), innerErr.Load())
+ } else if innerErr.Load() != nil {
+ return E.Cause(innerErr.Load(), "write HTTP response")
+ }
+ return nil
+}
+
+func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
+ httpTransport := &http.Transport{
+ ForceAttemptHTTP2: true,
+ DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+ ctx = adapter.WithContext(ctx, &metadata)
+ if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
+ return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
+ } else {
+ return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+ }
+ },
+ TLSClientConfig: tlsConfig,
+ }
+ err := http2.ConfigureTransport(httpTransport)
+ if err != nil {
+ return E.Cause(err, "configure HTTP/2 transport")
+ }
+ handler := &engineHandler{
+ Engine: e,
+ conn: conn,
+ tlsConfig: tlsConfig,
+ dialer: this,
+ metadata: metadata,
+ httpClient: &http.Client{
+ Transport: httpTransport,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ },
+ onClose: onClose,
+ }
+ http2Server := &http2.Server{
+ MaxReadFrameSize: math.MaxUint32,
+ }
+ http2Server.ServeConn(conn, &http2.ServeConnOpts{
+ Context: ctx,
+ Handler: handler,
+ })
+ return nil
+}
+
+type engineHandler struct {
+ *Engine
+ conn net.Conn
+ tlsConfig *tls.Config
+ dialer N.Dialer
+ metadata adapter.InboundContext
+ onClose N.CloseHandlerFunc
+ httpClient *http.Client
+}
+
+func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+ err := e.serveHTTP(request.Context(), writer, request)
+ if err != nil {
+ if E.IsClosedOrCanceled(err) {
+ e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
+ } else {
+ e.logger.ErrorContext(request.Context(), err)
+ }
+ }
+}
+
+func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error {
+ options := e.metadata.MITM
+ rawRequestURL := request.URL
+ rawRequestURL.Scheme = "https"
+ if rawRequestURL.Host == "" {
+ rawRequestURL.Host = request.Host
+ }
+ requestURL := rawRequestURL.String()
+ request.RequestURI = ""
+ var requestMatch bool
+ var (
+ body []byte
+ err error
+ )
+ if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 {
+ body, err = io.ReadAll(request.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ request.Body.Close()
+ request.Body = io.NopCloser(bytes.NewReader(body))
+ }
+ if options.Print {
+ e.printRequest(ctx, request, body)
+ }
+ for i, rule := range options.SurgeURLRewrite {
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
+ if rule.Reject {
+ return E.New("request rejected by url_rewrite")
+ } else if rule.Redirect {
+ http.Redirect(writer, request, rule.Destination.String(), http.StatusFound)
+ return nil
+ }
+ requestMatch = true
+ request.URL = rule.Destination
+ newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
+ if newDestination.Port == 0 {
+ newDestination.Port = e.metadata.Destination.Port
+ }
+ e.metadata.Destination = newDestination
+ e.tlsConfig.ServerName = rule.Destination.Hostname()
+ break
+ }
+ for i, rule := range options.SurgeHeaderRewrite {
+ if rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+ switch {
+ case rule.Add:
+ if strings.ToLower(rule.Key) == "host" {
+ request.Host = rule.Value
+ continue
+ }
+ request.Header.Add(rule.Key, rule.Value)
+ case rule.Delete:
+ request.Header.Del(rule.Key)
+ case rule.Replace:
+ if request.Header.Get(rule.Key) != "" {
+ request.Header.Set(rule.Key, rule.Value)
+ }
+ case rule.ReplaceRegex:
+ if value := request.Header.Get(rule.Key); value != "" {
+ request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+ }
+ }
+ }
+ for i, rule := range options.SurgeBodyRewrite {
+ if rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+ var body []byte
+ if request.ContentLength <= 0 {
+ e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+ break
+ } else if request.ContentLength > 131072 {
+ e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+ break
+ }
+ body, err := io.ReadAll(request.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ request.Body.Close()
+ for mi := 0; i < len(rule.Match); i++ {
+ body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
+ }
+ request.Body = io.NopCloser(bytes.NewReader(body))
+ request.ContentLength = int64(len(body))
+ }
+ if !requestMatch {
+ for i, rule := range options.SurgeMapLocal {
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ requestMatch = true
+ e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
+ go func() {
+ io.Copy(io.Discard, request.Body)
+ request.Body.Close()
+ }()
+ var (
+ statusCode = http.StatusOK
+ headers = make(http.Header)
+ body []byte
+ )
+ if rule.StatusCode > 0 {
+ statusCode = rule.StatusCode
+ }
+ switch {
+ case rule.File:
+ resource, err := os.ReadFile(rule.Data)
+ if err != nil {
+ return E.Cause(err, "open map local source")
+ }
+ mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
+ if mimeType == "" {
+ mimeType = "application/octet-stream"
+ }
+ headers.Set("Content-Type", mimeType)
+ body = resource
+ case rule.Text:
+ headers.Set("Content-Type", "text/plain")
+ body = []byte(rule.Data)
+ case rule.TinyGif:
+ headers.Set("Content-Type", "image/gif")
+ body = surgeTinyGif()
+ case rule.Base64:
+ headers.Set("Content-Type", "application/octet-stream")
+ body = rule.Base64Data
+ }
+ for key, values := range headers {
+ writer.Header()[key] = values
+ }
+ writer.WriteHeader(statusCode)
+ _, err = writer.Write(body)
+ if err != nil {
+ return E.Cause(err, "write map local response")
+ }
+ return nil
+ }
+ }
+ requestCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ response, err := e.httpClient.Do(request.WithContext(requestCtx))
+ if err != nil {
+ cancel()
+ return E.Cause(err, "exchange request")
+ }
+ var responseMatch bool
+ var responseBody []byte
+ if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 {
+ responseBody, err = io.ReadAll(response.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP response body")
+ }
+ response.Body.Close()
+ response.Body = io.NopCloser(bytes.NewReader(responseBody))
+ }
+ if options.Print {
+ e.printResponse(ctx, request, response, responseBody)
+ }
+ for i, rule := range options.SurgeHeaderRewrite {
+ if !rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ responseMatch = true
+ e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
+ switch {
+ case rule.Add:
+ response.Header.Add(rule.Key, rule.Value)
+ case rule.Delete:
+ response.Header.Del(rule.Key)
+ case rule.Replace:
+ if response.Header.Get(rule.Key) != "" {
+ response.Header.Set(rule.Key, rule.Value)
+ }
+ case rule.ReplaceRegex:
+ if value := response.Header.Get(rule.Key); value != "" {
+ response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
+ }
+ }
+ }
+ for i, rule := range options.SurgeBodyRewrite {
+ if !rule.Response {
+ continue
+ }
+ if !rule.Pattern.MatchString(requestURL) {
+ continue
+ }
+ responseMatch = true
+ e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
+ if responseBody == nil {
+ if response.ContentLength <= 0 {
+ e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
+ break
+ } else if response.ContentLength > 131072 {
+ e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
+ break
+ }
+ responseBody, err = io.ReadAll(response.Body)
+ if err != nil {
+ return E.Cause(err, "read HTTP request body")
+ }
+ response.Body.Close()
+ }
+ for mi := 0; i < len(rule.Match); i++ {
+ responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i]))
+ }
+ response.Body = io.NopCloser(bytes.NewReader(responseBody))
+ response.ContentLength = int64(len(responseBody))
+ }
+ if !options.Print && !requestMatch && !responseMatch {
+ e.logger.WarnContext(ctx, "request not modified")
+ }
+ for key, values := range response.Header {
+ writer.Header()[key] = values
+ }
+ writer.WriteHeader(response.StatusCode)
+ _, err = io.Copy(writer, response.Body)
+ response.Body.Close()
+ if err != nil {
+ return E.Cause(err, "write HTTP response")
+ }
+ return nil
+}
+
+func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) {
+ var builder strings.Builder
+ builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL))
+ builder.WriteString("\n")
+ if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host {
+ builder.WriteString("Host: ")
+ builder.WriteString(request.Host)
+ builder.WriteString("\n")
+ }
+ for key, values := range request.Header {
+ for _, value := range values {
+ builder.WriteString(key)
+ builder.WriteString(": ")
+ builder.WriteString(value)
+ builder.WriteString("\n")
+ }
+ }
+ if len(body) > 0 {
+ builder.WriteString("\n")
+ if !bytes.ContainsFunc(body, func(r rune) bool {
+ return !unicode.IsPrint(r) && !unicode.IsSpace(r)
+ }) {
+ builder.Write(body)
+ } else {
+ builder.WriteString("(body not printable)")
+ }
+ }
+ e.logger.InfoContext(ctx, "request: ", builder.String())
+}
+
+func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) {
+ var builder strings.Builder
+ builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL))
+ builder.WriteString("\n")
+ for key, values := range response.Header {
+ for _, value := range values {
+ builder.WriteString(key)
+ builder.WriteString(": ")
+ builder.WriteString(value)
+ builder.WriteString("\n")
+ }
+ }
+ if len(body) > 0 {
+ builder.WriteString("\n")
+ if !bytes.ContainsFunc(body, func(r rune) bool {
+ return !unicode.IsPrint(r) && !unicode.IsSpace(r)
+ }) {
+ builder.Write(body)
+ } else {
+ builder.WriteString("(body not printable)")
+ }
+ }
+ e.logger.InfoContext(ctx, "response: ", builder.String())
+}
+
+type simpleResponseWriter struct {
+ statusCode int
+ header http.Header
+ body bytes.Buffer
+}
+
+func (w *simpleResponseWriter) Build(request *http.Request) *http.Response {
+ return &http.Response{
+ StatusCode: w.statusCode,
+ Status: http.StatusText(w.statusCode),
+ Proto: request.Proto,
+ ProtoMajor: request.ProtoMajor,
+ ProtoMinor: request.ProtoMinor,
+ Header: w.header,
+ Body: io.NopCloser(&w.body),
+ }
+}
+
+func (w *simpleResponseWriter) Header() http.Header {
+ if w.header == nil {
+ w.header = make(http.Header)
+ }
+ return w.header
+}
+
+func (w *simpleResponseWriter) Write(b []byte) (int, error) {
+ return w.body.Write(b)
+}
+
+func (w *simpleResponseWriter) WriteHeader(statusCode int) {
+ w.statusCode = statusCode
+}
diff --git a/option/certificate.go b/option/certificate.go
index ab524b99..2206a787 100644
--- a/option/certificate.go
+++ b/option/certificate.go
@@ -11,6 +11,13 @@ type _CertificateOptions struct {
Certificate badoption.Listable[string] `json:"certificate,omitempty"`
CertificatePath badoption.Listable[string] `json:"certificate_path,omitempty"`
CertificateDirectoryPath badoption.Listable[string] `json:"certificate_directory_path,omitempty"`
+ TLSDecryption *TLSDecryptionOptions `json:"tls_decryption,omitempty"`
+}
+
+type TLSDecryptionOptions struct {
+ Enabled bool `json:"enabled,omitempty"`
+ KeyPair string `json:"key_pair_p12,omitempty"`
+ KeyPairPassword string `json:"key_pair_p12_password,omitempty"`
}
type CertificateOptions _CertificateOptions
diff --git a/option/mitm.go b/option/mitm.go
new file mode 100644
index 00000000..171c1057
--- /dev/null
+++ b/option/mitm.go
@@ -0,0 +1,19 @@
+package option
+
+import (
+ "github.com/sagernet/sing/common/json/badoption"
+)
+
+type MITMOptions struct {
+ Enabled bool `json:"enabled,omitempty"`
+ HTTP2Enabled bool `json:"http2_enabled,omitempty"`
+}
+
+type MITMRouteOptions struct {
+ Enabled bool `json:"enabled,omitempty"`
+ Print bool `json:"print,omitempty"`
+ SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"surge_url_rewrite,omitempty"`
+ SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"surge_header_rewrite,omitempty"`
+ SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"surge_body_rewrite,omitempty"`
+ SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"surge_map_local,omitempty"`
+}
diff --git a/option/mitm_surge_urlrewrite.go b/option/mitm_surge_urlrewrite.go
new file mode 100644
index 00000000..ffec1917
--- /dev/null
+++ b/option/mitm_surge_urlrewrite.go
@@ -0,0 +1,449 @@
+package option
+
+import (
+ "encoding/base64"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strconv"
+ "strings"
+ "unicode"
+
+ "github.com/sagernet/sing/common"
+ E "github.com/sagernet/sing/common/exceptions"
+ F "github.com/sagernet/sing/common/format"
+ "github.com/sagernet/sing/common/json"
+)
+
+type SurgeURLRewriteLine struct {
+ Pattern *regexp.Regexp
+ Destination *url.URL
+ Redirect bool
+ Reject bool
+}
+
+func (l SurgeURLRewriteLine) String() string {
+ var fields []string
+ fields = append(fields, l.Pattern.String())
+ if l.Reject {
+ fields = append(fields, "_")
+ } else {
+ fields = append(fields, l.Destination.String())
+ }
+ switch {
+ case l.Redirect:
+ fields = append(fields, "302")
+ case l.Reject:
+ fields = append(fields, "reject")
+ default:
+ fields = append(fields, "header")
+ }
+ return encodeSurgeKeys(fields)
+}
+
+func (l SurgeURLRewriteLine) MarshalJSON() ([]byte, error) {
+ return json.Marshal(l.String())
+}
+
+func (l *SurgeURLRewriteLine) UnmarshalJSON(bytes []byte) error {
+ var stringValue string
+ err := json.Unmarshal(bytes, &stringValue)
+ if err != nil {
+ return err
+ }
+ fields, err := surgeFields(stringValue)
+ if err != nil {
+ return E.Cause(err, "invalid surge_url_rewrite line: ", stringValue)
+ } else if len(fields) < 2 || len(fields) > 3 {
+ return E.New("invalid surge_url_rewrite line: ", stringValue)
+ }
+ pattern, err := regexp.Compile(fields[0].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_url_rewrite line: invalid pattern: ", stringValue)
+ }
+ l.Pattern = pattern
+ l.Destination, err = url.Parse(fields[1].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_url_rewrite line: invalid destination: ", stringValue)
+ }
+ if len(fields) == 3 {
+ switch fields[2].Key {
+ case "header":
+ case "302":
+ l.Redirect = true
+ case "reject":
+ l.Reject = true
+ default:
+ return E.New("invalid surge_url_rewrite line: invalid action: ", stringValue)
+ }
+ }
+ return nil
+}
+
+type SurgeHeaderRewriteLine struct {
+ Response bool
+ Pattern *regexp.Regexp
+ Add bool
+ Delete bool
+ Replace bool
+ ReplaceRegex bool
+ Key string
+ Match *regexp.Regexp
+ Value string
+}
+
+func (l SurgeHeaderRewriteLine) String() string {
+ var fields []string
+ if !l.Response {
+ fields = append(fields, "http-request")
+ } else {
+ fields = append(fields, "http-response")
+ }
+ fields = append(fields, l.Pattern.String())
+ if l.Add {
+ fields = append(fields, "header-add")
+ } else if l.Delete {
+ fields = append(fields, "header-del")
+ } else if l.Replace {
+ fields = append(fields, "header-replace")
+ } else if l.ReplaceRegex {
+ fields = append(fields, "header-replace-regex")
+ }
+ fields = append(fields, l.Key)
+ if l.Add || l.Replace {
+ fields = append(fields, l.Value)
+ } else if l.ReplaceRegex {
+ fields = append(fields, l.Match.String(), l.Value)
+ }
+ return encodeSurgeKeys(fields)
+}
+
+func (l SurgeHeaderRewriteLine) MarshalJSON() ([]byte, error) {
+ return json.Marshal(l.String())
+}
+
+func (l *SurgeHeaderRewriteLine) UnmarshalJSON(bytes []byte) error {
+ var stringValue string
+ err := json.Unmarshal(bytes, &stringValue)
+ if err != nil {
+ return err
+ }
+ fields, err := surgeFields(stringValue)
+ if err != nil {
+ return E.Cause(err, "invalid surge_header_rewrite line: ", stringValue)
+ } else if len(fields) < 4 {
+ return E.New("invalid surge_header_rewrite line: ", stringValue)
+ }
+ switch fields[0].Key {
+ case "http-request":
+ case "http-response":
+ l.Response = true
+ default:
+ return E.New("invalid surge_header_rewrite line: invalid type: ", stringValue)
+ }
+ l.Pattern, err = regexp.Compile(fields[1].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_header_rewrite line: invalid pattern: ", stringValue)
+ }
+ switch fields[2].Key {
+ case "header-add":
+ l.Add = true
+ if len(fields) != 5 {
+ return E.New("invalid surge_header_rewrite line: " + stringValue)
+ }
+ l.Key = fields[3].Key
+ l.Value = fields[4].Key
+ case "header-del":
+ l.Delete = true
+ l.Key = fields[3].Key
+ case "header-replace":
+ l.Replace = true
+ if len(fields) != 5 {
+ return E.New("invalid surge_header_rewrite line: " + stringValue)
+ }
+ l.Key = fields[3].Key
+ l.Value = fields[4].Key
+ case "header-replace-regex":
+ l.ReplaceRegex = true
+ if len(fields) != 6 {
+ return E.New("invalid surge_header_rewrite line: " + stringValue)
+ }
+ l.Key = fields[3].Key
+ l.Match, err = regexp.Compile(fields[4].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_header_rewrite line: invalid match: ", stringValue)
+ }
+ l.Value = fields[5].Key
+ default:
+ return E.New("invalid surge_header_rewrite line: invalid action: ", stringValue)
+ }
+ return nil
+}
+
+type SurgeBodyRewriteLine struct {
+ Response bool
+ Pattern *regexp.Regexp
+ Match []*regexp.Regexp
+ Replace []string
+}
+
+func (l SurgeBodyRewriteLine) String() string {
+ var fields []string
+ if !l.Response {
+ fields = append(fields, "http-request")
+ } else {
+ fields = append(fields, "http-response")
+ }
+ for i := 0; i < len(l.Match); i += 2 {
+ fields = append(fields, l.Match[i].String(), l.Replace[i])
+ }
+ return strings.Join(fields, " ")
+}
+
+func (l SurgeBodyRewriteLine) MarshalJSON() ([]byte, error) {
+ return json.Marshal(l.String())
+}
+
+func (l *SurgeBodyRewriteLine) UnmarshalJSON(bytes []byte) error {
+ var stringValue string
+ err := json.Unmarshal(bytes, &stringValue)
+ if err != nil {
+ return err
+ }
+ fields, err := surgeFields(stringValue)
+ if err != nil {
+ return E.Cause(err, "invalid surge_body_rewrite line: ", stringValue)
+ } else if len(fields) < 4 {
+ return E.New("invalid surge_body_rewrite line: ", stringValue)
+ } else if len(fields)%2 != 0 {
+ return E.New("invalid surge_body_rewrite line: ", stringValue)
+ }
+ switch fields[0].Key {
+ case "http-request":
+ case "http-response":
+ l.Response = true
+ default:
+ return E.New("invalid surge_body_rewrite line: invalid type: ", stringValue)
+ }
+ l.Pattern, err = regexp.Compile(fields[1].Key)
+ for i := 2; i < len(fields); i += 2 {
+ var match *regexp.Regexp
+ match, err = regexp.Compile(fields[i].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_body_rewrite line: invalid match: ", stringValue)
+ }
+ l.Match = append(l.Match, match)
+ l.Replace = append(l.Replace, fields[i+1].Key)
+ }
+ return nil
+}
+
+type SurgeMapLocalLine struct {
+ Pattern *regexp.Regexp
+ StatusCode int
+ File bool
+ Text bool
+ TinyGif bool
+ Base64 bool
+ Data string
+ Base64Data []byte
+ Headers http.Header
+}
+
+func (l SurgeMapLocalLine) String() string {
+ var fields []surgeField
+ fields = append(fields, surgeField{Key: l.Pattern.String()})
+ if l.File {
+ fields = append(fields, surgeField{Key: "data-type", Value: "file"})
+ fields = append(fields, surgeField{Key: "data", Value: l.Data})
+ } else if l.Text {
+ fields = append(fields, surgeField{Key: "data-type", Value: "text"})
+ fields = append(fields, surgeField{Key: "data", Value: l.Data})
+ } else if l.TinyGif {
+ fields = append(fields, surgeField{Key: "data-type", Value: "tiny-gif"})
+ } else if l.Base64 {
+ fields = append(fields, surgeField{Key: "data-type", Value: "base64"})
+ fields = append(fields, surgeField{Key: "data-type", Value: base64.StdEncoding.EncodeToString(l.Base64Data)})
+ }
+ if l.StatusCode != 0 {
+ fields = append(fields, surgeField{Key: "status-code", Value: F.ToString(l.StatusCode), ValueSet: true})
+ }
+ if len(l.Headers) > 0 {
+ var headers []string
+ for key, values := range l.Headers {
+ for _, value := range values {
+ headers = append(headers, key+":"+value)
+ }
+ }
+ fields = append(fields, surgeField{Key: "headers", Value: strings.Join(headers, "|")})
+ }
+ return encodeSurgeFields(fields)
+}
+
+func (l SurgeMapLocalLine) MarshalJSON() ([]byte, error) {
+ return json.Marshal(l.String())
+}
+
+func (l *SurgeMapLocalLine) UnmarshalJSON(bytes []byte) error {
+ var stringValue string
+ err := json.Unmarshal(bytes, &stringValue)
+ if err != nil {
+ return err
+ }
+ fields, err := surgeFields(stringValue)
+ if err != nil {
+ return E.Cause(err, "invalid surge_map_local line: ", stringValue)
+ } else if len(fields) < 1 {
+ return E.New("invalid surge_map_local line: ", stringValue)
+ }
+ l.Pattern, err = regexp.Compile(fields[0].Key)
+ if err != nil {
+ return E.Cause(err, "invalid surge_map_local line: invalid pattern: ", stringValue)
+ }
+ dataTypeField := common.Find(fields, func(it surgeField) bool {
+ return it.Key == "data-type"
+ })
+ if !dataTypeField.ValueSet {
+ return E.New("invalid surge_map_local line: missing data-type: ", stringValue)
+ }
+ switch dataTypeField.Value {
+ case "file":
+ l.File = true
+ case "text":
+ l.Text = true
+ case "tiny-gif":
+ l.TinyGif = true
+ case "base64":
+ l.Base64 = true
+ default:
+ return E.New("unsupported data-type ", dataTypeField.Value)
+ }
+ for i := 1; i < len(fields); i++ {
+ switch fields[i].Key {
+ case "data-type":
+ continue
+ case "data":
+ if l.File {
+ l.Data = fields[i].Value
+ } else if l.Text {
+ l.Data = fields[i].Value
+ } else if l.Base64 {
+ l.Base64Data, err = base64.StdEncoding.DecodeString(fields[i].Value)
+ if err != nil {
+ return E.New("invalid surge_map_local line: invalid base64 data: ", stringValue)
+ }
+ }
+ case "status-code":
+ statusCode, err := strconv.ParseInt(fields[i].Value, 10, 16)
+ if err != nil {
+ return E.New("invalid surge_map_local line: invalid status code: ", stringValue)
+ }
+ l.StatusCode = int(statusCode)
+ case "header":
+ headers := make(http.Header)
+ for _, headerLine := range strings.Split(fields[i].Value, "|") {
+ if !strings.Contains(headerLine, ":") {
+ return E.New("invalid surge_map_local line: headers: missing `:` in item: ", stringValue, ": ", headerLine)
+ }
+ headers.Add(common.SubstringBefore(headerLine, ":"), common.SubstringAfter(headerLine, ":"))
+ }
+ l.Headers = headers
+ default:
+ return E.New("invalid surge_map_local line: unknown options: ", fields[i].Key)
+ }
+ }
+ return nil
+}
+
+type surgeField struct {
+ Key string
+ Value string
+ ValueSet bool
+}
+
+func encodeSurgeKeys(keys []string) string {
+ keys = common.Map(keys, func(it string) string {
+ if strings.ContainsFunc(it, unicode.IsSpace) {
+ return "\"" + it + "\""
+ } else {
+ return it
+ }
+ })
+ return strings.Join(keys, " ")
+}
+
+func encodeSurgeFields(fields []surgeField) string {
+ return strings.Join(common.Map(fields, func(it surgeField) string {
+ if !it.ValueSet {
+ if strings.ContainsFunc(it.Key, unicode.IsSpace) {
+ return "\"" + it.Key + "\""
+ } else {
+ return it.Key
+ }
+ } else {
+ if strings.ContainsFunc(it.Value, unicode.IsSpace) {
+ return it.Key + "=\"" + it.Value + "\""
+ } else {
+ return it.Key + "=" + it.Value
+ }
+ }
+ }), " ")
+}
+
+func surgeFields(s string) ([]surgeField, error) {
+ var (
+ fields []surgeField
+ currentField *surgeField
+ )
+ for _, field := range strings.Fields(s) {
+ if currentField != nil {
+ field = " " + field
+ if strings.HasSuffix(field, "\"") {
+ field = field[:len(field)-1]
+ if !currentField.ValueSet {
+ currentField.Key += field
+ } else {
+ currentField.Value += field
+ }
+ fields = append(fields, *currentField)
+ currentField = nil
+ } else {
+ if !currentField.ValueSet {
+ currentField.Key += field
+ } else {
+ currentField.Value += field
+ }
+ }
+ continue
+ }
+ if !strings.Contains(field, "=") {
+ if strings.HasPrefix(field, "\"") {
+ field = field[1:]
+ if strings.HasSuffix(field, "\"") {
+ field = field[:len(field)-1]
+ } else {
+ currentField = &surgeField{Key: field}
+ continue
+ }
+ }
+ fields = append(fields, surgeField{Key: field})
+ } else {
+ key := common.SubstringBefore(field, "=")
+ value := common.SubstringAfter(field, "=")
+ if strings.HasPrefix(value, "\"") {
+ value = value[1:]
+ if strings.HasSuffix(field, "\"") {
+ value = value[:len(value)-1]
+ } else {
+ currentField = &surgeField{Key: key, Value: value, ValueSet: true}
+ continue
+ }
+ }
+ fields = append(fields, surgeField{Key: key, Value: value, ValueSet: true})
+ }
+ }
+ if currentField != nil {
+ return nil, E.New("invalid surge fields line: ", s)
+ }
+ return fields, nil
+}
diff --git a/option/options.go b/option/options.go
index 168074ed..ebd49fba 100644
--- a/option/options.go
+++ b/option/options.go
@@ -12,13 +12,14 @@ type _Options struct {
Schema string `json:"$schema,omitempty"`
Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"`
- NTP *NTPOptions `json:"ntp,omitempty"`
- Certificate *CertificateOptions `json:"certificate,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"`
+ NTP *NTPOptions `json:"ntp,omitempty"`
+ Certificate *CertificateOptions `json:"certificate,omitempty"`
+ MITM *MITMOptions `json:"mitm,omitempty"`
}
type Options _Options
diff --git a/option/rule_action.go b/option/rule_action.go
index 7c05dce6..473d371f 100644
--- a/option/rule_action.go
+++ b/option/rule_action.go
@@ -158,6 +158,8 @@ type RawRouteOptionsActionOptions struct {
TLSFragment bool `json:"tls_fragment,omitempty"`
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
+
+ MITM *MITMRouteOptions `json:"mitm,omitempty"`
}
type RouteOptionsActionOptions RawRouteOptionsActionOptions
diff --git a/route/conn.go b/route/conn.go
index c2a2eab9..33c02a44 100644
--- a/route/conn.go
+++ b/route/conn.go
@@ -24,23 +24,31 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
+ "github.com/sagernet/sing/service"
)
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct {
+ ctx context.Context
logger logger.ContextLogger
+ mitm adapter.MITMEngine
access sync.Mutex
connections list.List[io.Closer]
}
-func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
+func NewConnectionManager(ctx context.Context, logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{
+ ctx: ctx,
logger: logger,
}
}
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
+ switch stage {
+ case adapter.StartStateInitialize:
+ m.mitm = service.FromContext[adapter.MITMEngine](m.ctx)
+ }
return nil
}
@@ -55,6 +63,14 @@ func (m *ConnectionManager) Close() error {
}
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
+ if metadata.MITM != nil && metadata.MITM.Enabled {
+ if m.mitm == nil {
+ m.logger.WarnContext(ctx, "MITM disabled")
+ } else {
+ m.mitm.NewConnection(ctx, this, conn, metadata, onClose)
+ return
+ }
+ }
ctx = adapter.WithContext(ctx, &metadata)
var (
remoteConn net.Conn
diff --git a/route/route.go b/route/route.go
index 531ad039..72ecbd34 100644
--- a/route/route.go
+++ b/route/route.go
@@ -458,6 +458,9 @@ match:
metadata.TLSFragment = true
metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay
}
+ if routeOptions.MITM != nil && routeOptions.MITM.Enabled {
+ metadata.MITM = routeOptions.MITM
+ }
}
switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff:
diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go
index f49baca6..73bc2bae 100644
--- a/route/rule/rule_action.go
+++ b/route/rule/rule_action.go
@@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPConnect: action.RouteOptions.UDPConnect,
TLSFragment: action.RouteOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
+ MITM: action.RouteOptions.MITM,
},
}, nil
case C.RuleActionTypeRouteOptions:
@@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
TLSFragment: action.RouteOptionsOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
+ MITM: action.RouteOptionsOptions.MITM,
}, nil
case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string {
func (r *RuleActionRoute) String() string {
var descriptions []string
descriptions = append(descriptions, r.Outbound)
- if r.UDPDisableDomainUnmapping {
- descriptions = append(descriptions, "udp-disable-domain-unmapping")
- }
- if r.UDPConnect {
- descriptions = append(descriptions, "udp-connect")
- }
- if r.TLSFragment {
- descriptions = append(descriptions, "tls-fragment")
- }
+ descriptions = append(descriptions, r.Descriptions()...)
return F.ToString("route(", strings.Join(descriptions, ","), ")")
}
@@ -176,13 +170,14 @@ type RuleActionRouteOptions struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
+ MITM *option.MITMRouteOptions
}
func (r *RuleActionRouteOptions) Type() string {
return C.RuleActionTypeRouteOptions
}
-func (r *RuleActionRouteOptions) String() string {
+func (r *RuleActionRouteOptions) Descriptions() []string {
var descriptions []string
if r.OverrideAddress.IsValid() {
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString()))
@@ -209,9 +204,22 @@ func (r *RuleActionRouteOptions) String() string {
descriptions = append(descriptions, "udp-connect")
}
if r.UDPTimeout > 0 {
- descriptions = append(descriptions, "udp-timeout")
+ descriptions = append(descriptions, F.ToString("udp-timeout=", r.UDPTimeout))
}
- return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
+ if r.TLSFragment {
+ descriptions = append(descriptions, "tls-fragment")
+ if r.TLSFragmentFallbackDelay > 0 {
+ descriptions = append(descriptions, F.ToString("tls-fragment-fallbac-delay=", r.TLSFragmentFallbackDelay.String()))
+ }
+ }
+ if r.MITM != nil && r.MITM.Enabled {
+ descriptions = append(descriptions, "mitm")
+ }
+ return descriptions
+}
+
+func (r *RuleActionRouteOptions) String() string {
+ return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
}
type RuleActionDNSRoute struct {