diff --git a/adapter/certificate.go b/adapter/certificate.go index 0998e130..11eeb196 100644 --- a/adapter/certificate.go +++ b/adapter/certificate.go @@ -10,6 +10,9 @@ import ( type CertificateStore interface { LifecycleService Pool() *x509.CertPool + TLSDecryptionEnabled() bool + TLSDecryptionCertificate() *x509.Certificate + TLSDecryptionPrivateKey() any } func RootPoolFromContext(ctx context.Context) *x509.CertPool { diff --git a/adapter/inbound.go b/adapter/inbound.go index 1218c049..869cdec9 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,6 +2,8 @@ package adapter import ( "context" + "crypto/tls" + "net/http" "net/netip" "time" @@ -58,6 +60,8 @@ type InboundContext struct { Client string SniffContext any PacketSniffError error + HTTPRequest *http.Request + ClientHello *tls.ClientHelloInfo // cache @@ -74,6 +78,7 @@ type InboundContext struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + MITM *option.MITMRouteOptions NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType diff --git a/adapter/lifecycle.go b/adapter/lifecycle.go index aff9fadb..9e522141 100644 --- a/adapter/lifecycle.go +++ b/adapter/lifecycle.go @@ -1,6 +1,8 @@ package adapter -import E "github.com/sagernet/sing/common/exceptions" +import ( + E "github.com/sagernet/sing/common/exceptions" +) type StartStage uint8 @@ -45,6 +47,9 @@ type LifecycleService interface { func Start(stage StartStage, services ...Lifecycle) error { for _, service := range services { + if service == nil { + continue + } err := service.Start(stage) if err != nil { return err diff --git a/adapter/mitm.go b/adapter/mitm.go new file mode 100644 index 00000000..450468a9 --- /dev/null +++ b/adapter/mitm.go @@ -0,0 +1,13 @@ +package adapter + +import ( + "context" + "net" + + N "github.com/sagernet/sing/common/network" +) + +type MITMEngine interface { + Lifecycle + NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc) +} diff --git a/box.go b/box.go index 0f176474..664764a6 100644 --- a/box.go +++ b/box.go @@ -23,6 +23,7 @@ import ( "github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/mitm" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/protocol/direct" "github.com/sagernet/sing-box/route" @@ -48,6 +49,7 @@ type Box struct { dnsRouter *dns.Router connection *route.ConnectionManager router *route.Router + mitm adapter.MITMEngine //*mitm.Engine services []adapter.LifecycleService done chan struct{} } @@ -143,18 +145,12 @@ func New(options Options) (*Box, error) { } var services []adapter.LifecycleService - certificateOptions := common.PtrValueOrDefault(options.Certificate) - if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem || - len(certificateOptions.Certificate) > 0 || - len(certificateOptions.CertificatePath) > 0 || - len(certificateOptions.CertificateDirectoryPath) > 0 { - certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions) - if err != nil { - return nil, err - } - service.MustRegister[adapter.CertificateStore](ctx, certificateStore) - services = append(services, certificateStore) + certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), common.PtrValueOrDefault(options.Certificate)) + if err != nil { + return nil, err } + service.MustRegister[adapter.CertificateStore](ctx, certificateStore) + services = append(services, certificateStore) routeOptions := common.PtrValueOrDefault(options.Route) dnsOptions := common.PtrValueOrDefault(options.DNS) @@ -173,7 +169,7 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize network manager") } service.MustRegister[adapter.NetworkManager](ctx, networkManager) - connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) + connectionManager := route.NewConnectionManager(ctx, logFactory.NewLogger("connection")) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) service.MustRegister[adapter.Router](ctx, router) @@ -181,8 +177,8 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "initialize router") } - ntpOptions := common.PtrValueOrDefault(options.NTP) var timeService *tls.TimeServiceWrapper + ntpOptions := common.PtrValueOrDefault(options.NTP) if ntpOptions.Enabled { timeService = new(tls.TimeServiceWrapper) service.MustRegister[ntp.TimeService](ctx, timeService) @@ -345,6 +341,16 @@ func New(options Options) (*Box, error) { timeService.TimeService = ntpService services = append(services, adapter.NewLifecycleService(ntpService, "ntp service")) } + mitmOptions := common.PtrValueOrDefault(options.MITM) + var mitmEngine adapter.MITMEngine + if mitmOptions.Enabled { + engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions) + if err != nil { + return nil, E.Cause(err, "create MITM engine") + } + service.MustRegister[adapter.MITMEngine](ctx, engine) + mitmEngine = engine + } return &Box{ network: networkManager, endpoint: endpointManager, @@ -354,6 +360,7 @@ func New(options Options) (*Box, error) { dnsRouter: dnsRouter, connection: connectionManager, router: router, + mitm: mitmEngine, createdAt: createdAt, logFactory: logFactory, logger: logFactory.Logger(), @@ -412,11 +419,11 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.outbound, s.inbound, s.endpoint) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router, s.mitm) if err != nil { return err } @@ -440,7 +447,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.inbound, s.endpoint) if err != nil { return err } @@ -448,7 +455,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.mitm, s.outbound, s.inbound, s.endpoint) if err != nil { return err } @@ -467,7 +474,7 @@ func (s *Box) Close() error { close(s.done) } err := common.Close( - s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, + s.inbound, s.outbound, s.endpoint, s.mitm, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, ) for _, lifecycleService := range s.services { err = E.Append(err, lifecycleService.Close(), func(err error) error { diff --git a/cmd/sing-box/cmd_generate_ca.go b/cmd/sing-box/cmd_generate_ca.go new file mode 100644 index 00000000..fcda2dd6 --- /dev/null +++ b/cmd/sing-box/cmd_generate_ca.go @@ -0,0 +1,120 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" + + "github.com/spf13/cobra" + "software.sslmate.com/src/go-pkcs12" +) + +var ( + flagGenerateCAName string + flagGenerateCAPKCS12Password string + flagGenerateOutput string +) + +var commandGenerateCAKeyPair = &cobra.Command{ + Use: "ca-keypair", + Short: "Generate CA key pair", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + err := generateCAKeyPair() + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAName, "name", "n", "", "Set custom CA name") + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAPKCS12Password, "p12-password", "p", "", "Set custom PKCS12 password") + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateOutput, "output", "o", ".", "Set output directory") + commandGenerate.AddCommand(commandGenerateCAKeyPair) +} + +func generateCAKeyPair() error { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return err + } + spkiASN1, err := x509.MarshalPKIXPublicKey(privateKey.Public()) + var spki struct { + Algorithm pkix.AlgorithmIdentifier + SubjectPublicKey asn1.BitString + } + _, err = asn1.Unmarshal(spkiASN1, &spki) + if err != nil { + return err + } + skid := sha1.Sum(spki.SubjectPublicKey.Bytes) + var caName string + if flagGenerateCAName != "" { + caName = flagGenerateCAName + } else { + caName = "sing-box Generated CA " + strings.ToUpper(hex.EncodeToString(skid[:4])) + } + caTpl := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{caName}, + CommonName: caName, + }, + SubjectKeyId: skid[:], + NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: time.Now(), + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + publicDer, err := x509.CreateCertificate(rand.Reader, caTpl, caTpl, privateKey.Public(), privateKey) + var caPassword string + if flagGenerateCAPKCS12Password != "" { + caPassword = flagGenerateCAPKCS12Password + } else { + caPassword = strings.ToUpper(hex.EncodeToString(skid[:4])) + } + caTpl.Raw = publicDer + p12Bytes, err := pkcs12.Modern.Encode(privateKey, caTpl, nil, caPassword) + if err != nil { + return err + } + privateDer, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return err + } + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".pem"), []byte(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}))+string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateDer}))), 0o644) + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".crt"), publicDer, 0o644) + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".p12"), p12Bytes, 0o644) + var tlsDecryptionOptions option.TLSDecryptionOptions + tlsDecryptionOptions.Enabled = true + tlsDecryptionOptions.KeyPair = base64.StdEncoding.EncodeToString(p12Bytes) + tlsDecryptionOptions.KeyPairPassword = caPassword + var certificateOptions option.CertificateOptions + certificateOptions.TLSDecryption = &tlsDecryptionOptions + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(certificateOptions) +} diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go index 55e5b458..5f5a0d71 100644 --- a/cmd/sing-box/cmd_tools.go +++ b/cmd/sing-box/cmd_tools.go @@ -1,13 +1,6 @@ package main import ( - "errors" - "os" - - "github.com/sagernet/sing-box" - E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" - "github.com/spf13/cobra" ) @@ -19,36 +12,5 @@ var commandTools = &cobra.Command{ } func init() { - commandTools.PersistentFlags().StringVarP(&commandToolsFlagOutbound, "outbound", "o", "", "Use specified tag instead of default outbound") mainCommand.AddCommand(commandTools) } - -func createPreStartedClient() (*box.Box, error) { - options, err := readConfigAndMerge() - if err != nil { - if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" { - return nil, err - } - } - instance, err := box.New(box.Options{Context: globalCtx, Options: options}) - if err != nil { - return nil, E.Cause(err, "create service") - } - err = instance.PreStart() - if err != nil { - return nil, E.Cause(err, "start service") - } - return instance, nil -} - -func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) { - if outboundTag == "" { - return instance.Outbound().Default(), nil - } else { - outbound, loaded := instance.Outbound().Outbound(outboundTag) - if !loaded { - return nil, E.New("outbound not found: ", outboundTag) - } - return outbound, nil - } -} diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go deleted file mode 100644 index d352d533..00000000 --- a/cmd/sing-box/cmd_tools_connect.go +++ /dev/null @@ -1,73 +0,0 @@ -package main - -import ( - "context" - "os" - - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/task" - - "github.com/spf13/cobra" -) - -var commandConnectFlagNetwork string - -var commandConnect = &cobra.Command{ - Use: "connect
", - Short: "Connect to an address", - Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - err := connect(args[0]) - if err != nil { - log.Fatal(err) - } - }, -} - -func init() { - commandConnect.Flags().StringVarP(&commandConnectFlagNetwork, "network", "n", "tcp", "network type") - commandTools.AddCommand(commandConnect) -} - -func connect(address string) error { - switch N.NetworkName(commandConnectFlagNetwork) { - case N.NetworkTCP, N.NetworkUDP: - default: - return E.Cause(N.ErrUnknownNetwork, commandConnectFlagNetwork) - } - instance, err := createPreStartedClient() - if err != nil { - return err - } - defer instance.Close() - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - conn, err := dialer.DialContext(context.Background(), commandConnectFlagNetwork, M.ParseSocksaddr(address)) - if err != nil { - return E.Cause(err, "connect to server") - } - var group task.Group - group.Append("upload", func(ctx context.Context) error { - return common.Error(bufio.Copy(conn, os.Stdin)) - }) - group.Append("download", func(ctx context.Context) error { - return common.Error(bufio.Copy(os.Stdout, conn)) - }) - group.Cleanup(func() { - conn.Close() - }) - err = group.Run(context.Background()) - if E.IsClosed(err) { - log.Info(err) - } else { - log.Error(err) - } - return nil -} diff --git a/cmd/sing-box/cmd_tools_fetch.go b/cmd/sing-box/cmd_tools_fetch.go deleted file mode 100644 index 5ee3b875..00000000 --- a/cmd/sing-box/cmd_tools_fetch.go +++ /dev/null @@ -1,115 +0,0 @@ -package main - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "net/url" - "os" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - - "github.com/spf13/cobra" -) - -var commandFetch = &cobra.Command{ - Use: "fetch", - Short: "Fetch an URL", - Args: cobra.MinimumNArgs(1), - Run: func(cmd *cobra.Command, args []string) { - err := fetch(args) - if err != nil { - log.Fatal(err) - } - }, -} - -func init() { - commandTools.AddCommand(commandFetch) -} - -var ( - httpClient *http.Client - http3Client *http.Client -) - -func fetch(args []string) error { - instance, err := createPreStartedClient() - if err != nil { - return err - } - defer instance.Close() - httpClient = &http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return nil, err - } - return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - ForceAttemptHTTP2: true, - }, - } - defer httpClient.CloseIdleConnections() - if C.WithQUIC { - err = initializeHTTP3Client(instance) - if err != nil { - return err - } - defer http3Client.CloseIdleConnections() - } - for _, urlString := range args { - var parsedURL *url.URL - parsedURL, err = url.Parse(urlString) - if err != nil { - return err - } - switch parsedURL.Scheme { - case "": - parsedURL.Scheme = "http" - fallthrough - case "http", "https": - err = fetchHTTP(httpClient, parsedURL) - if err != nil { - return err - } - case "http3": - if !C.WithQUIC { - return C.ErrQUICNotIncluded - } - parsedURL.Scheme = "https" - err = fetchHTTP(http3Client, parsedURL) - if err != nil { - return err - } - default: - return E.New("unsupported scheme: ", parsedURL.Scheme) - } - } - return nil -} - -func fetchHTTP(httpClient *http.Client, parsedURL *url.URL) error { - request, err := http.NewRequest("GET", parsedURL.String(), nil) - if err != nil { - return err - } - request.Header.Add("User-Agent", "curl/7.88.0") - response, err := httpClient.Do(request) - if err != nil { - return err - } - defer response.Body.Close() - _, err = bufio.Copy(os.Stdout, response.Body) - if errors.Is(err, io.EOF) { - return nil - } - return err -} diff --git a/cmd/sing-box/cmd_tools_fetch_http3.go b/cmd/sing-box/cmd_tools_fetch_http3.go deleted file mode 100644 index b7a31a72..00000000 --- a/cmd/sing-box/cmd_tools_fetch_http3.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build with_quic - -package main - -import ( - "context" - "crypto/tls" - "net/http" - - "github.com/sagernet/quic-go" - "github.com/sagernet/quic-go/http3" - box "github.com/sagernet/sing-box" - "github.com/sagernet/sing/common/bufio" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -func initializeHTTP3Client(instance *box.Box) error { - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - http3Client = &http.Client{ - Transport: &http3.Transport{ - Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - destination := M.ParseSocksaddr(addr) - udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, destination) - if dErr != nil { - return nil, dErr - } - return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsCfg, cfg) - }, - }, - } - return nil -} diff --git a/cmd/sing-box/cmd_tools_fetch_http3_stub.go b/cmd/sing-box/cmd_tools_fetch_http3_stub.go deleted file mode 100644 index ae13f54c..00000000 --- a/cmd/sing-box/cmd_tools_fetch_http3_stub.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build !with_quic - -package main - -import ( - "net/url" - "os" - - box "github.com/sagernet/sing-box" -) - -func initializeHTTP3Client(instance *box.Box) error { - return os.ErrInvalid -} - -func fetchHTTP3(parsedURL *url.URL) error { - return os.ErrInvalid -} diff --git a/cmd/sing-box/cmd_tools_install_ca.go b/cmd/sing-box/cmd_tools_install_ca.go new file mode 100644 index 00000000..80ec14b1 --- /dev/null +++ b/cmd/sing-box/cmd_tools_install_ca.go @@ -0,0 +1,108 @@ +package main + +import ( + "encoding/pem" + "errors" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/shell" + + "github.com/spf13/cobra" +) + +var commandInstallCACertificate = &cobra.Command{ + Use: "install-ca ", + Short: "Install CA certificate to system", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + err := installCACertificate(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandTools.AddCommand(commandInstallCACertificate) +} + +func installCACertificate(path string) error { + switch runtime.GOOS { + case "windows": + return shell.Exec("powershell", "-Command", "Import-Certificate -FilePath \""+path+"\" -CertStoreLocation Cert:\\LocalMachine\\Root").Attach().Run() + case "darwin": + return shell.Exec("sudo", "security", "add-trusted-cert", "-d", "-r", "trustRoot", "-k", "/Library/Keychains/System.keychain", path).Attach().Run() + case "linux": + updateCertPath, updateCertPathNotFoundErr := exec.LookPath("update-ca-certificates") + if updateCertPathNotFoundErr == nil { + publicDer, err := os.ReadFile(path) + if err != nil { + return err + } + err = os.MkdirAll("/usr/local/share/ca-certificates", 0o755) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + fileName := filepath.Base(updateCertPath) + if !strings.HasSuffix(fileName, ".crt") { + fileName = fileName + ".crt" + } + filePath, _ := filepath.Abs(filepath.Join("/usr/local/share/ca-certificates", fileName)) + err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + log.Info("certificate written to " + filePath + "\n") + err = shell.Exec(updateCertPath).Attach().Run() + if err != nil { + return err + } + log.Info("certificate installed") + return nil + } + updateTrustPath, updateTrustPathNotFoundErr := exec.LookPath("update-ca-trust") + if updateTrustPathNotFoundErr == nil { + publicDer, err := os.ReadFile(path) + if err != nil { + return err + } + fileName := filepath.Base(updateTrustPath) + fileExt := filepath.Ext(path) + if fileExt != "" { + fileName = fileName[:len(fileName)-len(fileExt)] + } + filePath, _ := filepath.Abs(filepath.Join("/etc/pki/ca-trust/source/anchors/", fileName+".pem")) + err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + log.Info("certificate written to " + filePath + "\n") + err = shell.Exec(updateTrustPath, "extract").Attach().Run() + if err != nil { + return err + } + log.Info("certificate installed") + } + return E.New("update-ca-certificates or update-ca-trust not found") + default: + return E.New("unsupported operating system: ", runtime.GOOS) + } +} diff --git a/cmd/sing-box/cmd_tools_synctime.go b/cmd/sing-box/cmd_tools_synctime.go index 09d487ef..33e38743 100644 --- a/cmd/sing-box/cmd_tools_synctime.go +++ b/cmd/sing-box/cmd_tools_synctime.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" "github.com/spf13/cobra" @@ -39,20 +40,11 @@ func init() { } func syncTime() error { - instance, err := createPreStartedClient() - if err != nil { - return err - } - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - defer instance.Close() serverAddress := M.ParseSocksaddr(commandSyncTimeFlagServer) if serverAddress.Port == 0 { serverAddress.Port = 123 } - response, err := ntp.Exchange(context.Background(), dialer, serverAddress) + response, err := ntp.Exchange(context.Background(), N.SystemDialer, serverAddress) if err != nil { return err } diff --git a/common/certificate/store.go b/common/certificate/store.go index 34f20019..d17f77b5 100644 --- a/common/certificate/store.go +++ b/common/certificate/store.go @@ -3,6 +3,7 @@ package certificate import ( "context" "crypto/x509" + "encoding/base64" "io/fs" "os" "path/filepath" @@ -16,6 +17,8 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/service" + + "software.sslmate.com/src/go-pkcs12" ) var _ adapter.CertificateStore = (*Store)(nil) @@ -27,6 +30,9 @@ type Store struct { certificatePaths []string certificateDirectoryPaths []string watcher *fswatch.Watcher + tlsDecryptionEnabled bool + tlsDecryptionPrivateKey any + tlsDecryptionCertificate *x509.Certificate } func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) { @@ -90,6 +96,19 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific if err != nil { return nil, E.Cause(err, "initializing certificate store") } + if options.TLSDecryption != nil && options.TLSDecryption.Enabled { + pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryption.KeyPair) + if err != nil { + return nil, E.Cause(err, "decode key pair base64 bytes") + } + privateKey, certificate, err := pkcs12.Decode(pfxBytes, options.TLSDecryption.KeyPairPassword) + if err != nil { + return nil, E.Cause(err, "decode key pair") + } + store.tlsDecryptionEnabled = true + store.tlsDecryptionPrivateKey = privateKey + store.tlsDecryptionCertificate = certificate + } return store, nil } @@ -183,3 +202,15 @@ func isSameDirSymlink(f fs.DirEntry, dir string) bool { target, err := os.Readlink(filepath.Join(dir, f.Name())) return err == nil && !strings.Contains(target, "/") } + +func (s *Store) TLSDecryptionEnabled() bool { + return s.tlsDecryptionEnabled +} + +func (s *Store) TLSDecryptionCertificate() *x509.Certificate { + return s.tlsDecryptionCertificate +} + +func (s *Store) TLSDecryptionPrivateKey() any { + return s.tlsDecryptionPrivateKey +} diff --git a/common/sniff/http.go b/common/sniff/http.go index 0e6ab406..e7c6eb8c 100644 --- a/common/sniff/http.go +++ b/common/sniff/http.go @@ -18,5 +18,6 @@ func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Rea } metadata.Protocol = C.ProtocolHTTP metadata.Domain = M.ParseSocksaddr(request.Host).AddrString() + metadata.HTTPRequest = request return nil } diff --git a/common/sniff/tls.go b/common/sniff/tls.go index 6fe430e2..27729fa2 100644 --- a/common/sniff/tls.go +++ b/common/sniff/tls.go @@ -21,6 +21,7 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade if clientHello != nil { metadata.Protocol = C.ProtocolTLS metadata.Domain = clientHello.ServerName + metadata.ClientHello = clientHello return nil } return err diff --git a/common/tls/mkcert.go b/common/tls/mkcert.go index 4e0ed102..fc9c4ab9 100644 --- a/common/tls/mkcert.go +++ b/common/tls/mkcert.go @@ -8,7 +8,10 @@ import ( "crypto/x509/pkix" "encoding/pem" "math/big" + "net" "time" + + M "github.com/sagernet/sing/common/metadata" ) func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) { @@ -35,17 +38,30 @@ func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func( if err != nil { return } - template := &x509.Certificate{ - SerialNumber: serialNumber, - NotBefore: timeFunc().Add(time.Hour * -1), - NotAfter: expire, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - Subject: pkix.Name{ - CommonName: serverName, - }, - DNSNames: []string{serverName}, + var template *x509.Certificate + if serverAddress := M.ParseAddr(serverName); serverAddress.IsValid() { + template = &x509.Certificate{ + SerialNumber: serialNumber, + IPAddresses: []net.IP{serverAddress.AsSlice()}, + NotBefore: timeFunc().Add(time.Hour * -1), + NotAfter: expire, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + } else { + template = &x509.Certificate{ + SerialNumber: serialNumber, + NotBefore: timeFunc().Add(time.Hour * -1), + NotAfter: expire, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + Subject: pkix.Name{ + CommonName: serverName, + }, + DNSNames: []string{serverName}, + } } if parent == nil { parent = template diff --git a/experimental/clashapi/mitm.go b/experimental/clashapi/mitm.go new file mode 100644 index 00000000..8d64081d --- /dev/null +++ b/experimental/clashapi/mitm.go @@ -0,0 +1,186 @@ +package clashapi + +import ( + "archive/zip" + "context" + "crypto/x509" + "encoding/pem" + "io" + "net/http" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/service" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/gofrs/uuid/v5" + "howett.net/plist" +) + +func mitmRouter(ctx context.Context) http.Handler { + r := chi.NewRouter() + r.Get("/mobileconfig", getMobileConfig(ctx)) + r.Get("/crt", getCertificate(ctx)) + r.Get("/pem", getCertificatePEM(ctx)) + r.Get("/magisk", getMagiskModule(ctx)) + return r +} + +func getMobileConfig(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + certificate := store.TLSDecryptionCertificate() + writer.Header().Set("Content-Type", "application/x-apple-aspen-config") + uuidGen := common.Must1(uuid.NewV4()).String() + mobileConfig := map[string]interface{}{ + "PayloadContent": []interface{}{ + map[string]interface{}{ + "PayloadCertificateFileName": "Certificates.cer", + "PayloadContent": certificate.Raw, + "PayloadDescription": "Adds a root certificate", + "PayloadDisplayName": certificate.Subject.CommonName, + "PayloadIdentifier": "com.apple.security.root." + uuidGen, + "PayloadType": "com.apple.security.root", + "PayloadUUID": uuidGen, + "PayloadVersion": 1, + }, + }, + "PayloadDisplayName": certificate.Subject.CommonName, + "PayloadIdentifier": "io.nekohasekai.sfa.ca.profile." + uuidGen, + "PayloadRemovalDisallowed": false, + "PayloadType": "Configuration", + "PayloadUUID": uuidGen, + "PayloadVersion": 1, + } + encoder := plist.NewEncoder(writer) + encoder.Indent("\t") + encoder.Encode(mobileConfig) + } +} + +func getCertificate(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/x-x509-ca-cert") + writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.crt") + writer.Write(store.TLSDecryptionCertificate().Raw) + } +} + +func getCertificatePEM(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/x-pem-file") + writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.pem") + writer.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: store.TLSDecryptionCertificate().Raw})) + } +} + +func getMagiskModule(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/zip") + writer.Header().Set("Content-Disposition", "attachment; filename="+store.TLSDecryptionCertificate().Subject.CommonName+".zip") + createMagiskModule(writer, store.TLSDecryptionCertificate()) + } +} + +func createMagiskModule(writer io.Writer, certificate *x509.Certificate) error { + zipWriter := zip.NewWriter(writer) + defer zipWriter.Close() + moduleProp, err := zipWriter.Create("module.prop") + if err != nil { + return err + } + _, err = moduleProp.Write([]byte(` +id=sing-box-certificate +name=` + certificate.Subject.CommonName + ` +version=v0.0.1 +versionCode=1 +author=sing-box +description=This module adds ` + certificate.Subject.CommonName + ` to the system trust store. +`)) + if err != nil { + return err + } + certificateFile, err := zipWriter.Create("system/etc/security/cacerts/" + certificate.Subject.CommonName + ".pem") + if err != nil { + return err + } + err = pem.Encode(certificateFile, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw}) + if err != nil { + return err + } + updateBinary, err := zipWriter.Create("META-INF/com/google/android/update-binary") + if err != nil { + return err + } + _, err = updateBinary.Write([]byte(` +#!/sbin/sh + +################# +# Initialization +################# + +umask 022 + +# echo before loading util_functions +ui_print() { echo "$1"; } + +require_new_magisk() { + ui_print "*******************************" + ui_print " Please install Magisk v20.4+! " + ui_print "*******************************" + exit 1 +} + +######################### +# Load util_functions.sh +######################### + +OUTFD=$2 +ZIPFILE=$3 + +mount /data 2>/dev/null + +[ -f /data/adb/magisk/util_functions.sh ] || require_new_magisk +. /data/adb/magisk/util_functions.sh +[ $MAGISK_VER_CODE -lt 20400 ] && require_new_magisk + +install_module +exit 0 +`)) + if err != nil { + return err + } + updaterScript, err := zipWriter.Create("META-INF/com/google/android/updater-script") + if err != nil { + return err + } + _, err = updaterScript.Write([]byte("#MAGISK")) + if err != nil { + return err + } + return nil +} diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index e6d8c4cf..e471a13f 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -124,6 +124,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op r.Mount("/profile", profileRouter()) r.Mount("/cache", cacheRouter(ctx)) r.Mount("/dns", dnsRouter(s.dnsRouter)) + r.Mount("/mitm", mitmRouter(ctx)) s.setupMetaAPI(r) }) diff --git a/go.mod b/go.mod index 9a970b46..d27b976d 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( google.golang.org/grpc v1.70.0 google.golang.org/protobuf v1.36.5 howett.net/plist v1.0.1 + software.sslmate.com/src/go-pkcs12 v0.4.0 ) //replace github.com/sagernet/sing => ../sing diff --git a/log/log.go b/log/log.go index 7b8f2843..6ee79677 100644 --- a/log/log.go +++ b/log/log.go @@ -10,6 +10,10 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) +const ( + DefaultTimeFormat = "-0700 2006-01-02 15:04:05" +) + type Options struct { Context context.Context Options option.LogOptions @@ -47,7 +51,7 @@ func New(options Options) (Factory, error) { DisableColors: logOptions.DisableColor || logFilePath != "", DisableTimestamp: !logOptions.Timestamp && logFilePath != "", FullTimestamp: logOptions.Timestamp, - TimestampFormat: "-0700 2006-01-02 15:04:05", + TimestampFormat: DefaultTimeFormat, } factory := NewDefaultFactory( options.Context, diff --git a/mitm/constants.go b/mitm/constants.go new file mode 100644 index 00000000..e3a9a46b --- /dev/null +++ b/mitm/constants.go @@ -0,0 +1,11 @@ +package mitm + +import ( + "encoding/base64" + + "github.com/sagernet/sing/common" +) + +var surgeTinyGif = common.OnceValue(func() []byte { + return common.Must1(base64.StdEncoding.DecodeString("R0lGODlhAQABAAAAACH5BAEAAAAALAAAAAABAAEAAAIBAAA=")) +}) diff --git a/mitm/engine.go b/mitm/engine.go new file mode 100644 index 00000000..a9acbc47 --- /dev/null +++ b/mitm/engine.go @@ -0,0 +1,811 @@ +package mitm + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "io" + "math" + "mime" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "time" + "unicode" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + sTLS "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" + sHTTP "github.com/sagernet/sing/protocol/http" + "github.com/sagernet/sing/service" + + "golang.org/x/net/http2" +) + +var _ adapter.MITMEngine = (*Engine)(nil) + +type Engine struct { + ctx context.Context + logger logger.ContextLogger + connection adapter.ConnectionManager + certificate adapter.CertificateStore + timeFunc func() time.Time + http2Enabled bool +} + +func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) { + engine := &Engine{ + ctx: ctx, + logger: logger, + http2Enabled: options.HTTP2Enabled, + } + return engine, nil +} + +func (e *Engine) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateInitialize: + e.connection = service.FromContext[adapter.ConnectionManager](e.ctx) + e.certificate = service.FromContext[adapter.CertificateStore](e.ctx) + e.timeFunc = ntp.TimeFuncFromContext(e.ctx) + if e.timeFunc == nil { + e.timeFunc = time.Now + } + } + return nil +} + +func (e *Engine) Close() error { + return nil +} + +func (e *Engine) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if e.certificate.TLSDecryptionEnabled() && metadata.ClientHello != nil { + err := e.newTLS(ctx, this, conn, metadata, onClose) + if err != nil { + e.logger.ErrorContext(ctx, err) + } else { + e.logger.DebugContext(ctx, "connection closed") + } + if onClose != nil { + onClose(err) + } + return + } else if metadata.HTTPRequest != nil { + err := e.newHTTP1(ctx, this, conn, nil, metadata) + if err != nil { + e.logger.ErrorContext(ctx, err) + } else { + e.logger.DebugContext(ctx, "connection closed") + } + if onClose != nil { + onClose(err) + } + return + } else { + e.logger.DebugContext(ctx, "HTTP and TLS not detected, skipped") + } + metadata.MITM = nil + e.connection.NewConnection(ctx, this, conn, metadata, onClose) +} + +func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1") + acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2") + if !acceptHTTP && !acceptH2 { + metadata.MITM = nil + e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ",")) + e.connection.NewConnection(ctx, this, conn, metadata, onClose) + return nil + } + var nextProtos []string + if acceptH2 { + nextProtos = append(nextProtos, "h2") + } else if acceptHTTP { + nextProtos = append(nextProtos, "http/1.1") + } + var ( + maxVersion uint16 + minVersion uint16 + ) + for _, version := range metadata.ClientHello.SupportedVersions { + maxVersion = common.Max(maxVersion, version) + minVersion = common.Min(minVersion, version) + } + serverName := metadata.ClientHello.ServerName + if serverName == "" && metadata.Destination.IsIP() { + serverName = metadata.Destination.Addr.String() + } + tlsConfig := &tls.Config{ + Time: e.timeFunc, + ServerName: serverName, + NextProtos: nextProtos, + MinVersion: minVersion, + MaxVersion: maxVersion, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return sTLS.GenerateKeyPair(e.certificate.TLSDecryptionCertificate(), e.certificate.TLSDecryptionPrivateKey(), e.timeFunc, serverName) + }, + } + tlsConn := tls.Server(conn, tlsConfig) + err := tlsConn.HandshakeContext(ctx) + if err != nil { + return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", ")) + } + if tlsConn.ConnectionState().NegotiatedProtocol == "h2" { + return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose) + } else { + return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata) + } +} + +func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error { + options := metadata.MITM + defer conn.Close() + reader := bufio.NewReader(conn) + request, err := sHTTP.ReadRequest(reader) + if err != nil { + return E.Cause(err, "read HTTP request") + } + rawRequestURL := request.URL + if tlsConfig != nil { + rawRequestURL.Scheme = "https" + } else { + rawRequestURL.Scheme = "http" + } + if rawRequestURL.Host == "" { + rawRequestURL.Host = request.Host + } + requestURL := rawRequestURL.String() + request.RequestURI = "" + var requestMatch bool + var body []byte + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } + for i, rule := range options.SurgeURLRewrite { + if !rule.Pattern.MatchString(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String()) + if rule.Reject { + return E.New("request rejected by url_rewrite") + } else if rule.Redirect { + w := new(simpleResponseWriter) + http.Redirect(w, request, rule.Destination.String(), http.StatusFound) + err = w.Build(request).Write(conn) + if err != nil { + return E.Cause(err, "write url_rewrite 302 response") + } + return nil + } + requestMatch = true + request.URL = rule.Destination + newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port()) + if newDestination.Port == 0 { + newDestination.Port = metadata.Destination.Port + } + metadata.Destination = newDestination + if tlsConfig != nil { + tlsConfig.ServerName = rule.Destination.Hostname() + } + break + } + for i, rule := range options.SurgeHeaderRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + if strings.ToLower(rule.Key) == "host" { + request.Host = rule.Value + continue + } + request.Header.Add(rule.Key, rule.Value) + case rule.Delete: + request.Header.Del(rule.Key) + case rule.Replace: + if request.Header.Get(rule.Key) != "" { + request.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := request.Header.Get(rule.Key); value != "" { + request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if body == nil { + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + } + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + if !requestMatch { + for i, rule := range options.SurgeMapLocal { + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String()) + var ( + statusCode = http.StatusOK + headers = make(http.Header) + ) + if rule.StatusCode > 0 { + statusCode = rule.StatusCode + } + switch { + case rule.File: + resource, err := os.ReadFile(rule.Data) + if err != nil { + return E.Cause(err, "open map local source") + } + mimeType := mime.TypeByExtension(filepath.Ext(rule.Data)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + headers.Set("Content-Type", mimeType) + body = resource + case rule.Text: + headers.Set("Content-Type", "text/plain") + body = []byte(rule.Data) + case rule.TinyGif: + headers.Set("Content-Type", "image/gif") + body = surgeTinyGif() + case rule.Base64: + headers.Set("Content-Type", "application/octet-stream") + body = rule.Base64Data + } + response := &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: headers, + Body: io.NopCloser(bytes.NewReader(body)), + } + err = response.Write(conn) + if err != nil { + return E.Cause(err, "write map local response") + } + return nil + } + } + ctx = adapter.WithContext(ctx, &metadata) + var innerErr atomic.TypedValue[error] + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + }, + TLSClientConfig: tlsConfig, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + defer httpClient.CloseIdleConnections() + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + response, err := httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err) + } + var responseMatch bool + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, request, response, responseBody) + } + for i, rule := range options.SurgeHeaderRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + response.Header.Add(rule.Key, rule.Value) + case rule.Delete: + response.Header.Del(rule.Key) + case rule.Replace: + if response.Header.Get(rule.Key) != "" { + response.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := response.Header.Get(rule.Key); value != "" { + response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + } + for mi := 0; i < len(rule.Match); i++ { + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) + } + if !options.Print && !requestMatch && !responseMatch { + e.logger.WarnContext(ctx, "request not modified") + } + err = response.Write(conn) + if err != nil { + return E.Errors(E.Cause(err, "write HTTP response"), innerErr.Load()) + } else if innerErr.Load() != nil { + return E.Cause(innerErr.Load(), "write HTTP response") + } + return nil +} + +func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + httpTransport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + ctx = adapter.WithContext(ctx, &metadata) + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + }, + TLSClientConfig: tlsConfig, + } + err := http2.ConfigureTransport(httpTransport) + if err != nil { + return E.Cause(err, "configure HTTP/2 transport") + } + handler := &engineHandler{ + Engine: e, + conn: conn, + tlsConfig: tlsConfig, + dialer: this, + metadata: metadata, + httpClient: &http.Client{ + Transport: httpTransport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + onClose: onClose, + } + http2Server := &http2.Server{ + MaxReadFrameSize: math.MaxUint32, + } + http2Server.ServeConn(conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: handler, + }) + return nil +} + +type engineHandler struct { + *Engine + conn net.Conn + tlsConfig *tls.Config + dialer N.Dialer + metadata adapter.InboundContext + onClose N.CloseHandlerFunc + httpClient *http.Client +} + +func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + err := e.serveHTTP(request.Context(), writer, request) + if err != nil { + if E.IsClosedOrCanceled(err) { + e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed")) + } else { + e.logger.ErrorContext(request.Context(), err) + } + } +} + +func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + options := e.metadata.MITM + rawRequestURL := request.URL + rawRequestURL.Scheme = "https" + if rawRequestURL.Host == "" { + rawRequestURL.Host = request.Host + } + requestURL := rawRequestURL.String() + request.RequestURI = "" + var requestMatch bool + var ( + body []byte + err error + ) + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } + for i, rule := range options.SurgeURLRewrite { + if !rule.Pattern.MatchString(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String()) + if rule.Reject { + return E.New("request rejected by url_rewrite") + } else if rule.Redirect { + http.Redirect(writer, request, rule.Destination.String(), http.StatusFound) + return nil + } + requestMatch = true + request.URL = rule.Destination + newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port()) + if newDestination.Port == 0 { + newDestination.Port = e.metadata.Destination.Port + } + e.metadata.Destination = newDestination + e.tlsConfig.ServerName = rule.Destination.Hostname() + break + } + for i, rule := range options.SurgeHeaderRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + if strings.ToLower(rule.Key) == "host" { + request.Host = rule.Value + continue + } + request.Header.Add(rule.Key, rule.Value) + case rule.Delete: + request.Header.Del(rule.Key) + case rule.Replace: + if request.Header.Get(rule.Key) != "" { + request.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := request.Header.Get(rule.Key); value != "" { + request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + var body []byte + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err := io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + if !requestMatch { + for i, rule := range options.SurgeMapLocal { + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String()) + go func() { + io.Copy(io.Discard, request.Body) + request.Body.Close() + }() + var ( + statusCode = http.StatusOK + headers = make(http.Header) + body []byte + ) + if rule.StatusCode > 0 { + statusCode = rule.StatusCode + } + switch { + case rule.File: + resource, err := os.ReadFile(rule.Data) + if err != nil { + return E.Cause(err, "open map local source") + } + mimeType := mime.TypeByExtension(filepath.Ext(rule.Data)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + headers.Set("Content-Type", mimeType) + body = resource + case rule.Text: + headers.Set("Content-Type", "text/plain") + body = []byte(rule.Data) + case rule.TinyGif: + headers.Set("Content-Type", "image/gif") + body = surgeTinyGif() + case rule.Base64: + headers.Set("Content-Type", "application/octet-stream") + body = rule.Base64Data + } + for key, values := range headers { + writer.Header()[key] = values + } + writer.WriteHeader(statusCode) + _, err = writer.Write(body) + if err != nil { + return E.Cause(err, "write map local response") + } + return nil + } + } + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + response, err := e.httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Cause(err, "exchange request") + } + var responseMatch bool + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, request, response, responseBody) + } + for i, rule := range options.SurgeHeaderRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + response.Header.Add(rule.Key, rule.Value) + case rule.Delete: + response.Header.Del(rule.Key) + case rule.Replace: + if response.Header.Get(rule.Key) != "" { + response.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := response.Header.Get(rule.Key); value != "" { + response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + response.Body.Close() + } + for mi := 0; i < len(rule.Match); i++ { + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) + } + if !options.Print && !requestMatch && !responseMatch { + e.logger.WarnContext(ctx, "request not modified") + } + for key, values := range response.Header { + writer.Header()[key] = values + } + writer.WriteHeader(response.StatusCode) + _, err = io.Copy(writer, response.Body) + response.Body.Close() + if err != nil { + return E.Cause(err, "write HTTP response") + } + return nil +} + +func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) { + var builder strings.Builder + builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL)) + builder.WriteString("\n") + if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host { + builder.WriteString("Host: ") + builder.WriteString(request.Host) + builder.WriteString("\n") + } + for key, values := range request.Header { + for _, value := range values { + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") + } + } + if len(body) > 0 { + builder.WriteString("\n") + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + builder.Write(body) + } else { + builder.WriteString("(body not printable)") + } + } + e.logger.InfoContext(ctx, "request: ", builder.String()) +} + +func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) { + var builder strings.Builder + builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL)) + builder.WriteString("\n") + for key, values := range response.Header { + for _, value := range values { + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") + } + } + if len(body) > 0 { + builder.WriteString("\n") + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + builder.Write(body) + } else { + builder.WriteString("(body not printable)") + } + } + e.logger.InfoContext(ctx, "response: ", builder.String()) +} + +type simpleResponseWriter struct { + statusCode int + header http.Header + body bytes.Buffer +} + +func (w *simpleResponseWriter) Build(request *http.Request) *http.Response { + return &http.Response{ + StatusCode: w.statusCode, + Status: http.StatusText(w.statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: w.header, + Body: io.NopCloser(&w.body), + } +} + +func (w *simpleResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *simpleResponseWriter) Write(b []byte) (int, error) { + return w.body.Write(b) +} + +func (w *simpleResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} diff --git a/option/certificate.go b/option/certificate.go index ab524b99..2206a787 100644 --- a/option/certificate.go +++ b/option/certificate.go @@ -11,6 +11,13 @@ type _CertificateOptions struct { Certificate badoption.Listable[string] `json:"certificate,omitempty"` CertificatePath badoption.Listable[string] `json:"certificate_path,omitempty"` CertificateDirectoryPath badoption.Listable[string] `json:"certificate_directory_path,omitempty"` + TLSDecryption *TLSDecryptionOptions `json:"tls_decryption,omitempty"` +} + +type TLSDecryptionOptions struct { + Enabled bool `json:"enabled,omitempty"` + KeyPair string `json:"key_pair_p12,omitempty"` + KeyPairPassword string `json:"key_pair_p12_password,omitempty"` } type CertificateOptions _CertificateOptions diff --git a/option/mitm.go b/option/mitm.go new file mode 100644 index 00000000..171c1057 --- /dev/null +++ b/option/mitm.go @@ -0,0 +1,19 @@ +package option + +import ( + "github.com/sagernet/sing/common/json/badoption" +) + +type MITMOptions struct { + Enabled bool `json:"enabled,omitempty"` + HTTP2Enabled bool `json:"http2_enabled,omitempty"` +} + +type MITMRouteOptions struct { + Enabled bool `json:"enabled,omitempty"` + Print bool `json:"print,omitempty"` + SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"surge_url_rewrite,omitempty"` + SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"surge_header_rewrite,omitempty"` + SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"surge_body_rewrite,omitempty"` + SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"surge_map_local,omitempty"` +} diff --git a/option/mitm_surge_urlrewrite.go b/option/mitm_surge_urlrewrite.go new file mode 100644 index 00000000..ffec1917 --- /dev/null +++ b/option/mitm_surge_urlrewrite.go @@ -0,0 +1,449 @@ +package option + +import ( + "encoding/base64" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json" +) + +type SurgeURLRewriteLine struct { + Pattern *regexp.Regexp + Destination *url.URL + Redirect bool + Reject bool +} + +func (l SurgeURLRewriteLine) String() string { + var fields []string + fields = append(fields, l.Pattern.String()) + if l.Reject { + fields = append(fields, "_") + } else { + fields = append(fields, l.Destination.String()) + } + switch { + case l.Redirect: + fields = append(fields, "302") + case l.Reject: + fields = append(fields, "reject") + default: + fields = append(fields, "header") + } + return encodeSurgeKeys(fields) +} + +func (l SurgeURLRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeURLRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: ", stringValue) + } else if len(fields) < 2 || len(fields) > 3 { + return E.New("invalid surge_url_rewrite line: ", stringValue) + } + pattern, err := regexp.Compile(fields[0].Key) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: invalid pattern: ", stringValue) + } + l.Pattern = pattern + l.Destination, err = url.Parse(fields[1].Key) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: invalid destination: ", stringValue) + } + if len(fields) == 3 { + switch fields[2].Key { + case "header": + case "302": + l.Redirect = true + case "reject": + l.Reject = true + default: + return E.New("invalid surge_url_rewrite line: invalid action: ", stringValue) + } + } + return nil +} + +type SurgeHeaderRewriteLine struct { + Response bool + Pattern *regexp.Regexp + Add bool + Delete bool + Replace bool + ReplaceRegex bool + Key string + Match *regexp.Regexp + Value string +} + +func (l SurgeHeaderRewriteLine) String() string { + var fields []string + if !l.Response { + fields = append(fields, "http-request") + } else { + fields = append(fields, "http-response") + } + fields = append(fields, l.Pattern.String()) + if l.Add { + fields = append(fields, "header-add") + } else if l.Delete { + fields = append(fields, "header-del") + } else if l.Replace { + fields = append(fields, "header-replace") + } else if l.ReplaceRegex { + fields = append(fields, "header-replace-regex") + } + fields = append(fields, l.Key) + if l.Add || l.Replace { + fields = append(fields, l.Value) + } else if l.ReplaceRegex { + fields = append(fields, l.Match.String(), l.Value) + } + return encodeSurgeKeys(fields) +} + +func (l SurgeHeaderRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeHeaderRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: ", stringValue) + } else if len(fields) < 4 { + return E.New("invalid surge_header_rewrite line: ", stringValue) + } + switch fields[0].Key { + case "http-request": + case "http-response": + l.Response = true + default: + return E.New("invalid surge_header_rewrite line: invalid type: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[1].Key) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: invalid pattern: ", stringValue) + } + switch fields[2].Key { + case "header-add": + l.Add = true + if len(fields) != 5 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Value = fields[4].Key + case "header-del": + l.Delete = true + l.Key = fields[3].Key + case "header-replace": + l.Replace = true + if len(fields) != 5 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Value = fields[4].Key + case "header-replace-regex": + l.ReplaceRegex = true + if len(fields) != 6 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Match, err = regexp.Compile(fields[4].Key) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: invalid match: ", stringValue) + } + l.Value = fields[5].Key + default: + return E.New("invalid surge_header_rewrite line: invalid action: ", stringValue) + } + return nil +} + +type SurgeBodyRewriteLine struct { + Response bool + Pattern *regexp.Regexp + Match []*regexp.Regexp + Replace []string +} + +func (l SurgeBodyRewriteLine) String() string { + var fields []string + if !l.Response { + fields = append(fields, "http-request") + } else { + fields = append(fields, "http-response") + } + for i := 0; i < len(l.Match); i += 2 { + fields = append(fields, l.Match[i].String(), l.Replace[i]) + } + return strings.Join(fields, " ") +} + +func (l SurgeBodyRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeBodyRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_body_rewrite line: ", stringValue) + } else if len(fields) < 4 { + return E.New("invalid surge_body_rewrite line: ", stringValue) + } else if len(fields)%2 != 0 { + return E.New("invalid surge_body_rewrite line: ", stringValue) + } + switch fields[0].Key { + case "http-request": + case "http-response": + l.Response = true + default: + return E.New("invalid surge_body_rewrite line: invalid type: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[1].Key) + for i := 2; i < len(fields); i += 2 { + var match *regexp.Regexp + match, err = regexp.Compile(fields[i].Key) + if err != nil { + return E.Cause(err, "invalid surge_body_rewrite line: invalid match: ", stringValue) + } + l.Match = append(l.Match, match) + l.Replace = append(l.Replace, fields[i+1].Key) + } + return nil +} + +type SurgeMapLocalLine struct { + Pattern *regexp.Regexp + StatusCode int + File bool + Text bool + TinyGif bool + Base64 bool + Data string + Base64Data []byte + Headers http.Header +} + +func (l SurgeMapLocalLine) String() string { + var fields []surgeField + fields = append(fields, surgeField{Key: l.Pattern.String()}) + if l.File { + fields = append(fields, surgeField{Key: "data-type", Value: "file"}) + fields = append(fields, surgeField{Key: "data", Value: l.Data}) + } else if l.Text { + fields = append(fields, surgeField{Key: "data-type", Value: "text"}) + fields = append(fields, surgeField{Key: "data", Value: l.Data}) + } else if l.TinyGif { + fields = append(fields, surgeField{Key: "data-type", Value: "tiny-gif"}) + } else if l.Base64 { + fields = append(fields, surgeField{Key: "data-type", Value: "base64"}) + fields = append(fields, surgeField{Key: "data-type", Value: base64.StdEncoding.EncodeToString(l.Base64Data)}) + } + if l.StatusCode != 0 { + fields = append(fields, surgeField{Key: "status-code", Value: F.ToString(l.StatusCode), ValueSet: true}) + } + if len(l.Headers) > 0 { + var headers []string + for key, values := range l.Headers { + for _, value := range values { + headers = append(headers, key+":"+value) + } + } + fields = append(fields, surgeField{Key: "headers", Value: strings.Join(headers, "|")}) + } + return encodeSurgeFields(fields) +} + +func (l SurgeMapLocalLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeMapLocalLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_map_local line: ", stringValue) + } else if len(fields) < 1 { + return E.New("invalid surge_map_local line: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[0].Key) + if err != nil { + return E.Cause(err, "invalid surge_map_local line: invalid pattern: ", stringValue) + } + dataTypeField := common.Find(fields, func(it surgeField) bool { + return it.Key == "data-type" + }) + if !dataTypeField.ValueSet { + return E.New("invalid surge_map_local line: missing data-type: ", stringValue) + } + switch dataTypeField.Value { + case "file": + l.File = true + case "text": + l.Text = true + case "tiny-gif": + l.TinyGif = true + case "base64": + l.Base64 = true + default: + return E.New("unsupported data-type ", dataTypeField.Value) + } + for i := 1; i < len(fields); i++ { + switch fields[i].Key { + case "data-type": + continue + case "data": + if l.File { + l.Data = fields[i].Value + } else if l.Text { + l.Data = fields[i].Value + } else if l.Base64 { + l.Base64Data, err = base64.StdEncoding.DecodeString(fields[i].Value) + if err != nil { + return E.New("invalid surge_map_local line: invalid base64 data: ", stringValue) + } + } + case "status-code": + statusCode, err := strconv.ParseInt(fields[i].Value, 10, 16) + if err != nil { + return E.New("invalid surge_map_local line: invalid status code: ", stringValue) + } + l.StatusCode = int(statusCode) + case "header": + headers := make(http.Header) + for _, headerLine := range strings.Split(fields[i].Value, "|") { + if !strings.Contains(headerLine, ":") { + return E.New("invalid surge_map_local line: headers: missing `:` in item: ", stringValue, ": ", headerLine) + } + headers.Add(common.SubstringBefore(headerLine, ":"), common.SubstringAfter(headerLine, ":")) + } + l.Headers = headers + default: + return E.New("invalid surge_map_local line: unknown options: ", fields[i].Key) + } + } + return nil +} + +type surgeField struct { + Key string + Value string + ValueSet bool +} + +func encodeSurgeKeys(keys []string) string { + keys = common.Map(keys, func(it string) string { + if strings.ContainsFunc(it, unicode.IsSpace) { + return "\"" + it + "\"" + } else { + return it + } + }) + return strings.Join(keys, " ") +} + +func encodeSurgeFields(fields []surgeField) string { + return strings.Join(common.Map(fields, func(it surgeField) string { + if !it.ValueSet { + if strings.ContainsFunc(it.Key, unicode.IsSpace) { + return "\"" + it.Key + "\"" + } else { + return it.Key + } + } else { + if strings.ContainsFunc(it.Value, unicode.IsSpace) { + return it.Key + "=\"" + it.Value + "\"" + } else { + return it.Key + "=" + it.Value + } + } + }), " ") +} + +func surgeFields(s string) ([]surgeField, error) { + var ( + fields []surgeField + currentField *surgeField + ) + for _, field := range strings.Fields(s) { + if currentField != nil { + field = " " + field + if strings.HasSuffix(field, "\"") { + field = field[:len(field)-1] + if !currentField.ValueSet { + currentField.Key += field + } else { + currentField.Value += field + } + fields = append(fields, *currentField) + currentField = nil + } else { + if !currentField.ValueSet { + currentField.Key += field + } else { + currentField.Value += field + } + } + continue + } + if !strings.Contains(field, "=") { + if strings.HasPrefix(field, "\"") { + field = field[1:] + if strings.HasSuffix(field, "\"") { + field = field[:len(field)-1] + } else { + currentField = &surgeField{Key: field} + continue + } + } + fields = append(fields, surgeField{Key: field}) + } else { + key := common.SubstringBefore(field, "=") + value := common.SubstringAfter(field, "=") + if strings.HasPrefix(value, "\"") { + value = value[1:] + if strings.HasSuffix(field, "\"") { + value = value[:len(value)-1] + } else { + currentField = &surgeField{Key: key, Value: value, ValueSet: true} + continue + } + } + fields = append(fields, surgeField{Key: key, Value: value, ValueSet: true}) + } + } + if currentField != nil { + return nil, E.New("invalid surge fields line: ", s) + } + return fields, nil +} diff --git a/option/options.go b/option/options.go index 168074ed..ebd49fba 100644 --- a/option/options.go +++ b/option/options.go @@ -12,13 +12,14 @@ type _Options struct { Schema string `json:"$schema,omitempty"` Log *LogOptions `json:"log,omitempty"` DNS *DNSOptions `json:"dns,omitempty"` - NTP *NTPOptions `json:"ntp,omitempty"` - Certificate *CertificateOptions `json:"certificate,omitempty"` Endpoints []Endpoint `json:"endpoints,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` Route *RouteOptions `json:"route,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"` + NTP *NTPOptions `json:"ntp,omitempty"` + Certificate *CertificateOptions `json:"certificate,omitempty"` + MITM *MITMOptions `json:"mitm,omitempty"` } type Options _Options diff --git a/option/rule_action.go b/option/rule_action.go index 7c05dce6..473d371f 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -158,6 +158,8 @@ type RawRouteOptionsActionOptions struct { TLSFragment bool `json:"tls_fragment,omitempty"` TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` + + MITM *MITMRouteOptions `json:"mitm,omitempty"` } type RouteOptionsActionOptions RawRouteOptionsActionOptions diff --git a/route/conn.go b/route/conn.go index c2a2eab9..33c02a44 100644 --- a/route/conn.go +++ b/route/conn.go @@ -24,23 +24,31 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { + ctx context.Context logger logger.ContextLogger + mitm adapter.MITMEngine access sync.Mutex connections list.List[io.Closer] } -func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { +func NewConnectionManager(ctx context.Context, logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ + ctx: ctx, logger: logger, } } func (m *ConnectionManager) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateInitialize: + m.mitm = service.FromContext[adapter.MITMEngine](m.ctx) + } return nil } @@ -55,6 +63,14 @@ func (m *ConnectionManager) Close() error { } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.MITM != nil && metadata.MITM.Enabled { + if m.mitm == nil { + m.logger.WarnContext(ctx, "MITM disabled") + } else { + m.mitm.NewConnection(ctx, this, conn, metadata, onClose) + return + } + } ctx = adapter.WithContext(ctx, &metadata) var ( remoteConn net.Conn diff --git a/route/route.go b/route/route.go index 531ad039..72ecbd34 100644 --- a/route/route.go +++ b/route/route.go @@ -458,6 +458,9 @@ match: metadata.TLSFragment = true metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay } + if routeOptions.MITM != nil && routeOptions.MITM.Enabled { + metadata.MITM = routeOptions.MITM + } } switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index f49baca6..73bc2bae 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPConnect: action.RouteOptions.UDPConnect, TLSFragment: action.RouteOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), + MITM: action.RouteOptions.MITM, }, }, nil case C.RuleActionTypeRouteOptions: @@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), TLSFragment: action.RouteOptionsOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), + MITM: action.RouteOptionsOptions.MITM, }, nil case C.RuleActionTypeDirect: directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) @@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string { func (r *RuleActionRoute) String() string { var descriptions []string descriptions = append(descriptions, r.Outbound) - if r.UDPDisableDomainUnmapping { - descriptions = append(descriptions, "udp-disable-domain-unmapping") - } - if r.UDPConnect { - descriptions = append(descriptions, "udp-connect") - } - if r.TLSFragment { - descriptions = append(descriptions, "tls-fragment") - } + descriptions = append(descriptions, r.Descriptions()...) return F.ToString("route(", strings.Join(descriptions, ","), ")") } @@ -176,13 +170,14 @@ type RuleActionRouteOptions struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + MITM *option.MITMRouteOptions } func (r *RuleActionRouteOptions) Type() string { return C.RuleActionTypeRouteOptions } -func (r *RuleActionRouteOptions) String() string { +func (r *RuleActionRouteOptions) Descriptions() []string { var descriptions []string if r.OverrideAddress.IsValid() { descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString())) @@ -209,9 +204,22 @@ func (r *RuleActionRouteOptions) String() string { descriptions = append(descriptions, "udp-connect") } if r.UDPTimeout > 0 { - descriptions = append(descriptions, "udp-timeout") + descriptions = append(descriptions, F.ToString("udp-timeout=", r.UDPTimeout)) } - return F.ToString("route-options(", strings.Join(descriptions, ","), ")") + if r.TLSFragment { + descriptions = append(descriptions, "tls-fragment") + if r.TLSFragmentFallbackDelay > 0 { + descriptions = append(descriptions, F.ToString("tls-fragment-fallbac-delay=", r.TLSFragmentFallbackDelay.String())) + } + } + if r.MITM != nil && r.MITM.Enabled { + descriptions = append(descriptions, "mitm") + } + return descriptions +} + +func (r *RuleActionRouteOptions) String() string { + return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")") } type RuleActionDNSRoute struct {