From 5cca8893c9a351a1834b03b891cb46ff00d1e5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 20 Mar 2025 09:12:48 +0800 Subject: [PATCH] Add Surge MITM and scripts --- .gitignore | 1 + .golangci.yml | 1 + .goreleaser.yaml | 2 + Makefile | 2 +- adapter/certificate.go | 3 + adapter/experimental.go | 4 + adapter/inbound.go | 5 + adapter/lifecycle.go | 7 +- adapter/mitm.go | 13 + adapter/script.go | 54 + box.go | 51 +- cmd/internal/build_libbox/main.go | 2 +- cmd/sing-box/cmd_generate_ca.go | 121 ++ cmd/sing-box/cmd_tools.go | 38 - cmd/sing-box/cmd_tools_connect.go | 73 -- cmd/sing-box/cmd_tools_fetch.go | 115 -- cmd/sing-box/cmd_tools_fetch_http3.go | 36 - cmd/sing-box/cmd_tools_fetch_http3_stub.go | 18 - cmd/sing-box/cmd_tools_install_ca.go | 108 ++ cmd/sing-box/cmd_tools_synctime.go | 12 +- common/certificate/store.go | 31 + common/sniff/http.go | 1 + common/sniff/tls.go | 1 + common/tls/mkcert.go | 38 +- constant/script.go | 7 + experimental/cachefile/cache.go | 77 +- experimental/clashapi/mitm.go | 186 +++ experimental/clashapi/server.go | 1 + experimental/libbox/platform/interface.go | 5 + go.mod | 13 +- go.sum | 30 +- log/log.go | 6 +- mitm/constants.go | 11 + mitm/engine.go | 1099 +++++++++++++++++ option/certificate.go | 7 + option/mitm.go | 31 + option/mitm_surge_urlrewrite.go | 449 +++++++ option/options.go | 6 +- option/rule_action.go | 2 + option/script.go | 128 ++ route/conn.go | 18 +- route/route.go | 3 + route/rule/rule_action.go | 32 +- script/jsc/array.go | 23 + script/jsc/array_test.go | 18 + script/jsc/assert.go | 124 ++ script/jsc/class.go | 192 +++ script/jsc/headers.go | 56 + script/jsc/headers_test.go | 31 + script/jsc/iterator.go | 36 + script/jsc/time.go | 18 + script/jsc/time_test.go | 20 + script/jstest/assert.js | 83 ++ script/jstest/test.go | 21 + script/manager.go | 118 ++ script/manager_stub.go | 43 + script/modules/boxctx/context.go | 50 + script/modules/boxctx/module.go | 35 + script/modules/console/console.go | 281 +++++ script/modules/console/context.go | 3 + script/modules/console/module.go | 34 + script/modules/eventloop/eventloop.go | 489 ++++++++ script/modules/require/module.go | 231 ++++ script/modules/require/resolve.go | 277 +++++ script/modules/sgnotification/module.go | 111 ++ script/modules/surge/environment.go | 65 + script/modules/surge/http.go | 150 +++ script/modules/surge/module.go | 63 + script/modules/surge/notification.go | 120 ++ script/modules/surge/persistent_store.go | 78 ++ script/modules/surge/script.go | 32 + script/modules/surge/utils.go | 50 + script/modules/url/escape.go | 55 + script/modules/url/module.go | 41 + script/modules/url/module_test.go | 37 + .../url/testdata/url_search_params_test.js | 385 ++++++ script/modules/url/testdata/url_test.js | 229 ++++ script/modules/url/url.go | 315 +++++ script/modules/url/url_search_params.go | 244 ++++ script/runtime.go | 49 + script/script.go | 22 + script/script_surge.go | 347 ++++++ script/source.go | 33 + script/source_local.go | 94 ++ script/source_remote.go | 226 ++++ 85 files changed, 7422 insertions(+), 355 deletions(-) create mode 100644 adapter/mitm.go create mode 100644 adapter/script.go create mode 100644 cmd/sing-box/cmd_generate_ca.go delete mode 100644 cmd/sing-box/cmd_tools_connect.go delete mode 100644 cmd/sing-box/cmd_tools_fetch.go delete mode 100644 cmd/sing-box/cmd_tools_fetch_http3.go delete mode 100644 cmd/sing-box/cmd_tools_fetch_http3_stub.go create mode 100644 cmd/sing-box/cmd_tools_install_ca.go create mode 100644 constant/script.go create mode 100644 experimental/clashapi/mitm.go create mode 100644 mitm/constants.go create mode 100644 mitm/engine.go create mode 100644 option/mitm.go create mode 100644 option/mitm_surge_urlrewrite.go create mode 100644 option/script.go create mode 100644 script/jsc/array.go create mode 100644 script/jsc/array_test.go create mode 100644 script/jsc/assert.go create mode 100644 script/jsc/class.go create mode 100644 script/jsc/headers.go create mode 100644 script/jsc/headers_test.go create mode 100644 script/jsc/iterator.go create mode 100644 script/jsc/time.go create mode 100644 script/jsc/time_test.go create mode 100644 script/jstest/assert.js create mode 100644 script/jstest/test.go create mode 100644 script/manager.go create mode 100644 script/manager_stub.go create mode 100644 script/modules/boxctx/context.go create mode 100644 script/modules/boxctx/module.go create mode 100644 script/modules/console/console.go create mode 100644 script/modules/console/context.go create mode 100644 script/modules/console/module.go create mode 100644 script/modules/eventloop/eventloop.go create mode 100644 script/modules/require/module.go create mode 100644 script/modules/require/resolve.go create mode 100644 script/modules/sgnotification/module.go create mode 100644 script/modules/surge/environment.go create mode 100644 script/modules/surge/http.go create mode 100644 script/modules/surge/module.go create mode 100644 script/modules/surge/notification.go create mode 100644 script/modules/surge/persistent_store.go create mode 100644 script/modules/surge/script.go create mode 100644 script/modules/surge/utils.go create mode 100644 script/modules/url/escape.go create mode 100644 script/modules/url/module.go create mode 100644 script/modules/url/module_test.go create mode 100644 script/modules/url/testdata/url_search_params_test.js create mode 100644 script/modules/url/testdata/url_test.js create mode 100644 script/modules/url/url.go create mode 100644 script/modules/url/url_search_params.go create mode 100644 script/runtime.go create mode 100644 script/script.go create mode 100644 script/script_surge.go create mode 100644 script/source.go create mode 100644 script/source_local.go create mode 100644 script/source_remote.go diff --git a/.gitignore b/.gitignore index 60eb851e..b0742392 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ /.idea/ /vendor/ /*.json +/*.js /*.srs /*.db /site/ diff --git a/.golangci.yml b/.golangci.yml index d212ebb2..ca0cbe14 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -31,6 +31,7 @@ run: - with_reality_server - with_acme - with_clash_api + - with_script issues: exclude-dirs: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 4a7efcf4..ba2faaa5 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -21,6 +21,7 @@ builds: - with_acme - with_clash_api - with_tailscale + - with_script env: - CGO_ENABLED=0 - GOTOOLCHAIN=local @@ -51,6 +52,7 @@ builds: - with_acme - with_clash_api - with_tailscale + - with_script env: - CGO_ENABLED=0 - GOROOT={{ .Env.GOPATH }}/go_legacy diff --git a/Makefile b/Makefile index cfd8cc12..0fd920e9 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ NAME = sing-box COMMIT = $(shell git rev-parse --short HEAD) -TAGS ?= with_gvisor,with_dhcp,with_wireguard,with_reality_server,with_clash_api,with_quic,with_utls,with_tailscale +TAGS ?= with_gvisor,with_dhcp,with_wireguard,with_reality_server,with_clash_api,with_quic,with_utls,with_tailscale,with_script TAGS_TEST ?= with_gvisor,with_quic,with_wireguard,with_grpc,with_utls,with_reality_server GOHOSTOS = $(shell go env GOHOSTOS) diff --git a/adapter/certificate.go b/adapter/certificate.go index 0998e130..11eeb196 100644 --- a/adapter/certificate.go +++ b/adapter/certificate.go @@ -10,6 +10,9 @@ import ( type CertificateStore interface { LifecycleService Pool() *x509.CertPool + TLSDecryptionEnabled() bool + TLSDecryptionCertificate() *x509.Certificate + TLSDecryptionPrivateKey() any } func RootPoolFromContext(ctx context.Context) *x509.CertPool { diff --git a/adapter/experimental.go b/adapter/experimental.go index de01d7be..ec29a443 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -52,6 +52,10 @@ type CacheFile interface { StoreGroupExpand(group string, expand bool) error LoadRuleSet(tag string) *SavedBinary SaveRuleSet(tag string, set *SavedBinary) error + LoadScript(tag string) *SavedBinary + SaveScript(tag string, script *SavedBinary) error + SurgePersistentStoreRead(key string) string + SurgePersistentStoreWrite(key string, value string) error } type SavedBinary struct { diff --git a/adapter/inbound.go b/adapter/inbound.go index 1218c049..869cdec9 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,6 +2,8 @@ package adapter import ( "context" + "crypto/tls" + "net/http" "net/netip" "time" @@ -58,6 +60,8 @@ type InboundContext struct { Client string SniffContext any PacketSniffError error + HTTPRequest *http.Request + ClientHello *tls.ClientHelloInfo // cache @@ -74,6 +78,7 @@ type InboundContext struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + MITM *option.MITMRouteOptions NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType diff --git a/adapter/lifecycle.go b/adapter/lifecycle.go index aff9fadb..9e522141 100644 --- a/adapter/lifecycle.go +++ b/adapter/lifecycle.go @@ -1,6 +1,8 @@ package adapter -import E "github.com/sagernet/sing/common/exceptions" +import ( + E "github.com/sagernet/sing/common/exceptions" +) type StartStage uint8 @@ -45,6 +47,9 @@ type LifecycleService interface { func Start(stage StartStage, services ...Lifecycle) error { for _, service := range services { + if service == nil { + continue + } err := service.Start(stage) if err != nil { return err diff --git a/adapter/mitm.go b/adapter/mitm.go new file mode 100644 index 00000000..450468a9 --- /dev/null +++ b/adapter/mitm.go @@ -0,0 +1,13 @@ +package adapter + +import ( + "context" + "net" + + N "github.com/sagernet/sing/common/network" +) + +type MITMEngine interface { + Lifecycle + NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc) +} diff --git a/adapter/script.go b/adapter/script.go new file mode 100644 index 00000000..3967ed92 --- /dev/null +++ b/adapter/script.go @@ -0,0 +1,54 @@ +package adapter + +import ( + "context" + "net/http" + "sync" + "time" +) + +type ScriptManager interface { + Lifecycle + Scripts() []Script + Script(name string) (Script, bool) + SurgeCache() *SurgeInMemoryCache +} + +type SurgeInMemoryCache struct { + sync.RWMutex + Data map[string]string +} + +type Script interface { + Type() string + Tag() string + StartContext(ctx context.Context, startContext *HTTPStartContext) error + PostStart() error + Close() error +} + +type SurgeScript interface { + Script + ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error + ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*HTTPRequestScriptResult, error) + ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*HTTPResponseScriptResult, error) +} + +type HTTPRequestScriptResult struct { + URL string + Headers http.Header + Body []byte + Response *HTTPRequestScriptResponse +} + +type HTTPRequestScriptResponse struct { + Status int + Headers http.Header + Body []byte +} + +type HTTPResponseScriptResult struct { + Status int + Headers http.Header + Body []byte +} diff --git a/box.go b/box.go index 0f176474..a7356ae3 100644 --- a/box.go +++ b/box.go @@ -23,9 +23,11 @@ import ( "github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/mitm" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/protocol/direct" "github.com/sagernet/sing-box/route" + "github.com/sagernet/sing-box/script" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -48,6 +50,8 @@ type Box struct { dnsRouter *dns.Router connection *route.ConnectionManager router *route.Router + script *script.Manager + mitm adapter.MITMEngine //*mitm.Engine services []adapter.LifecycleService done chan struct{} } @@ -143,18 +147,12 @@ func New(options Options) (*Box, error) { } var services []adapter.LifecycleService - certificateOptions := common.PtrValueOrDefault(options.Certificate) - if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem || - len(certificateOptions.Certificate) > 0 || - len(certificateOptions.CertificatePath) > 0 || - len(certificateOptions.CertificateDirectoryPath) > 0 { - certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions) - if err != nil { - return nil, err - } - service.MustRegister[adapter.CertificateStore](ctx, certificateStore) - services = append(services, certificateStore) + certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), common.PtrValueOrDefault(options.Certificate)) + if err != nil { + return nil, err } + service.MustRegister[adapter.CertificateStore](ctx, certificateStore) + services = append(services, certificateStore) routeOptions := common.PtrValueOrDefault(options.Route) dnsOptions := common.PtrValueOrDefault(options.DNS) @@ -173,7 +171,7 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize network manager") } service.MustRegister[adapter.NetworkManager](ctx, networkManager) - connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) + connectionManager := route.NewConnectionManager(ctx, logFactory.NewLogger("connection")) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) service.MustRegister[adapter.Router](ctx, router) @@ -181,8 +179,8 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "initialize router") } - ntpOptions := common.PtrValueOrDefault(options.NTP) var timeService *tls.TimeServiceWrapper + ntpOptions := common.PtrValueOrDefault(options.NTP) if ntpOptions.Enabled { timeService = new(tls.TimeServiceWrapper) service.MustRegister[ntp.TimeService](ctx, timeService) @@ -296,6 +294,11 @@ func New(options Options) (*Box, error) { "local", option.LocalDNSServerOptions{}, ))) + scriptManager, err := script.NewManager(ctx, logFactory, options.Scripts) + if err != nil { + return nil, E.Cause(err, "initialize script manager") + } + service.MustRegister[adapter.ScriptManager](ctx, scriptManager) if platformInterface != nil { err = platformInterface.Initialize(networkManager) if err != nil { @@ -345,6 +348,16 @@ func New(options Options) (*Box, error) { timeService.TimeService = ntpService services = append(services, adapter.NewLifecycleService(ntpService, "ntp service")) } + mitmOptions := common.PtrValueOrDefault(options.MITM) + var mitmEngine adapter.MITMEngine + if mitmOptions.Enabled { + engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions) + if err != nil { + return nil, E.Cause(err, "create MITM engine") + } + service.MustRegister[adapter.MITMEngine](ctx, engine) + mitmEngine = engine + } return &Box{ network: networkManager, endpoint: endpointManager, @@ -354,6 +367,8 @@ func New(options Options) (*Box, error) { dnsRouter: dnsRouter, connection: connectionManager, router: router, + script: scriptManager, + mitm: mitmEngine, createdAt: createdAt, logFactory: logFactory, logger: logFactory.Logger(), @@ -412,11 +427,11 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router, s.script, s.mitm) if err != nil { return err } @@ -440,7 +455,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.inbound, s.endpoint) if err != nil { return err } @@ -448,7 +463,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint) if err != nil { return err } @@ -467,7 +482,7 @@ func (s *Box) Close() error { close(s.done) } err := common.Close( - s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, + s.inbound, s.outbound, s.endpoint, s.mitm, s.script, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, ) for _, lifecycleService := range s.services { err = E.Append(err, lifecycleService.Close(), func(err error) error { diff --git a/cmd/internal/build_libbox/main.go b/cmd/internal/build_libbox/main.go index 8f82ff36..f349612b 100644 --- a/cmd/internal/build_libbox/main.go +++ b/cmd/internal/build_libbox/main.go @@ -59,7 +59,7 @@ func init() { sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid=") debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag) - sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_utls", "with_clash_api") + sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_utls", "with_clash_api", "with_script") iosTags = append(iosTags, "with_dhcp", "with_low_memory", "with_conntrack") memcTags = append(memcTags, "with_tailscale") debugTags = append(debugTags, "debug") diff --git a/cmd/sing-box/cmd_generate_ca.go b/cmd/sing-box/cmd_generate_ca.go new file mode 100644 index 00000000..11bd6ceb --- /dev/null +++ b/cmd/sing-box/cmd_generate_ca.go @@ -0,0 +1,121 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" + + "github.com/spf13/cobra" + "software.sslmate.com/src/go-pkcs12" +) + +var ( + flagGenerateCAName string + flagGenerateCAPKCS12Password string + flagGenerateOutput string +) + +var commandGenerateCAKeyPair = &cobra.Command{ + Use: "ca-keypair", + Short: "Generate CA key pair", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + err := generateCAKeyPair() + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAName, "name", "n", "", "Set custom CA name") + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateCAPKCS12Password, "p12-password", "p", "", "Set custom PKCS12 password") + commandGenerateCAKeyPair.Flags().StringVarP(&flagGenerateOutput, "output", "o", ".", "Set output directory") + commandGenerate.AddCommand(commandGenerateCAKeyPair) +} + +func generateCAKeyPair() error { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return err + } + spkiASN1, err := x509.MarshalPKIXPublicKey(privateKey.Public()) + var spki struct { + Algorithm pkix.AlgorithmIdentifier + SubjectPublicKey asn1.BitString + } + _, err = asn1.Unmarshal(spkiASN1, &spki) + if err != nil { + return err + } + skid := sha1.Sum(spki.SubjectPublicKey.Bytes) + var caName string + if flagGenerateCAName != "" { + caName = flagGenerateCAName + } else { + caName = "sing-box Generated CA " + strings.ToUpper(hex.EncodeToString(skid[:4])) + } + caTpl := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{caName}, + CommonName: caName, + }, + SubjectKeyId: skid[:], + NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: time.Now(), + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + publicDer, err := x509.CreateCertificate(rand.Reader, caTpl, caTpl, privateKey.Public(), privateKey) + var caPassword string + if flagGenerateCAPKCS12Password != "" { + caPassword = flagGenerateCAPKCS12Password + } else { + caPassword = strings.ToUpper(hex.EncodeToString(skid[:4])) + } + caTpl.Raw = publicDer + p12Bytes, err := pkcs12.Modern.Encode(privateKey, caTpl, nil, caPassword) + if err != nil { + return err + } + privateDer, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return err + } + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".pem"), pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644) + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".private.pem"), pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateDer}), 0o644) + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".crt"), publicDer, 0o644) + os.WriteFile(filepath.Join(flagGenerateOutput, caName+".p12"), p12Bytes, 0o644) + var tlsDecryptionOptions option.TLSDecryptionOptions + tlsDecryptionOptions.Enabled = true + tlsDecryptionOptions.KeyPair = base64.StdEncoding.EncodeToString(p12Bytes) + tlsDecryptionOptions.KeyPairPassword = caPassword + var certificateOptions option.CertificateOptions + certificateOptions.TLSDecryption = &tlsDecryptionOptions + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(certificateOptions) +} diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go index 55e5b458..5f5a0d71 100644 --- a/cmd/sing-box/cmd_tools.go +++ b/cmd/sing-box/cmd_tools.go @@ -1,13 +1,6 @@ package main import ( - "errors" - "os" - - "github.com/sagernet/sing-box" - E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" - "github.com/spf13/cobra" ) @@ -19,36 +12,5 @@ var commandTools = &cobra.Command{ } func init() { - commandTools.PersistentFlags().StringVarP(&commandToolsFlagOutbound, "outbound", "o", "", "Use specified tag instead of default outbound") mainCommand.AddCommand(commandTools) } - -func createPreStartedClient() (*box.Box, error) { - options, err := readConfigAndMerge() - if err != nil { - if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" { - return nil, err - } - } - instance, err := box.New(box.Options{Context: globalCtx, Options: options}) - if err != nil { - return nil, E.Cause(err, "create service") - } - err = instance.PreStart() - if err != nil { - return nil, E.Cause(err, "start service") - } - return instance, nil -} - -func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) { - if outboundTag == "" { - return instance.Outbound().Default(), nil - } else { - outbound, loaded := instance.Outbound().Outbound(outboundTag) - if !loaded { - return nil, E.New("outbound not found: ", outboundTag) - } - return outbound, nil - } -} diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go deleted file mode 100644 index d352d533..00000000 --- a/cmd/sing-box/cmd_tools_connect.go +++ /dev/null @@ -1,73 +0,0 @@ -package main - -import ( - "context" - "os" - - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/task" - - "github.com/spf13/cobra" -) - -var commandConnectFlagNetwork string - -var commandConnect = &cobra.Command{ - Use: "connect
", - Short: "Connect to an address", - Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - err := connect(args[0]) - if err != nil { - log.Fatal(err) - } - }, -} - -func init() { - commandConnect.Flags().StringVarP(&commandConnectFlagNetwork, "network", "n", "tcp", "network type") - commandTools.AddCommand(commandConnect) -} - -func connect(address string) error { - switch N.NetworkName(commandConnectFlagNetwork) { - case N.NetworkTCP, N.NetworkUDP: - default: - return E.Cause(N.ErrUnknownNetwork, commandConnectFlagNetwork) - } - instance, err := createPreStartedClient() - if err != nil { - return err - } - defer instance.Close() - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - conn, err := dialer.DialContext(context.Background(), commandConnectFlagNetwork, M.ParseSocksaddr(address)) - if err != nil { - return E.Cause(err, "connect to server") - } - var group task.Group - group.Append("upload", func(ctx context.Context) error { - return common.Error(bufio.Copy(conn, os.Stdin)) - }) - group.Append("download", func(ctx context.Context) error { - return common.Error(bufio.Copy(os.Stdout, conn)) - }) - group.Cleanup(func() { - conn.Close() - }) - err = group.Run(context.Background()) - if E.IsClosed(err) { - log.Info(err) - } else { - log.Error(err) - } - return nil -} diff --git a/cmd/sing-box/cmd_tools_fetch.go b/cmd/sing-box/cmd_tools_fetch.go deleted file mode 100644 index 5ee3b875..00000000 --- a/cmd/sing-box/cmd_tools_fetch.go +++ /dev/null @@ -1,115 +0,0 @@ -package main - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "net/url" - "os" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - - "github.com/spf13/cobra" -) - -var commandFetch = &cobra.Command{ - Use: "fetch", - Short: "Fetch an URL", - Args: cobra.MinimumNArgs(1), - Run: func(cmd *cobra.Command, args []string) { - err := fetch(args) - if err != nil { - log.Fatal(err) - } - }, -} - -func init() { - commandTools.AddCommand(commandFetch) -} - -var ( - httpClient *http.Client - http3Client *http.Client -) - -func fetch(args []string) error { - instance, err := createPreStartedClient() - if err != nil { - return err - } - defer instance.Close() - httpClient = &http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return nil, err - } - return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - ForceAttemptHTTP2: true, - }, - } - defer httpClient.CloseIdleConnections() - if C.WithQUIC { - err = initializeHTTP3Client(instance) - if err != nil { - return err - } - defer http3Client.CloseIdleConnections() - } - for _, urlString := range args { - var parsedURL *url.URL - parsedURL, err = url.Parse(urlString) - if err != nil { - return err - } - switch parsedURL.Scheme { - case "": - parsedURL.Scheme = "http" - fallthrough - case "http", "https": - err = fetchHTTP(httpClient, parsedURL) - if err != nil { - return err - } - case "http3": - if !C.WithQUIC { - return C.ErrQUICNotIncluded - } - parsedURL.Scheme = "https" - err = fetchHTTP(http3Client, parsedURL) - if err != nil { - return err - } - default: - return E.New("unsupported scheme: ", parsedURL.Scheme) - } - } - return nil -} - -func fetchHTTP(httpClient *http.Client, parsedURL *url.URL) error { - request, err := http.NewRequest("GET", parsedURL.String(), nil) - if err != nil { - return err - } - request.Header.Add("User-Agent", "curl/7.88.0") - response, err := httpClient.Do(request) - if err != nil { - return err - } - defer response.Body.Close() - _, err = bufio.Copy(os.Stdout, response.Body) - if errors.Is(err, io.EOF) { - return nil - } - return err -} diff --git a/cmd/sing-box/cmd_tools_fetch_http3.go b/cmd/sing-box/cmd_tools_fetch_http3.go deleted file mode 100644 index b7a31a72..00000000 --- a/cmd/sing-box/cmd_tools_fetch_http3.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build with_quic - -package main - -import ( - "context" - "crypto/tls" - "net/http" - - "github.com/sagernet/quic-go" - "github.com/sagernet/quic-go/http3" - box "github.com/sagernet/sing-box" - "github.com/sagernet/sing/common/bufio" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -func initializeHTTP3Client(instance *box.Box) error { - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - http3Client = &http.Client{ - Transport: &http3.Transport{ - Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - destination := M.ParseSocksaddr(addr) - udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, destination) - if dErr != nil { - return nil, dErr - } - return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsCfg, cfg) - }, - }, - } - return nil -} diff --git a/cmd/sing-box/cmd_tools_fetch_http3_stub.go b/cmd/sing-box/cmd_tools_fetch_http3_stub.go deleted file mode 100644 index ae13f54c..00000000 --- a/cmd/sing-box/cmd_tools_fetch_http3_stub.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build !with_quic - -package main - -import ( - "net/url" - "os" - - box "github.com/sagernet/sing-box" -) - -func initializeHTTP3Client(instance *box.Box) error { - return os.ErrInvalid -} - -func fetchHTTP3(parsedURL *url.URL) error { - return os.ErrInvalid -} diff --git a/cmd/sing-box/cmd_tools_install_ca.go b/cmd/sing-box/cmd_tools_install_ca.go new file mode 100644 index 00000000..80ec14b1 --- /dev/null +++ b/cmd/sing-box/cmd_tools_install_ca.go @@ -0,0 +1,108 @@ +package main + +import ( + "encoding/pem" + "errors" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/shell" + + "github.com/spf13/cobra" +) + +var commandInstallCACertificate = &cobra.Command{ + Use: "install-ca ", + Short: "Install CA certificate to system", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + err := installCACertificate(args[0]) + if err != nil { + log.Fatal(err) + } + }, +} + +func init() { + commandTools.AddCommand(commandInstallCACertificate) +} + +func installCACertificate(path string) error { + switch runtime.GOOS { + case "windows": + return shell.Exec("powershell", "-Command", "Import-Certificate -FilePath \""+path+"\" -CertStoreLocation Cert:\\LocalMachine\\Root").Attach().Run() + case "darwin": + return shell.Exec("sudo", "security", "add-trusted-cert", "-d", "-r", "trustRoot", "-k", "/Library/Keychains/System.keychain", path).Attach().Run() + case "linux": + updateCertPath, updateCertPathNotFoundErr := exec.LookPath("update-ca-certificates") + if updateCertPathNotFoundErr == nil { + publicDer, err := os.ReadFile(path) + if err != nil { + return err + } + err = os.MkdirAll("/usr/local/share/ca-certificates", 0o755) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + fileName := filepath.Base(updateCertPath) + if !strings.HasSuffix(fileName, ".crt") { + fileName = fileName + ".crt" + } + filePath, _ := filepath.Abs(filepath.Join("/usr/local/share/ca-certificates", fileName)) + err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + log.Info("certificate written to " + filePath + "\n") + err = shell.Exec(updateCertPath).Attach().Run() + if err != nil { + return err + } + log.Info("certificate installed") + return nil + } + updateTrustPath, updateTrustPathNotFoundErr := exec.LookPath("update-ca-trust") + if updateTrustPathNotFoundErr == nil { + publicDer, err := os.ReadFile(path) + if err != nil { + return err + } + fileName := filepath.Base(updateTrustPath) + fileExt := filepath.Ext(path) + if fileExt != "" { + fileName = fileName[:len(fileName)-len(fileExt)] + } + filePath, _ := filepath.Abs(filepath.Join("/etc/pki/ca-trust/source/anchors/", fileName+".pem")) + err = os.WriteFile(filePath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicDer}), 0o644) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Info("Try running with sudo") + return shell.Exec("sudo", os.Args...).Attach().Run() + } + return err + } + log.Info("certificate written to " + filePath + "\n") + err = shell.Exec(updateTrustPath, "extract").Attach().Run() + if err != nil { + return err + } + log.Info("certificate installed") + } + return E.New("update-ca-certificates or update-ca-trust not found") + default: + return E.New("unsupported operating system: ", runtime.GOOS) + } +} diff --git a/cmd/sing-box/cmd_tools_synctime.go b/cmd/sing-box/cmd_tools_synctime.go index 09d487ef..33e38743 100644 --- a/cmd/sing-box/cmd_tools_synctime.go +++ b/cmd/sing-box/cmd_tools_synctime.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" "github.com/spf13/cobra" @@ -39,20 +40,11 @@ func init() { } func syncTime() error { - instance, err := createPreStartedClient() - if err != nil { - return err - } - dialer, err := createDialer(instance, commandToolsFlagOutbound) - if err != nil { - return err - } - defer instance.Close() serverAddress := M.ParseSocksaddr(commandSyncTimeFlagServer) if serverAddress.Port == 0 { serverAddress.Port = 123 } - response, err := ntp.Exchange(context.Background(), dialer, serverAddress) + response, err := ntp.Exchange(context.Background(), N.SystemDialer, serverAddress) if err != nil { return err } diff --git a/common/certificate/store.go b/common/certificate/store.go index 34f20019..d17f77b5 100644 --- a/common/certificate/store.go +++ b/common/certificate/store.go @@ -3,6 +3,7 @@ package certificate import ( "context" "crypto/x509" + "encoding/base64" "io/fs" "os" "path/filepath" @@ -16,6 +17,8 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/service" + + "software.sslmate.com/src/go-pkcs12" ) var _ adapter.CertificateStore = (*Store)(nil) @@ -27,6 +30,9 @@ type Store struct { certificatePaths []string certificateDirectoryPaths []string watcher *fswatch.Watcher + tlsDecryptionEnabled bool + tlsDecryptionPrivateKey any + tlsDecryptionCertificate *x509.Certificate } func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) { @@ -90,6 +96,19 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific if err != nil { return nil, E.Cause(err, "initializing certificate store") } + if options.TLSDecryption != nil && options.TLSDecryption.Enabled { + pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryption.KeyPair) + if err != nil { + return nil, E.Cause(err, "decode key pair base64 bytes") + } + privateKey, certificate, err := pkcs12.Decode(pfxBytes, options.TLSDecryption.KeyPairPassword) + if err != nil { + return nil, E.Cause(err, "decode key pair") + } + store.tlsDecryptionEnabled = true + store.tlsDecryptionPrivateKey = privateKey + store.tlsDecryptionCertificate = certificate + } return store, nil } @@ -183,3 +202,15 @@ func isSameDirSymlink(f fs.DirEntry, dir string) bool { target, err := os.Readlink(filepath.Join(dir, f.Name())) return err == nil && !strings.Contains(target, "/") } + +func (s *Store) TLSDecryptionEnabled() bool { + return s.tlsDecryptionEnabled +} + +func (s *Store) TLSDecryptionCertificate() *x509.Certificate { + return s.tlsDecryptionCertificate +} + +func (s *Store) TLSDecryptionPrivateKey() any { + return s.tlsDecryptionPrivateKey +} diff --git a/common/sniff/http.go b/common/sniff/http.go index 0e6ab406..e7c6eb8c 100644 --- a/common/sniff/http.go +++ b/common/sniff/http.go @@ -18,5 +18,6 @@ func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Rea } metadata.Protocol = C.ProtocolHTTP metadata.Domain = M.ParseSocksaddr(request.Host).AddrString() + metadata.HTTPRequest = request return nil } diff --git a/common/sniff/tls.go b/common/sniff/tls.go index 6fe430e2..27729fa2 100644 --- a/common/sniff/tls.go +++ b/common/sniff/tls.go @@ -21,6 +21,7 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade if clientHello != nil { metadata.Protocol = C.ProtocolTLS metadata.Domain = clientHello.ServerName + metadata.ClientHello = clientHello return nil } return err diff --git a/common/tls/mkcert.go b/common/tls/mkcert.go index 4e0ed102..fc9c4ab9 100644 --- a/common/tls/mkcert.go +++ b/common/tls/mkcert.go @@ -8,7 +8,10 @@ import ( "crypto/x509/pkix" "encoding/pem" "math/big" + "net" "time" + + M "github.com/sagernet/sing/common/metadata" ) func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) { @@ -35,17 +38,30 @@ func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func( if err != nil { return } - template := &x509.Certificate{ - SerialNumber: serialNumber, - NotBefore: timeFunc().Add(time.Hour * -1), - NotAfter: expire, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - Subject: pkix.Name{ - CommonName: serverName, - }, - DNSNames: []string{serverName}, + var template *x509.Certificate + if serverAddress := M.ParseAddr(serverName); serverAddress.IsValid() { + template = &x509.Certificate{ + SerialNumber: serialNumber, + IPAddresses: []net.IP{serverAddress.AsSlice()}, + NotBefore: timeFunc().Add(time.Hour * -1), + NotAfter: expire, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + } else { + template = &x509.Certificate{ + SerialNumber: serialNumber, + NotBefore: timeFunc().Add(time.Hour * -1), + NotAfter: expire, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + Subject: pkix.Name{ + CommonName: serverName, + }, + DNSNames: []string{serverName}, + } } if parent == nil { parent = template diff --git a/constant/script.go b/constant/script.go new file mode 100644 index 00000000..45574038 --- /dev/null +++ b/constant/script.go @@ -0,0 +1,7 @@ +package constant + +const ( + ScriptTypeSurge = "surge" + ScriptSourceTypeLocal = "local" + ScriptSourceTypeRemote = "remote" +) diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index 88cffdbe..fa1e28b2 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -19,10 +19,12 @@ import ( ) var ( - bucketSelected = []byte("selected") - bucketExpand = []byte("group_expand") - bucketMode = []byte("clash_mode") - bucketRuleSet = []byte("rule_set") + bucketSelected = []byte("selected") + bucketExpand = []byte("group_expand") + bucketMode = []byte("clash_mode") + bucketRuleSet = []byte("rule_set") + bucketScript = []byte("script") + bucketSgPersistentStore = []byte("sg_persistent_store") bucketNameList = []string{ string(bucketSelected), @@ -316,3 +318,70 @@ func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedBinary) error { return bucket.Put([]byte(tag), setBinary) }) } + +func (c *CacheFile) LoadScript(tag string) *adapter.SavedBinary { + var savedSet adapter.SavedBinary + err := c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketScript) + if bucket == nil { + return os.ErrNotExist + } + scriptBinary := bucket.Get([]byte(tag)) + if len(scriptBinary) == 0 { + return os.ErrInvalid + } + return savedSet.UnmarshalBinary(scriptBinary) + }) + if err != nil { + return nil + } + return &savedSet +} + +func (c *CacheFile) SaveScript(tag string, set *adapter.SavedBinary) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := c.createBucket(t, bucketScript) + if err != nil { + return err + } + scriptBinary, err := set.MarshalBinary() + if err != nil { + return err + } + return bucket.Put([]byte(tag), scriptBinary) + }) +} + +func (c *CacheFile) SurgePersistentStoreRead(key string) string { + var value string + _ = c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketSgPersistentStore) + if bucket == nil { + return nil + } + valueBinary := bucket.Get([]byte(key)) + if len(valueBinary) > 0 { + value = string(valueBinary) + } + return nil + }) + return value +} + +func (c *CacheFile) SurgePersistentStoreWrite(key string, value string) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + if value != "" { + bucket, err := c.createBucket(t, bucketSgPersistentStore) + if err != nil { + return err + } + return bucket.Put([]byte(key), []byte(value)) + } else { + bucket := c.bucket(t, bucketSgPersistentStore) + if bucket == nil { + return nil + } + return bucket.Delete([]byte(key)) + } + }) +} diff --git a/experimental/clashapi/mitm.go b/experimental/clashapi/mitm.go new file mode 100644 index 00000000..8d64081d --- /dev/null +++ b/experimental/clashapi/mitm.go @@ -0,0 +1,186 @@ +package clashapi + +import ( + "archive/zip" + "context" + "crypto/x509" + "encoding/pem" + "io" + "net/http" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/service" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/gofrs/uuid/v5" + "howett.net/plist" +) + +func mitmRouter(ctx context.Context) http.Handler { + r := chi.NewRouter() + r.Get("/mobileconfig", getMobileConfig(ctx)) + r.Get("/crt", getCertificate(ctx)) + r.Get("/pem", getCertificatePEM(ctx)) + r.Get("/magisk", getMagiskModule(ctx)) + return r +} + +func getMobileConfig(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + certificate := store.TLSDecryptionCertificate() + writer.Header().Set("Content-Type", "application/x-apple-aspen-config") + uuidGen := common.Must1(uuid.NewV4()).String() + mobileConfig := map[string]interface{}{ + "PayloadContent": []interface{}{ + map[string]interface{}{ + "PayloadCertificateFileName": "Certificates.cer", + "PayloadContent": certificate.Raw, + "PayloadDescription": "Adds a root certificate", + "PayloadDisplayName": certificate.Subject.CommonName, + "PayloadIdentifier": "com.apple.security.root." + uuidGen, + "PayloadType": "com.apple.security.root", + "PayloadUUID": uuidGen, + "PayloadVersion": 1, + }, + }, + "PayloadDisplayName": certificate.Subject.CommonName, + "PayloadIdentifier": "io.nekohasekai.sfa.ca.profile." + uuidGen, + "PayloadRemovalDisallowed": false, + "PayloadType": "Configuration", + "PayloadUUID": uuidGen, + "PayloadVersion": 1, + } + encoder := plist.NewEncoder(writer) + encoder.Indent("\t") + encoder.Encode(mobileConfig) + } +} + +func getCertificate(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/x-x509-ca-cert") + writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.crt") + writer.Write(store.TLSDecryptionCertificate().Raw) + } +} + +func getCertificatePEM(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/x-pem-file") + writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.pem") + writer.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: store.TLSDecryptionCertificate().Raw})) + } +} + +func getMagiskModule(ctx context.Context) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + store := service.FromContext[adapter.CertificateStore](ctx) + if !store.TLSDecryptionEnabled() { + http.NotFound(writer, request) + render.PlainText(writer, request, "TLS decryption not enabled") + return + } + writer.Header().Set("Content-Type", "application/zip") + writer.Header().Set("Content-Disposition", "attachment; filename="+store.TLSDecryptionCertificate().Subject.CommonName+".zip") + createMagiskModule(writer, store.TLSDecryptionCertificate()) + } +} + +func createMagiskModule(writer io.Writer, certificate *x509.Certificate) error { + zipWriter := zip.NewWriter(writer) + defer zipWriter.Close() + moduleProp, err := zipWriter.Create("module.prop") + if err != nil { + return err + } + _, err = moduleProp.Write([]byte(` +id=sing-box-certificate +name=` + certificate.Subject.CommonName + ` +version=v0.0.1 +versionCode=1 +author=sing-box +description=This module adds ` + certificate.Subject.CommonName + ` to the system trust store. +`)) + if err != nil { + return err + } + certificateFile, err := zipWriter.Create("system/etc/security/cacerts/" + certificate.Subject.CommonName + ".pem") + if err != nil { + return err + } + err = pem.Encode(certificateFile, &pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw}) + if err != nil { + return err + } + updateBinary, err := zipWriter.Create("META-INF/com/google/android/update-binary") + if err != nil { + return err + } + _, err = updateBinary.Write([]byte(` +#!/sbin/sh + +################# +# Initialization +################# + +umask 022 + +# echo before loading util_functions +ui_print() { echo "$1"; } + +require_new_magisk() { + ui_print "*******************************" + ui_print " Please install Magisk v20.4+! " + ui_print "*******************************" + exit 1 +} + +######################### +# Load util_functions.sh +######################### + +OUTFD=$2 +ZIPFILE=$3 + +mount /data 2>/dev/null + +[ -f /data/adb/magisk/util_functions.sh ] || require_new_magisk +. /data/adb/magisk/util_functions.sh +[ $MAGISK_VER_CODE -lt 20400 ] && require_new_magisk + +install_module +exit 0 +`)) + if err != nil { + return err + } + updaterScript, err := zipWriter.Create("META-INF/com/google/android/updater-script") + if err != nil { + return err + } + _, err = updaterScript.Write([]byte("#MAGISK")) + if err != nil { + return err + } + return nil +} diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index e6d8c4cf..e471a13f 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -124,6 +124,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op r.Mount("/profile", profileRouter()) r.Mount("/cache", cacheRouter(ctx)) r.Mount("/dns", dnsRouter(s.dnsRouter)) + r.Mount("/mitm", mitmRouter(ctx)) s.setupMetaAPI(r) }) diff --git a/experimental/libbox/platform/interface.go b/experimental/libbox/platform/interface.go index 35b0830b..fef3fb11 100644 --- a/experimental/libbox/platform/interface.go +++ b/experimental/libbox/platform/interface.go @@ -32,4 +32,9 @@ type Notification struct { Subtitle string Body string OpenURL string + Clipboard string + MediaURL string + MediaData []byte + MediaType string + Timeout int } diff --git a/go.mod b/go.mod index 78f04e56..1829be47 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module github.com/sagernet/sing-box go 1.23.1 require ( + github.com/adhocore/gronx v1.19.5 github.com/anytls/sing-anytls v0.0.6 github.com/caddyserver/certmagic v0.21.7 github.com/cloudflare/circl v1.6.0 github.com/cretz/bine v0.2.0 + github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17 github.com/go-chi/chi/v5 v5.2.1 github.com/go-chi/render v1.0.3 github.com/gofrs/uuid/v5 v5.3.1 @@ -53,6 +55,7 @@ require ( google.golang.org/grpc v1.70.0 google.golang.org/protobuf v1.36.5 howett.net/plist v1.0.1 + software.sslmate.com/src/go-pkcs12 v0.4.0 ) //replace github.com/sagernet/sing => ../sing @@ -72,12 +75,14 @@ require ( github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 // indirect github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gaissmai/bart v0.11.1 // indirect github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 // indirect github.com/go-ole/go-ole v1.3.0 // indirect - github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 // indirect @@ -86,7 +91,7 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 // indirect - github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect + github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 // indirect github.com/gorilla/securecookie v1.1.2 // indirect @@ -104,7 +109,7 @@ require ( github.com/mdlayher/sdnotify v1.0.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect - github.com/onsi/ginkgo/v2 v2.17.2 // indirect + github.com/onsi/ginkgo/v2 v2.9.7 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus-community/pro-bing v0.4.0 // indirect @@ -139,5 +144,3 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.3.0 // indirect ) - -//replace github.com/sagernet/sing => ../sing diff --git a/go.sum b/go.sum index 02718a70..686c2646 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/adhocore/gronx v1.19.5 h1:cwIG4nT1v9DvadxtHBe6MzE+FZ1JDvAUC45U2fl4eSQ= +github.com/adhocore/gronx v1.19.5/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/akutz/memconn v0.1.0 h1:NawI0TORU4hcOMsMr11g7vwlCdkYeLKXBcxWu2W/P8A= @@ -30,6 +34,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo= github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk= @@ -38,6 +43,10 @@ github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbY github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1/go.mod h1:+hnT3ywWDTAFrW5aE+u2Sa/wT555ZqwoCS+pk3p6ry4= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17 h1:spJaibPy2sZNwo6Q0HjBVufq7hBUj5jNFOKRoogCBow= +github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= @@ -58,8 +67,10 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= -github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= @@ -83,8 +94,8 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 h1:wG8RYIyctLhdFk6Vl1yPGtSRtwGpVkWyZww1OCil2MI= github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= -github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 h1:k7nVchz72niMH6YLQNvHSdIE7iqsQxK1P41mySCvssg= -github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a h1:fEBsGL/sjAuJrgah5XqmmYsTLzJp/TO9Lhy39gkverk= +github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= @@ -137,10 +148,10 @@ github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= -github.com/onsi/ginkgo/v2 v2.17.2 h1:7eMhcy3GimbsA3hEnVKdw/PQM9XN9krpKVXsZdph0/g= -github.com/onsi/ginkgo/v2 v2.17.2/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc= -github.com/onsi/gomega v1.33.0 h1:snPCflnZrpMsy94p4lXVEkHo12lmPnc3vY5XBbreexE= -github.com/onsi/gomega v1.33.0/go.mod h1:+925n5YtiFsLzzafLUHzVMBpvvRAzrydIBiSIxjX3wY= +github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= +github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= +github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= +github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= @@ -209,6 +220,7 @@ github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3k github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -321,6 +333,8 @@ google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojt gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/log/log.go b/log/log.go index 7b8f2843..6ee79677 100644 --- a/log/log.go +++ b/log/log.go @@ -10,6 +10,10 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) +const ( + DefaultTimeFormat = "-0700 2006-01-02 15:04:05" +) + type Options struct { Context context.Context Options option.LogOptions @@ -47,7 +51,7 @@ func New(options Options) (Factory, error) { DisableColors: logOptions.DisableColor || logFilePath != "", DisableTimestamp: !logOptions.Timestamp && logFilePath != "", FullTimestamp: logOptions.Timestamp, - TimestampFormat: "-0700 2006-01-02 15:04:05", + TimestampFormat: DefaultTimeFormat, } factory := NewDefaultFactory( options.Context, diff --git a/mitm/constants.go b/mitm/constants.go new file mode 100644 index 00000000..e3a9a46b --- /dev/null +++ b/mitm/constants.go @@ -0,0 +1,11 @@ +package mitm + +import ( + "encoding/base64" + + "github.com/sagernet/sing/common" +) + +var surgeTinyGif = common.OnceValue(func() []byte { + return common.Must1(base64.StdEncoding.DecodeString("R0lGODlhAQABAAAAACH5BAEAAAAALAAAAAABAAEAAAIBAAA=")) +}) diff --git a/mitm/engine.go b/mitm/engine.go new file mode 100644 index 00000000..2f083393 --- /dev/null +++ b/mitm/engine.go @@ -0,0 +1,1099 @@ +package mitm + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "io" + "math" + "mime" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + "unicode" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + sTLS "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" + sHTTP "github.com/sagernet/sing/protocol/http" + "github.com/sagernet/sing/service" + + "golang.org/x/net/http2" +) + +var _ adapter.MITMEngine = (*Engine)(nil) + +type Engine struct { + ctx context.Context + logger logger.ContextLogger + connection adapter.ConnectionManager + certificate adapter.CertificateStore + script adapter.ScriptManager + timeFunc func() time.Time + http2Enabled bool +} + +func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) { + engine := &Engine{ + ctx: ctx, + logger: logger, + http2Enabled: options.HTTP2Enabled, + } + return engine, nil +} + +func (e *Engine) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateInitialize: + e.connection = service.FromContext[adapter.ConnectionManager](e.ctx) + e.certificate = service.FromContext[adapter.CertificateStore](e.ctx) + e.script = service.FromContext[adapter.ScriptManager](e.ctx) + e.timeFunc = ntp.TimeFuncFromContext(e.ctx) + if e.timeFunc == nil { + e.timeFunc = time.Now + } + } + return nil +} + +func (e *Engine) Close() error { + return nil +} + +func (e *Engine) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if e.certificate.TLSDecryptionEnabled() && metadata.ClientHello != nil { + err := e.newTLS(ctx, this, conn, metadata, onClose) + if err != nil { + e.logger.ErrorContext(ctx, err) + } else { + e.logger.DebugContext(ctx, "connection closed") + } + if onClose != nil { + onClose(err) + } + return + } else if metadata.HTTPRequest != nil { + err := e.newHTTP1(ctx, this, conn, nil, metadata) + if err != nil { + e.logger.ErrorContext(ctx, err) + } else { + e.logger.DebugContext(ctx, "connection closed") + } + if onClose != nil { + onClose(err) + } + return + } else { + e.logger.DebugContext(ctx, "HTTP and TLS not detected, skipped") + } + metadata.MITM = nil + e.connection.NewConnection(ctx, this, conn, metadata, onClose) +} + +func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1") + acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2") + if !acceptHTTP && !acceptH2 { + metadata.MITM = nil + e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ",")) + e.connection.NewConnection(ctx, this, conn, metadata, onClose) + return nil + } + var nextProtos []string + if acceptH2 { + nextProtos = append(nextProtos, "h2") + } else if acceptHTTP { + nextProtos = append(nextProtos, "http/1.1") + } + var ( + maxVersion uint16 + minVersion uint16 + ) + for _, version := range metadata.ClientHello.SupportedVersions { + maxVersion = common.Max(maxVersion, version) + minVersion = common.Min(minVersion, version) + } + serverName := metadata.ClientHello.ServerName + if serverName == "" && metadata.Destination.IsIP() { + serverName = metadata.Destination.Addr.String() + } + tlsConfig := &tls.Config{ + Time: e.timeFunc, + ServerName: serverName, + NextProtos: nextProtos, + MinVersion: minVersion, + MaxVersion: maxVersion, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return sTLS.GenerateKeyPair(e.certificate.TLSDecryptionCertificate(), e.certificate.TLSDecryptionPrivateKey(), e.timeFunc, serverName) + }, + } + tlsConn := tls.Server(conn, tlsConfig) + err := tlsConn.HandshakeContext(ctx) + if err != nil { + return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", ")) + } + if tlsConn.ConnectionState().NegotiatedProtocol == "h2" { + return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose) + } else { + return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata) + } +} + +func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error { + options := metadata.MITM + defer conn.Close() + reader := bufio.NewReader(conn) + request, err := sHTTP.ReadRequest(reader) + if err != nil { + return E.Cause(err, "read HTTP request") + } + rawRequestURL := request.URL + if tlsConfig != nil { + rawRequestURL.Scheme = "https" + } else { + rawRequestURL.Scheme = "http" + } + if rawRequestURL.Host == "" { + rawRequestURL.Host = request.Host + } + requestURL := rawRequestURL.String() + request.RequestURI = "" + var ( + requestMatch bool + requestScript adapter.SurgeScript + requestScriptOptions option.MITMRouteSurgeScriptOptions + ) +match: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) + continue + } + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + requestScript = surgeScript + requestScriptOptions = scriptOptions + requestMatch = true + break match + } + } + } + var body []byte + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } + if requestScript != nil { + if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body = io.NopCloser(bytes.NewReader(body)) + } + var result *adapter.HTTPRequestScriptResult + result, err = requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments) + if err != nil { + return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]") + } + if result.Response != nil { + if result.Response.Status == 0 { + result.Response.Status = http.StatusOK + } + response := &http.Response{ + StatusCode: result.Response.Status, + Status: http.StatusText(result.Response.Status), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: result.Response.Headers, + Body: io.NopCloser(bytes.NewReader(result.Response.Body)), + } + err = response.Write(conn) + if err != nil { + return E.Cause(err, "write fake response body") + } + return nil + } else { + if result.URL != "" { + var newURL *url.URL + newURL, err = url.Parse(result.URL) + if err != nil { + return E.Cause(err, "parse updated request URL") + } + request.URL = newURL + newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port()) + if newDestination.Port == 0 { + newDestination.Port = metadata.Destination.Port + } + metadata.Destination = newDestination + if tlsConfig != nil { + tlsConfig.ServerName = newURL.Hostname() + } + } + for key, values := range result.Headers { + request.Header[key] = values + } + if newHost := result.Headers.Get("Host"); newHost != "" { + request.Host = newHost + request.Header.Del("Host") + } + if result.Body != nil { + body = result.Body + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + } + } + if !requestMatch { + for i, rule := range options.SurgeURLRewrite { + if !rule.Pattern.MatchString(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String()) + if rule.Reject { + return E.New("request rejected by url_rewrite") + } else if rule.Redirect { + w := new(simpleResponseWriter) + http.Redirect(w, request, rule.Destination.String(), http.StatusFound) + err = w.Build(request).Write(conn) + if err != nil { + return E.Cause(err, "write url_rewrite 302 response") + } + return nil + } + requestMatch = true + request.URL = rule.Destination + newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port()) + if newDestination.Port == 0 { + newDestination.Port = metadata.Destination.Port + } + metadata.Destination = newDestination + if tlsConfig != nil { + tlsConfig.ServerName = rule.Destination.Hostname() + } + break + } + for i, rule := range options.SurgeHeaderRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + if strings.ToLower(rule.Key) == "host" { + request.Host = rule.Value + continue + } + request.Header.Add(rule.Key, rule.Value) + case rule.Delete: + request.Header.Del(rule.Key) + case rule.Replace: + if request.Header.Get(rule.Key) != "" { + request.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := request.Header.Get(rule.Key); value != "" { + request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if body == nil { + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + } + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + } + if !requestMatch { + for i, rule := range options.SurgeMapLocal { + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String()) + var ( + statusCode = http.StatusOK + headers = make(http.Header) + ) + if rule.StatusCode > 0 { + statusCode = rule.StatusCode + } + switch { + case rule.File: + resource, err := os.ReadFile(rule.Data) + if err != nil { + return E.Cause(err, "open map local source") + } + mimeType := mime.TypeByExtension(filepath.Ext(rule.Data)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + headers.Set("Content-Type", mimeType) + body = resource + case rule.Text: + headers.Set("Content-Type", "text/plain") + body = []byte(rule.Data) + case rule.TinyGif: + headers.Set("Content-Type", "image/gif") + body = surgeTinyGif() + case rule.Base64: + headers.Set("Content-Type", "application/octet-stream") + body = rule.Base64Data + } + response := &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: headers, + Body: io.NopCloser(bytes.NewReader(body)), + } + err = response.Write(conn) + if err != nil { + return E.Cause(err, "write map local response") + } + return nil + } + } + ctx = adapter.WithContext(ctx, &metadata) + var innerErr atomic.TypedValue[error] + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + }, + TLSClientConfig: tlsConfig, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + defer httpClient.CloseIdleConnections() + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + response, err := httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err) + } + var ( + responseScript adapter.SurgeScript + responseMatch bool + responseScriptOptions option.MITMRouteSurgeScriptOptions + ) +matchResponse: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) + continue + } + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + responseScript = surgeScript + responseScriptOptions = scriptOptions + responseMatch = true + break matchResponse + } + } + } + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, request, response, responseBody) + } + if responseScript != nil { + if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + var result *adapter.HTTPResponseScriptResult + result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments) + if err != nil { + return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") + } + if result.Status > 0 { + response.Status = http.StatusText(result.Status) + response.StatusCode = result.Status + } + for key, values := range result.Headers { + response.Header[key] = values + } + if result.Body != nil { + response.Body.Close() + responseBody = result.Body + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) + } + } + if !responseMatch { + for i, rule := range options.SurgeHeaderRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + response.Header.Add(rule.Key, rule.Value) + case rule.Delete: + response.Header.Del(rule.Key) + case rule.Replace: + if response.Header.Get(rule.Key) != "" { + response.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := response.Header.Get(rule.Key); value != "" { + response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + } + for mi := 0; i < len(rule.Match); i++ { + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) + } + } + if !options.Print && !requestMatch && !responseMatch { + e.logger.WarnContext(ctx, "request not modified") + } + err = response.Write(conn) + if err != nil { + return E.Errors(E.Cause(err, "write HTTP response"), innerErr.Load()) + } else if innerErr.Load() != nil { + return E.Cause(innerErr.Load(), "write HTTP response") + } + return nil +} + +func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + httpTransport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + ctx = adapter.WithContext(ctx, &metadata) + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + return this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + }, + TLSClientConfig: tlsConfig, + } + err := http2.ConfigureTransport(httpTransport) + if err != nil { + return E.Cause(err, "configure HTTP/2 transport") + } + handler := &engineHandler{ + Engine: e, + conn: conn, + tlsConfig: tlsConfig, + dialer: this, + metadata: metadata, + httpClient: &http.Client{ + Transport: httpTransport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + onClose: onClose, + } + http2Server := &http2.Server{ + MaxReadFrameSize: math.MaxUint32, + } + http2Server.ServeConn(conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: handler, + }) + return nil +} + +type engineHandler struct { + *Engine + conn net.Conn + tlsConfig *tls.Config + dialer N.Dialer + metadata adapter.InboundContext + onClose N.CloseHandlerFunc + httpClient *http.Client +} + +func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + err := e.serveHTTP(request.Context(), writer, request) + if err != nil { + if E.IsClosedOrCanceled(err) { + e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed")) + } else { + e.logger.ErrorContext(request.Context(), err) + } + } +} + +func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + options := e.metadata.MITM + rawRequestURL := request.URL + rawRequestURL.Scheme = "https" + if rawRequestURL.Host == "" { + rawRequestURL.Host = request.Host + } + requestURL := rawRequestURL.String() + request.RequestURI = "" + var ( + requestMatch bool + requestScript adapter.SurgeScript + requestScriptOptions option.MITMRouteSurgeScriptOptions + ) +match: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) + continue + } + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + requestScript = surgeScript + requestScriptOptions = scriptOptions + requestMatch = true + break match + } + } + } + var ( + body []byte + err error + ) + if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + request.Body = io.NopCloser(bytes.NewReader(body)) + } + if options.Print { + e.printRequest(ctx, request, body) + } + if requestScript != nil { + if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + request.Body = io.NopCloser(bytes.NewReader(body)) + } + result, err := requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments) + if err != nil { + return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]") + } + if result.Response != nil { + if result.Response.Status == 0 { + result.Response.Status = http.StatusOK + } + for key, values := range result.Response.Headers { + writer.Header()[key] = values + } + writer.WriteHeader(result.Response.Status) + if result.Response.Body != nil { + _, err = writer.Write(result.Response.Body) + if err != nil { + return E.Cause(err, "write fake response body") + } + } + return nil + } else { + if result.URL != "" { + var newURL *url.URL + newURL, err = url.Parse(result.URL) + if err != nil { + return E.Cause(err, "parse updated request URL") + } + request.URL = newURL + newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port()) + if newDestination.Port == 0 { + newDestination.Port = e.metadata.Destination.Port + } + e.metadata.Destination = newDestination + e.tlsConfig.ServerName = newURL.Hostname() + } + for key, values := range result.Headers { + request.Header[key] = values + } + if newHost := result.Headers.Get("Host"); newHost != "" { + request.Host = newHost + request.Header.Del("Host") + } + if result.Body != nil { + io.Copy(io.Discard, request.Body) + request.Body = io.NopCloser(bytes.NewReader(result.Body)) + request.ContentLength = int64(len(result.Body)) + } + } + } + if !requestMatch { + for i, rule := range options.SurgeURLRewrite { + if !rule.Pattern.MatchString(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String()) + if rule.Reject { + return E.New("request rejected by url_rewrite") + } else if rule.Redirect { + http.Redirect(writer, request, rule.Destination.String(), http.StatusFound) + return nil + } + requestMatch = true + request.URL = rule.Destination + newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port()) + if newDestination.Port == 0 { + newDestination.Port = e.metadata.Destination.Port + } + e.metadata.Destination = newDestination + e.tlsConfig.ServerName = rule.Destination.Hostname() + break + } + for i, rule := range options.SurgeHeaderRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + if strings.ToLower(rule.Key) == "host" { + request.Host = rule.Value + continue + } + request.Header.Add(rule.Key, rule.Value) + case rule.Delete: + request.Header.Del(rule.Key) + case rule.Replace: + if request.Header.Get(rule.Key) != "" { + request.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := request.Header.Get(rule.Key); value != "" { + request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + var body []byte + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err := io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + } + if !requestMatch { + for i, rule := range options.SurgeMapLocal { + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String()) + go func() { + io.Copy(io.Discard, request.Body) + request.Body.Close() + }() + var ( + statusCode = http.StatusOK + headers = make(http.Header) + body []byte + ) + if rule.StatusCode > 0 { + statusCode = rule.StatusCode + } + switch { + case rule.File: + resource, err := os.ReadFile(rule.Data) + if err != nil { + return E.Cause(err, "open map local source") + } + mimeType := mime.TypeByExtension(filepath.Ext(rule.Data)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + headers.Set("Content-Type", mimeType) + body = resource + case rule.Text: + headers.Set("Content-Type", "text/plain") + body = []byte(rule.Data) + case rule.TinyGif: + headers.Set("Content-Type", "image/gif") + body = surgeTinyGif() + case rule.Base64: + headers.Set("Content-Type", "application/octet-stream") + body = rule.Base64Data + } + for key, values := range headers { + writer.Header()[key] = values + } + writer.WriteHeader(statusCode) + _, err = writer.Write(body) + if err != nil { + return E.Cause(err, "write map local response") + } + return nil + } + } + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + response, err := e.httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Cause(err, "exchange request") + } + var ( + responseScript adapter.SurgeScript + responseMatch bool + responseScriptOptions option.MITMRouteSurgeScriptOptions + ) +matchResponse: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) + continue + } + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + responseScript = surgeScript + responseScriptOptions = scriptOptions + responseMatch = true + break matchResponse + } + } + } + var responseBody []byte + if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + if options.Print { + e.printResponse(ctx, request, response, responseBody) + } + if responseScript != nil { + if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) { + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + } + var result *adapter.HTTPResponseScriptResult + result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments) + if err != nil { + return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") + } + if result.Status > 0 { + response.Status = http.StatusText(result.Status) + response.StatusCode = result.Status + } + for key, values := range result.Headers { + response.Header[key] = values + } + if result.Body != nil { + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(result.Body)) + response.ContentLength = int64(len(result.Body)) + } + } + if !responseMatch { + for i, rule := range options.SurgeHeaderRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + response.Header.Add(rule.Key, rule.Value) + case rule.Delete: + response.Header.Del(rule.Key) + case rule.Replace: + if response.Header.Get(rule.Key) != "" { + response.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := response.Header.Get(rule.Key); value != "" { + response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + if responseBody == nil { + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + responseBody, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + response.Body.Close() + } + for mi := 0; i < len(rule.Match); i++ { + responseBody = rule.Match[mi].ReplaceAll(responseBody, []byte(rule.Replace[i])) + } + response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.ContentLength = int64(len(responseBody)) + } + } + if !options.Print && !requestMatch && !responseMatch { + e.logger.WarnContext(ctx, "request not modified") + } + for key, values := range response.Header { + writer.Header()[key] = values + } + writer.WriteHeader(response.StatusCode) + _, err = io.Copy(writer, response.Body) + response.Body.Close() + if err != nil { + return E.Cause(err, "write HTTP response") + } + return nil +} + +func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) { + var builder strings.Builder + builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL)) + builder.WriteString("\n") + if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host { + builder.WriteString("Host: ") + builder.WriteString(request.Host) + builder.WriteString("\n") + } + for key, values := range request.Header { + for _, value := range values { + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") + } + } + if len(body) > 0 { + builder.WriteString("\n") + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + builder.Write(body) + } else { + builder.WriteString("(body not printable)") + } + } + e.logger.InfoContext(ctx, "request: ", builder.String()) +} + +func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) { + var builder strings.Builder + builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL)) + builder.WriteString("\n") + for key, values := range response.Header { + for _, value := range values { + builder.WriteString(key) + builder.WriteString(": ") + builder.WriteString(value) + builder.WriteString("\n") + } + } + if len(body) > 0 { + builder.WriteString("\n") + if !bytes.ContainsFunc(body, func(r rune) bool { + return !unicode.IsPrint(r) && !unicode.IsSpace(r) + }) { + builder.Write(body) + } else { + builder.WriteString("(body not printable)") + } + } + e.logger.InfoContext(ctx, "response: ", builder.String()) +} + +type simpleResponseWriter struct { + statusCode int + header http.Header + body bytes.Buffer +} + +func (w *simpleResponseWriter) Build(request *http.Request) *http.Response { + return &http.Response{ + StatusCode: w.statusCode, + Status: http.StatusText(w.statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: w.header, + Body: io.NopCloser(&w.body), + } +} + +func (w *simpleResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *simpleResponseWriter) Write(b []byte) (int, error) { + return w.body.Write(b) +} + +func (w *simpleResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} diff --git a/option/certificate.go b/option/certificate.go index ab524b99..2206a787 100644 --- a/option/certificate.go +++ b/option/certificate.go @@ -11,6 +11,13 @@ type _CertificateOptions struct { Certificate badoption.Listable[string] `json:"certificate,omitempty"` CertificatePath badoption.Listable[string] `json:"certificate_path,omitempty"` CertificateDirectoryPath badoption.Listable[string] `json:"certificate_directory_path,omitempty"` + TLSDecryption *TLSDecryptionOptions `json:"tls_decryption,omitempty"` +} + +type TLSDecryptionOptions struct { + Enabled bool `json:"enabled,omitempty"` + KeyPair string `json:"key_pair_p12,omitempty"` + KeyPairPassword string `json:"key_pair_p12_password,omitempty"` } type CertificateOptions _CertificateOptions diff --git a/option/mitm.go b/option/mitm.go new file mode 100644 index 00000000..501f099e --- /dev/null +++ b/option/mitm.go @@ -0,0 +1,31 @@ +package option + +import ( + "github.com/sagernet/sing/common/json/badoption" +) + +type MITMOptions struct { + Enabled bool `json:"enabled,omitempty"` + HTTP2Enabled bool `json:"http2_enabled,omitempty"` +} + +type MITMRouteOptions struct { + Enabled bool `json:"enabled,omitempty"` + Print bool `json:"print,omitempty"` + Script badoption.Listable[MITMRouteSurgeScriptOptions] `json:"surge_script,omitempty"` + SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"surge_url_rewrite,omitempty"` + SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"surge_header_rewrite,omitempty"` + SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"surge_body_rewrite,omitempty"` + SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"surge_map_local,omitempty"` +} + +type MITMRouteSurgeScriptOptions struct { + Tag string `json:"tag"` + Type badoption.Listable[string] `json:"type"` + Pattern badoption.Listable[*badoption.Regexp] `json:"pattern"` + Timeout badoption.Duration `json:"timeout,omitempty"` + RequiresBody bool `json:"requires_body,omitempty"` + MaxSize int64 `json:"max_size,omitempty"` + BinaryBodyMode bool `json:"binary_body_mode,omitempty"` + Arguments badoption.Listable[string] `json:"arguments,omitempty"` +} diff --git a/option/mitm_surge_urlrewrite.go b/option/mitm_surge_urlrewrite.go new file mode 100644 index 00000000..ffec1917 --- /dev/null +++ b/option/mitm_surge_urlrewrite.go @@ -0,0 +1,449 @@ +package option + +import ( + "encoding/base64" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json" +) + +type SurgeURLRewriteLine struct { + Pattern *regexp.Regexp + Destination *url.URL + Redirect bool + Reject bool +} + +func (l SurgeURLRewriteLine) String() string { + var fields []string + fields = append(fields, l.Pattern.String()) + if l.Reject { + fields = append(fields, "_") + } else { + fields = append(fields, l.Destination.String()) + } + switch { + case l.Redirect: + fields = append(fields, "302") + case l.Reject: + fields = append(fields, "reject") + default: + fields = append(fields, "header") + } + return encodeSurgeKeys(fields) +} + +func (l SurgeURLRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeURLRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: ", stringValue) + } else if len(fields) < 2 || len(fields) > 3 { + return E.New("invalid surge_url_rewrite line: ", stringValue) + } + pattern, err := regexp.Compile(fields[0].Key) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: invalid pattern: ", stringValue) + } + l.Pattern = pattern + l.Destination, err = url.Parse(fields[1].Key) + if err != nil { + return E.Cause(err, "invalid surge_url_rewrite line: invalid destination: ", stringValue) + } + if len(fields) == 3 { + switch fields[2].Key { + case "header": + case "302": + l.Redirect = true + case "reject": + l.Reject = true + default: + return E.New("invalid surge_url_rewrite line: invalid action: ", stringValue) + } + } + return nil +} + +type SurgeHeaderRewriteLine struct { + Response bool + Pattern *regexp.Regexp + Add bool + Delete bool + Replace bool + ReplaceRegex bool + Key string + Match *regexp.Regexp + Value string +} + +func (l SurgeHeaderRewriteLine) String() string { + var fields []string + if !l.Response { + fields = append(fields, "http-request") + } else { + fields = append(fields, "http-response") + } + fields = append(fields, l.Pattern.String()) + if l.Add { + fields = append(fields, "header-add") + } else if l.Delete { + fields = append(fields, "header-del") + } else if l.Replace { + fields = append(fields, "header-replace") + } else if l.ReplaceRegex { + fields = append(fields, "header-replace-regex") + } + fields = append(fields, l.Key) + if l.Add || l.Replace { + fields = append(fields, l.Value) + } else if l.ReplaceRegex { + fields = append(fields, l.Match.String(), l.Value) + } + return encodeSurgeKeys(fields) +} + +func (l SurgeHeaderRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeHeaderRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: ", stringValue) + } else if len(fields) < 4 { + return E.New("invalid surge_header_rewrite line: ", stringValue) + } + switch fields[0].Key { + case "http-request": + case "http-response": + l.Response = true + default: + return E.New("invalid surge_header_rewrite line: invalid type: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[1].Key) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: invalid pattern: ", stringValue) + } + switch fields[2].Key { + case "header-add": + l.Add = true + if len(fields) != 5 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Value = fields[4].Key + case "header-del": + l.Delete = true + l.Key = fields[3].Key + case "header-replace": + l.Replace = true + if len(fields) != 5 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Value = fields[4].Key + case "header-replace-regex": + l.ReplaceRegex = true + if len(fields) != 6 { + return E.New("invalid surge_header_rewrite line: " + stringValue) + } + l.Key = fields[3].Key + l.Match, err = regexp.Compile(fields[4].Key) + if err != nil { + return E.Cause(err, "invalid surge_header_rewrite line: invalid match: ", stringValue) + } + l.Value = fields[5].Key + default: + return E.New("invalid surge_header_rewrite line: invalid action: ", stringValue) + } + return nil +} + +type SurgeBodyRewriteLine struct { + Response bool + Pattern *regexp.Regexp + Match []*regexp.Regexp + Replace []string +} + +func (l SurgeBodyRewriteLine) String() string { + var fields []string + if !l.Response { + fields = append(fields, "http-request") + } else { + fields = append(fields, "http-response") + } + for i := 0; i < len(l.Match); i += 2 { + fields = append(fields, l.Match[i].String(), l.Replace[i]) + } + return strings.Join(fields, " ") +} + +func (l SurgeBodyRewriteLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeBodyRewriteLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_body_rewrite line: ", stringValue) + } else if len(fields) < 4 { + return E.New("invalid surge_body_rewrite line: ", stringValue) + } else if len(fields)%2 != 0 { + return E.New("invalid surge_body_rewrite line: ", stringValue) + } + switch fields[0].Key { + case "http-request": + case "http-response": + l.Response = true + default: + return E.New("invalid surge_body_rewrite line: invalid type: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[1].Key) + for i := 2; i < len(fields); i += 2 { + var match *regexp.Regexp + match, err = regexp.Compile(fields[i].Key) + if err != nil { + return E.Cause(err, "invalid surge_body_rewrite line: invalid match: ", stringValue) + } + l.Match = append(l.Match, match) + l.Replace = append(l.Replace, fields[i+1].Key) + } + return nil +} + +type SurgeMapLocalLine struct { + Pattern *regexp.Regexp + StatusCode int + File bool + Text bool + TinyGif bool + Base64 bool + Data string + Base64Data []byte + Headers http.Header +} + +func (l SurgeMapLocalLine) String() string { + var fields []surgeField + fields = append(fields, surgeField{Key: l.Pattern.String()}) + if l.File { + fields = append(fields, surgeField{Key: "data-type", Value: "file"}) + fields = append(fields, surgeField{Key: "data", Value: l.Data}) + } else if l.Text { + fields = append(fields, surgeField{Key: "data-type", Value: "text"}) + fields = append(fields, surgeField{Key: "data", Value: l.Data}) + } else if l.TinyGif { + fields = append(fields, surgeField{Key: "data-type", Value: "tiny-gif"}) + } else if l.Base64 { + fields = append(fields, surgeField{Key: "data-type", Value: "base64"}) + fields = append(fields, surgeField{Key: "data-type", Value: base64.StdEncoding.EncodeToString(l.Base64Data)}) + } + if l.StatusCode != 0 { + fields = append(fields, surgeField{Key: "status-code", Value: F.ToString(l.StatusCode), ValueSet: true}) + } + if len(l.Headers) > 0 { + var headers []string + for key, values := range l.Headers { + for _, value := range values { + headers = append(headers, key+":"+value) + } + } + fields = append(fields, surgeField{Key: "headers", Value: strings.Join(headers, "|")}) + } + return encodeSurgeFields(fields) +} + +func (l SurgeMapLocalLine) MarshalJSON() ([]byte, error) { + return json.Marshal(l.String()) +} + +func (l *SurgeMapLocalLine) UnmarshalJSON(bytes []byte) error { + var stringValue string + err := json.Unmarshal(bytes, &stringValue) + if err != nil { + return err + } + fields, err := surgeFields(stringValue) + if err != nil { + return E.Cause(err, "invalid surge_map_local line: ", stringValue) + } else if len(fields) < 1 { + return E.New("invalid surge_map_local line: ", stringValue) + } + l.Pattern, err = regexp.Compile(fields[0].Key) + if err != nil { + return E.Cause(err, "invalid surge_map_local line: invalid pattern: ", stringValue) + } + dataTypeField := common.Find(fields, func(it surgeField) bool { + return it.Key == "data-type" + }) + if !dataTypeField.ValueSet { + return E.New("invalid surge_map_local line: missing data-type: ", stringValue) + } + switch dataTypeField.Value { + case "file": + l.File = true + case "text": + l.Text = true + case "tiny-gif": + l.TinyGif = true + case "base64": + l.Base64 = true + default: + return E.New("unsupported data-type ", dataTypeField.Value) + } + for i := 1; i < len(fields); i++ { + switch fields[i].Key { + case "data-type": + continue + case "data": + if l.File { + l.Data = fields[i].Value + } else if l.Text { + l.Data = fields[i].Value + } else if l.Base64 { + l.Base64Data, err = base64.StdEncoding.DecodeString(fields[i].Value) + if err != nil { + return E.New("invalid surge_map_local line: invalid base64 data: ", stringValue) + } + } + case "status-code": + statusCode, err := strconv.ParseInt(fields[i].Value, 10, 16) + if err != nil { + return E.New("invalid surge_map_local line: invalid status code: ", stringValue) + } + l.StatusCode = int(statusCode) + case "header": + headers := make(http.Header) + for _, headerLine := range strings.Split(fields[i].Value, "|") { + if !strings.Contains(headerLine, ":") { + return E.New("invalid surge_map_local line: headers: missing `:` in item: ", stringValue, ": ", headerLine) + } + headers.Add(common.SubstringBefore(headerLine, ":"), common.SubstringAfter(headerLine, ":")) + } + l.Headers = headers + default: + return E.New("invalid surge_map_local line: unknown options: ", fields[i].Key) + } + } + return nil +} + +type surgeField struct { + Key string + Value string + ValueSet bool +} + +func encodeSurgeKeys(keys []string) string { + keys = common.Map(keys, func(it string) string { + if strings.ContainsFunc(it, unicode.IsSpace) { + return "\"" + it + "\"" + } else { + return it + } + }) + return strings.Join(keys, " ") +} + +func encodeSurgeFields(fields []surgeField) string { + return strings.Join(common.Map(fields, func(it surgeField) string { + if !it.ValueSet { + if strings.ContainsFunc(it.Key, unicode.IsSpace) { + return "\"" + it.Key + "\"" + } else { + return it.Key + } + } else { + if strings.ContainsFunc(it.Value, unicode.IsSpace) { + return it.Key + "=\"" + it.Value + "\"" + } else { + return it.Key + "=" + it.Value + } + } + }), " ") +} + +func surgeFields(s string) ([]surgeField, error) { + var ( + fields []surgeField + currentField *surgeField + ) + for _, field := range strings.Fields(s) { + if currentField != nil { + field = " " + field + if strings.HasSuffix(field, "\"") { + field = field[:len(field)-1] + if !currentField.ValueSet { + currentField.Key += field + } else { + currentField.Value += field + } + fields = append(fields, *currentField) + currentField = nil + } else { + if !currentField.ValueSet { + currentField.Key += field + } else { + currentField.Value += field + } + } + continue + } + if !strings.Contains(field, "=") { + if strings.HasPrefix(field, "\"") { + field = field[1:] + if strings.HasSuffix(field, "\"") { + field = field[:len(field)-1] + } else { + currentField = &surgeField{Key: field} + continue + } + } + fields = append(fields, surgeField{Key: field}) + } else { + key := common.SubstringBefore(field, "=") + value := common.SubstringAfter(field, "=") + if strings.HasPrefix(value, "\"") { + value = value[1:] + if strings.HasSuffix(field, "\"") { + value = value[:len(value)-1] + } else { + currentField = &surgeField{Key: key, Value: value, ValueSet: true} + continue + } + } + fields = append(fields, surgeField{Key: key, Value: value, ValueSet: true}) + } + } + if currentField != nil { + return nil, E.New("invalid surge fields line: ", s) + } + return fields, nil +} diff --git a/option/options.go b/option/options.go index 168074ed..f454afd3 100644 --- a/option/options.go +++ b/option/options.go @@ -12,13 +12,15 @@ type _Options struct { Schema string `json:"$schema,omitempty"` Log *LogOptions `json:"log,omitempty"` DNS *DNSOptions `json:"dns,omitempty"` - NTP *NTPOptions `json:"ntp,omitempty"` - Certificate *CertificateOptions `json:"certificate,omitempty"` Endpoints []Endpoint `json:"endpoints,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` Route *RouteOptions `json:"route,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"` + NTP *NTPOptions `json:"ntp,omitempty"` + Certificate *CertificateOptions `json:"certificate,omitempty"` + MITM *MITMOptions `json:"mitm,omitempty"` + Scripts []Script `json:"scripts,omitempty"` } type Options _Options diff --git a/option/rule_action.go b/option/rule_action.go index 7c05dce6..473d371f 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -158,6 +158,8 @@ type RawRouteOptionsActionOptions struct { TLSFragment bool `json:"tls_fragment,omitempty"` TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` + + MITM *MITMRouteOptions `json:"mitm,omitempty"` } type RouteOptionsActionOptions RawRouteOptionsActionOptions diff --git a/option/script.go b/option/script.go new file mode 100644 index 00000000..90a3b586 --- /dev/null +++ b/option/script.go @@ -0,0 +1,128 @@ +package option + +import ( + C "github.com/sagernet/sing-box/constant" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/common/json/badoption" +) + +type _ScriptSourceOptions struct { + Source string `json:"source"` + LocalOptions LocalScriptSource `json:"-"` + RemoteOptions RemoteScriptSource `json:"-"` +} + +type LocalScriptSource struct { + Path string `json:"path"` +} + +type RemoteScriptSource struct { + URL string `json:"url"` + DownloadDetour string `json:"download_detour,omitempty"` + UpdateInterval badoption.Duration `json:"update_interval,omitempty"` +} + +type ScriptSourceOptions _ScriptSourceOptions + +func (o ScriptSourceOptions) MarshalJSON() ([]byte, error) { + var source any + switch o.Source { + case C.ScriptSourceTypeLocal: + source = o.LocalOptions + case C.ScriptSourceTypeRemote: + source = o.RemoteOptions + default: + return nil, E.New("unknown script source: ", o.Source) + } + return badjson.MarshallObjects((_ScriptSourceOptions)(o), source) +} + +func (o *ScriptSourceOptions) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_ScriptSourceOptions)(o)) + if err != nil { + return err + } + var source any + switch o.Source { + case C.ScriptSourceTypeLocal: + source = &o.LocalOptions + case C.ScriptSourceTypeRemote: + source = &o.RemoteOptions + default: + return E.New("unknown script source: ", o.Source) + } + return json.Unmarshal(bytes, source) +} + +// TODO: make struct in order +type Script struct { + ScriptSourceOptions + ScriptOptions +} + +func (s Script) MarshalJSON() ([]byte, error) { + return badjson.MarshallObjects(s.ScriptSourceOptions, s.ScriptOptions) +} + +func (s *Script) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, &s.ScriptSourceOptions) + if err != nil { + return err + } + return badjson.UnmarshallExcluded(bytes, &s.ScriptSourceOptions, &s.ScriptOptions) +} + +type _ScriptOptions struct { + Type string `json:"type"` + Tag string `json:"tag"` + SurgeOptions SurgeScriptOptions `json:"-"` +} + +type ScriptOptions _ScriptOptions + +func (o ScriptOptions) MarshalJSON() ([]byte, error) { + var v any + switch o.Type { + case C.ScriptTypeSurge: + v = &o.SurgeOptions + default: + return nil, E.New("unknown script type: ", o.Type) + } + if v == nil { + return badjson.MarshallObjects((_ScriptOptions)(o)) + } + return badjson.MarshallObjects((_ScriptOptions)(o), v) +} + +func (o *ScriptOptions) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_ScriptOptions)(o)) + if err != nil { + return err + } + var v any + switch o.Type { + case C.ScriptTypeSurge: + v = &o.SurgeOptions + case "": + return E.New("missing script type") + default: + return E.New("unknown script type: ", o.Type) + } + if v == nil { + // check unknown fields + return json.UnmarshalDisallowUnknownFields(bytes, &_ScriptOptions{}) + } + return badjson.UnmarshallExcluded(bytes, (*_ScriptOptions)(o), v) +} + +type SurgeScriptOptions struct { + CronOptions *CronScriptOptions `json:"cron,omitempty"` +} + +type CronScriptOptions struct { + Expression string `json:"expression"` + Arguments []string `json:"arguments,omitempty"` + Timeout badoption.Duration `json:"timeout,omitempty"` +} diff --git a/route/conn.go b/route/conn.go index c2a2eab9..33c02a44 100644 --- a/route/conn.go +++ b/route/conn.go @@ -24,23 +24,31 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { + ctx context.Context logger logger.ContextLogger + mitm adapter.MITMEngine access sync.Mutex connections list.List[io.Closer] } -func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { +func NewConnectionManager(ctx context.Context, logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ + ctx: ctx, logger: logger, } } func (m *ConnectionManager) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateInitialize: + m.mitm = service.FromContext[adapter.MITMEngine](m.ctx) + } return nil } @@ -55,6 +63,14 @@ func (m *ConnectionManager) Close() error { } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.MITM != nil && metadata.MITM.Enabled { + if m.mitm == nil { + m.logger.WarnContext(ctx, "MITM disabled") + } else { + m.mitm.NewConnection(ctx, this, conn, metadata, onClose) + return + } + } ctx = adapter.WithContext(ctx, &metadata) var ( remoteConn net.Conn diff --git a/route/route.go b/route/route.go index 531ad039..72ecbd34 100644 --- a/route/route.go +++ b/route/route.go @@ -458,6 +458,9 @@ match: metadata.TLSFragment = true metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay } + if routeOptions.MITM != nil && routeOptions.MITM.Enabled { + metadata.MITM = routeOptions.MITM + } } switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index f49baca6..73bc2bae 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -40,6 +40,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPConnect: action.RouteOptions.UDPConnect, TLSFragment: action.RouteOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), + MITM: action.RouteOptions.MITM, }, }, nil case C.RuleActionTypeRouteOptions: @@ -53,6 +54,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), TLSFragment: action.RouteOptionsOptions.TLSFragment, TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), + MITM: action.RouteOptionsOptions.MITM, }, nil case C.RuleActionTypeDirect: directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) @@ -152,15 +154,7 @@ func (r *RuleActionRoute) Type() string { func (r *RuleActionRoute) String() string { var descriptions []string descriptions = append(descriptions, r.Outbound) - if r.UDPDisableDomainUnmapping { - descriptions = append(descriptions, "udp-disable-domain-unmapping") - } - if r.UDPConnect { - descriptions = append(descriptions, "udp-connect") - } - if r.TLSFragment { - descriptions = append(descriptions, "tls-fragment") - } + descriptions = append(descriptions, r.Descriptions()...) return F.ToString("route(", strings.Join(descriptions, ","), ")") } @@ -176,13 +170,14 @@ type RuleActionRouteOptions struct { UDPTimeout time.Duration TLSFragment bool TLSFragmentFallbackDelay time.Duration + MITM *option.MITMRouteOptions } func (r *RuleActionRouteOptions) Type() string { return C.RuleActionTypeRouteOptions } -func (r *RuleActionRouteOptions) String() string { +func (r *RuleActionRouteOptions) Descriptions() []string { var descriptions []string if r.OverrideAddress.IsValid() { descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString())) @@ -209,9 +204,22 @@ func (r *RuleActionRouteOptions) String() string { descriptions = append(descriptions, "udp-connect") } if r.UDPTimeout > 0 { - descriptions = append(descriptions, "udp-timeout") + descriptions = append(descriptions, F.ToString("udp-timeout=", r.UDPTimeout)) } - return F.ToString("route-options(", strings.Join(descriptions, ","), ")") + if r.TLSFragment { + descriptions = append(descriptions, "tls-fragment") + if r.TLSFragmentFallbackDelay > 0 { + descriptions = append(descriptions, F.ToString("tls-fragment-fallbac-delay=", r.TLSFragmentFallbackDelay.String())) + } + } + if r.MITM != nil && r.MITM.Enabled { + descriptions = append(descriptions, "mitm") + } + return descriptions +} + +func (r *RuleActionRouteOptions) String() string { + return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")") } type RuleActionDNSRoute struct { diff --git a/script/jsc/array.go b/script/jsc/array.go new file mode 100644 index 00000000..a113e02c --- /dev/null +++ b/script/jsc/array.go @@ -0,0 +1,23 @@ +package jsc + +import ( + _ "unsafe" + + "github.com/dop251/goja" +) + +func NewUint8Array(runtime *goja.Runtime, data []byte) goja.Value { + buffer := runtime.NewArrayBuffer(data) + ctor, loaded := goja.AssertConstructor(runtimeGetUint8Array(runtime)) + if !loaded { + panic(runtime.NewTypeError("missing UInt8Array constructor")) + } + array, err := ctor(nil, runtime.ToValue(buffer)) + if err != nil { + panic(runtime.NewGoError(err)) + } + return array +} + +//go:linkname runtimeGetUint8Array github.com/dop251/goja.(*Runtime).getUint8Array +func runtimeGetUint8Array(r *goja.Runtime) *goja.Object diff --git a/script/jsc/array_test.go b/script/jsc/array_test.go new file mode 100644 index 00000000..77f43dc5 --- /dev/null +++ b/script/jsc/array_test.go @@ -0,0 +1,18 @@ +package jsc_test + +import ( + "testing" + + "github.com/sagernet/sing-box/script/jsc" + + "github.com/dop251/goja" + "github.com/stretchr/testify/require" +) + +func TestNewUInt8Array(t *testing.T) { + runtime := goja.New() + runtime.Set("hello", jsc.NewUint8Array(runtime, []byte("world"))) + result, err := runtime.RunString("hello instanceof Uint8Array") + require.NoError(t, err) + require.True(t, result.ToBoolean()) +} diff --git a/script/jsc/assert.go b/script/jsc/assert.go new file mode 100644 index 00000000..0b7fe3b6 --- /dev/null +++ b/script/jsc/assert.go @@ -0,0 +1,124 @@ +package jsc + +import ( + "net/http" + + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" +) + +func IsNil(value goja.Value) bool { + return value == nil || goja.IsUndefined(value) || goja.IsNull(value) +} + +func AssertObject(vm *goja.Runtime, value goja.Value, name string, nilable bool) *goja.Object { + if IsNil(value) { + if nilable { + return nil + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + objectValue, isObject := value.(*goja.Object) + if !isObject { + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected object, but got ", value))) + } + return objectValue +} + +func AssertString(vm *goja.Runtime, value goja.Value, name string, nilable bool) string { + if IsNil(value) { + if nilable { + return "" + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + stringValue, isString := value.Export().(string) + if !isString { + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, but got ", value))) + } + return stringValue +} + +func AssertInt(vm *goja.Runtime, value goja.Value, name string, nilable bool) int64 { + if IsNil(value) { + if nilable { + return 0 + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + integerValue, isNumber := value.Export().(int64) + if !isNumber { + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected integer, but got ", value))) + } + return integerValue +} + +func AssertBool(vm *goja.Runtime, value goja.Value, name string, nilable bool) bool { + if IsNil(value) { + if nilable { + return false + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + boolValue, isBool := value.Export().(bool) + if !isBool { + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected boolean, but got ", value))) + } + return boolValue +} + +func AssertBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte { + if IsNil(value) { + if nilable { + return nil + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + switch exportedValue := value.Export().(type) { + case []byte: + return exportedValue + case goja.ArrayBuffer: + return exportedValue.Bytes() + default: + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected Uint8Array or ArrayBuffer, but got ", value))) + } +} + +func AssertStringBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte { + if IsNil(value) { + if nilable { + return nil + } + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + switch exportedValue := value.Export().(type) { + case string: + return []byte(exportedValue) + case []byte: + return exportedValue + case goja.ArrayBuffer: + return exportedValue.Bytes() + default: + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, Uint8Array or ArrayBuffer, but got ", value))) + } +} + +func AssertFunction(vm *goja.Runtime, value goja.Value, name string) goja.Callable { + if IsNil(value) { + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } + functionValue, isFunction := goja.AssertFunction(value) + if !isFunction { + panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected function, but got ", value))) + } + return functionValue +} + +func AssertHTTPHeader(vm *goja.Runtime, value goja.Value, name string) http.Header { + headersObject := AssertObject(vm, value, name, true) + if headersObject == nil { + return nil + } + return ObjectToHeaders(vm, headersObject, name) +} diff --git a/script/jsc/class.go b/script/jsc/class.go new file mode 100644 index 00000000..cb949512 --- /dev/null +++ b/script/jsc/class.go @@ -0,0 +1,192 @@ +package jsc + +import ( + "time" + + "github.com/sagernet/sing/common" + + "github.com/dop251/goja" +) + +type Module interface { + Runtime() *goja.Runtime +} + +type Class[M Module, C any] interface { + Module() M + Runtime() *goja.Runtime + DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value)) + DefineMethod(name string, method func(this C, call goja.FunctionCall) any) + DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any) + DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C) + ToValue() goja.Value + New(instance C) *goja.Object + Prototype() *goja.Object + Is(value goja.Value) bool + As(value goja.Value) C +} + +func GetClass[M Module, C any](runtime *goja.Runtime, exports *goja.Object, className string) Class[M, C] { + objectValue := exports.Get(className) + if objectValue == nil { + panic(runtime.NewTypeError("Missing class: " + className)) + } + object, isObject := objectValue.(*goja.Object) + if !isObject { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + classObject, isClass := object.Get("_class").(*goja.Object) + if !isClass { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + class, isClass := classObject.Export().(Class[M, C]) + if !isClass { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + return class +} + +type goClass[M Module, C any] struct { + m M + prototype *goja.Object + constructor goja.Value +} + +func NewClass[M Module, C any](module M) Class[M, C] { + class := &goClass[M, C]{ + m: module, + prototype: module.Runtime().NewObject(), + } + clazz := module.Runtime().ToValue(class).(*goja.Object) + clazz.Set("toString", module.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + return module.Runtime().ToValue("[sing-box Class]") + })) + class.prototype.DefineAccessorProperty("_class", class.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { return clazz }), nil, goja.FLAG_FALSE, goja.FLAG_TRUE) + return class +} + +func (c *goClass[M, C]) Module() M { + return c.m +} + +func (c *goClass[M, C]) Runtime() *goja.Runtime { + return c.m.Runtime() +} + +func (c *goClass[M, C]) DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value)) { + var ( + getterValue goja.Value + setterValue goja.Value + ) + if getter != nil { + getterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.ExportType().String())) + } + return c.toValue(getter(this), goja.Null()) + }) + } + if setter != nil { + setterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String())) + } + setter(this, call.Argument(0)) + return goja.Undefined() + }) + } + c.prototype.DefineAccessorProperty(name, getterValue, setterValue, goja.FLAG_FALSE, goja.FLAG_TRUE) +} + +func (c *goClass[M, C]) DefineMethod(name string, method func(this C, call goja.FunctionCall) any) { + methodValue := c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String())) + } + return c.toValue(method(this, call), goja.Undefined()) + }) + c.prototype.Set(name, methodValue) + if name == "entries" { + c.prototype.DefineDataPropertySymbol(goja.SymIterator, methodValue, goja.FLAG_TRUE, goja.FLAG_FALSE, goja.FLAG_TRUE) + } +} + +func (c *goClass[M, C]) DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any) { + c.prototype.Set(name, c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + return c.toValue(method(c, call), goja.Undefined()) + })) +} + +func (c *goClass[M, C]) DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C) { + constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object { + value := constructor(c, call) + object := c.toValue(value, goja.Undefined()).(*goja.Object) + object.SetPrototype(call.This.Prototype()) + return object + }).(*goja.Object) + constructorObject.SetPrototype(c.prototype) + c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE) + c.constructor = constructorObject +} + +func (c *goClass[M, C]) toValue(rawValue any, defaultValue goja.Value) goja.Value { + switch value := rawValue.(type) { + case nil: + return defaultValue + case time.Time: + return TimeToValue(c.Runtime(), value) + default: + return c.Runtime().ToValue(value) + } +} + +func (c *goClass[M, C]) ToValue() goja.Value { + if c.constructor == nil { + constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object { + panic(c.Runtime().NewTypeError("Illegal constructor call")) + }).(*goja.Object) + constructorObject.SetPrototype(c.prototype) + c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE) + c.constructor = constructorObject + } + return c.constructor +} + +func (c *goClass[M, C]) New(instance C) *goja.Object { + object := c.Runtime().ToValue(instance).(*goja.Object) + object.SetPrototype(c.prototype) + return object +} + +func (c *goClass[M, C]) Prototype() *goja.Object { + return c.prototype +} + +func (c *goClass[M, C]) Is(value goja.Value) bool { + object, isObject := value.(*goja.Object) + if !isObject { + return false + } + prototype := object.Prototype() + for prototype != nil { + if prototype == c.prototype { + return true + } + prototype = prototype.Prototype() + } + return false +} + +func (c *goClass[M, C]) As(value goja.Value) C { + object, isObject := value.(*goja.Object) + if !isObject { + return common.DefaultValue[C]() + } + if !c.Is(object) { + return common.DefaultValue[C]() + } + return object.Export().(C) +} diff --git a/script/jsc/headers.go b/script/jsc/headers.go new file mode 100644 index 00000000..dcbbb516 --- /dev/null +++ b/script/jsc/headers.go @@ -0,0 +1,56 @@ +package jsc + +import ( + "net/http" + "reflect" + + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" +) + +func HeadersToValue(runtime *goja.Runtime, headers http.Header) goja.Value { + object := runtime.NewObject() + for key, value := range headers { + if len(value) == 1 { + object.Set(key, value[0]) + } else { + object.Set(key, ArrayToValue(runtime, value)) + } + } + return object +} + +func ArrayToValue[T any](runtime *goja.Runtime, values []T) goja.Value { + return runtime.NewArray(common.Map(values, func(it T) any { return it })...) +} + +func ObjectToHeaders(vm *goja.Runtime, object *goja.Object, name string) http.Header { + headers := make(http.Header) + for _, key := range object.Keys() { + valueObject := object.Get(key) + switch headerValue := valueObject.(type) { + case goja.String: + headers.Set(key, headerValue.String()) + case *goja.Object: + values := headerValue.Export() + valueArray, isArray := values.([]any) + if !isArray { + panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, "expected string or string array, got ", valueObject.String()))) + } + newValues := make([]string, 0, len(valueArray)) + for _, value := range valueArray { + stringValue, isString := value.(string) + if !isString { + panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got array item type: ", reflect.TypeOf(value)))) + } + newValues = append(newValues, stringValue) + } + headers[key] = newValues + default: + panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got ", valueObject.String()))) + } + } + return headers +} diff --git a/script/jsc/headers_test.go b/script/jsc/headers_test.go new file mode 100644 index 00000000..ecbae23b --- /dev/null +++ b/script/jsc/headers_test.go @@ -0,0 +1,31 @@ +package jsc_test + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sagernet/sing-box/script/jsc" + + "github.com/dop251/goja" + "github.com/stretchr/testify/require" +) + +func TestHeaders(t *testing.T) { + runtime := goja.New() + runtime.Set("headers", jsc.HeadersToValue(runtime, http.Header{ + "My-Header": []string{"My-Value1", "My-Value2"}, + })) + headers := runtime.Get("headers").(*goja.Object).Get("My-Header").(*goja.Object) + fmt.Println(reflect.ValueOf(headers.Export()).Type().String()) +} + +func TestBody(t *testing.T) { + runtime := goja.New() + _, err := runtime.RunString(` +var responseBody = new Uint8Array([1, 2, 3, 4, 5]) +`) + require.NoError(t, err) + fmt.Println(reflect.TypeOf(runtime.Get("responseBody").Export())) +} diff --git a/script/jsc/iterator.go b/script/jsc/iterator.go new file mode 100644 index 00000000..deb66764 --- /dev/null +++ b/script/jsc/iterator.go @@ -0,0 +1,36 @@ +package jsc + +import "github.com/dop251/goja" + +type Iterator[M Module, T any] struct { + c Class[M, *Iterator[M, T]] + values []T + block func(this T) any +} + +func NewIterator[M Module, T any](class Class[M, *Iterator[M, T]], values []T, block func(this T) any) goja.Value { + return class.New(&Iterator[M, T]{class, values, block}) +} + +func CreateIterator[M Module, T any](module M) Class[M, *Iterator[M, T]] { + class := NewClass[M, *Iterator[M, T]](module) + class.DefineMethod("next", (*Iterator[M, T]).next) + class.DefineMethod("toString", (*Iterator[M, T]).toString) + return class +} + +func (i *Iterator[M, T]) next(call goja.FunctionCall) any { + result := i.c.Runtime().NewObject() + if len(i.values) == 0 { + result.Set("done", true) + } else { + result.Set("done", false) + result.Set("value", i.block(i.values[0])) + i.values = i.values[1:] + } + return result +} + +func (i *Iterator[M, T]) toString(call goja.FunctionCall) any { + return "[sing-box Iterator]" +} diff --git a/script/jsc/time.go b/script/jsc/time.go new file mode 100644 index 00000000..7879f84c --- /dev/null +++ b/script/jsc/time.go @@ -0,0 +1,18 @@ +package jsc + +import ( + "time" + _ "unsafe" + + "github.com/dop251/goja" +) + +func TimeToValue(runtime *goja.Runtime, time time.Time) goja.Value { + return runtimeNewDateObject(runtime, time, true, runtimeGetDatePrototype(runtime)) +} + +//go:linkname runtimeNewDateObject github.com/dop251/goja.(*Runtime).newDateObject +func runtimeNewDateObject(r *goja.Runtime, t time.Time, isSet bool, proto *goja.Object) *goja.Object + +//go:linkname runtimeGetDatePrototype github.com/dop251/goja.(*Runtime).getDatePrototype +func runtimeGetDatePrototype(r *goja.Runtime) *goja.Object diff --git a/script/jsc/time_test.go b/script/jsc/time_test.go new file mode 100644 index 00000000..5ef86e75 --- /dev/null +++ b/script/jsc/time_test.go @@ -0,0 +1,20 @@ +package jsc_test + +import ( + "testing" + "time" + + "github.com/sagernet/sing-box/script/jsc" + + "github.com/dop251/goja" + "github.com/stretchr/testify/require" +) + +func TestTimeToValue(t *testing.T) { + t.Parallel() + runtime := goja.New() + now := time.Now() + err := runtime.Set("now", jsc.TimeToValue(runtime, now)) + require.NoError(t, err) + println(runtime.Get("now").String()) +} diff --git a/script/jstest/assert.js b/script/jstest/assert.js new file mode 100644 index 00000000..b00076dd --- /dev/null +++ b/script/jstest/assert.js @@ -0,0 +1,83 @@ +'use strict'; + +const assert = { + _isSameValue(a, b) { + if (a === b) { + // Handle +/-0 vs. -/+0 + return a !== 0 || 1 / a === 1 / b; + } + + // Handle NaN vs. NaN + return a !== a && b !== b; + }, + + _toString(value) { + try { + if (value === 0 && 1 / value === -Infinity) { + return '-0'; + } + + return String(value); + } catch (err) { + if (err.name === 'TypeError') { + return Object.prototype.toString.call(value); + } + + throw err; + } + }, + + sameValue(actual, expected, message) { + if (assert._isSameValue(actual, expected)) { + return; + } + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + + message += 'Expected SameValue(«' + assert._toString(actual) + '», «' + assert._toString(expected) + '») to be true'; + + throw new Error(message); + }, + + throws(f, ctor, message) { + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + try { + f(); + } catch (e) { + if (e.constructor !== ctor) { + throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name); + } + return; + } + throw new Error(message + "No exception was thrown"); + }, + + throwsNodeError(f, ctor, code, message) { + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + try { + f(); + } catch (e) { + if (e.constructor !== ctor) { + throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name); + } + if (e.code !== code) { + throw new Error(message + "Wrong exception code was thrown: " + e.code); + } + return; + } + throw new Error(message + "No exception was thrown"); + } +} + +module.exports = assert; \ No newline at end of file diff --git a/script/jstest/test.go b/script/jstest/test.go new file mode 100644 index 00000000..e287f8c2 --- /dev/null +++ b/script/jstest/test.go @@ -0,0 +1,21 @@ +package jstest + +import ( + _ "embed" + + "github.com/sagernet/sing-box/script/modules/require" +) + +//go:embed assert.js +var assertJS []byte + +func NewRegistry() *require.Registry { + return require.NewRegistry(require.WithFsEnable(true), require.WithLoader(func(path string) ([]byte, error) { + switch path { + case "assert.js": + return assertJS, nil + default: + return require.DefaultSourceLoader(path) + } + })) +} diff --git a/script/manager.go b/script/manager.go new file mode 100644 index 00000000..c21af640 --- /dev/null +++ b/script/manager.go @@ -0,0 +1,118 @@ +//go:build with_script + +package script + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/common/task" +) + +var _ adapter.ScriptManager = (*Manager)(nil) + +type Manager struct { + ctx context.Context + logger logger.ContextLogger + scripts []adapter.Script + scriptByName map[string]adapter.Script + surgeCache *adapter.SurgeInMemoryCache +} + +func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Script) (*Manager, error) { + manager := &Manager{ + ctx: ctx, + logger: logFactory.NewLogger("script"), + scriptByName: make(map[string]adapter.Script), + } + for _, scriptOptions := range scripts { + script, err := NewScript(ctx, logFactory.NewLogger(F.ToString("script/", scriptOptions.Type, "[", scriptOptions.Tag, "]")), scriptOptions) + if err != nil { + return nil, E.Cause(err, "initialize script: ", scriptOptions.Tag) + } + manager.scripts = append(manager.scripts, script) + manager.scriptByName[scriptOptions.Tag] = script + } + return manager, nil +} + +func (m *Manager) Start(stage adapter.StartStage) error { + monitor := taskmonitor.New(m.logger, C.StartTimeout) + switch stage { + case adapter.StartStateStart: + var cacheContext *adapter.HTTPStartContext + if len(m.scripts) > 0 { + monitor.Start("initialize rule-set") + cacheContext = adapter.NewHTTPStartContext(m.ctx) + var scriptStartGroup task.Group + for _, script := range m.scripts { + scriptInPlace := script + scriptStartGroup.Append0(func(ctx context.Context) error { + err := scriptInPlace.StartContext(ctx, cacheContext) + if err != nil { + return E.Cause(err, "initialize script/", scriptInPlace.Type(), "[", scriptInPlace.Tag(), "]") + } + return nil + }) + } + scriptStartGroup.Concurrency(5) + scriptStartGroup.FastFail() + err := scriptStartGroup.Run(m.ctx) + monitor.Finish() + if err != nil { + return err + } + } + if cacheContext != nil { + cacheContext.Close() + } + case adapter.StartStatePostStart: + for _, script := range m.scripts { + monitor.Start(F.ToString("post start script/", script.Type(), "[", script.Tag(), "]")) + err := script.PostStart() + monitor.Finish() + if err != nil { + return E.Cause(err, "post start script/", script.Type(), "[", script.Tag(), "]") + } + } + } + return nil +} + +func (m *Manager) Close() error { + monitor := taskmonitor.New(m.logger, C.StopTimeout) + var err error + for _, script := range m.scripts { + monitor.Start(F.ToString("close start script/", script.Type(), "[", script.Tag(), "]")) + err = E.Append(err, script.Close(), func(err error) error { + return E.Cause(err, "close script/", script.Type(), "[", script.Tag(), "]") + }) + monitor.Finish() + } + return err +} + +func (m *Manager) Scripts() []adapter.Script { + return m.scripts +} + +func (m *Manager) Script(name string) (adapter.Script, bool) { + script, loaded := m.scriptByName[name] + return script, loaded +} + +func (m *Manager) SurgeCache() *adapter.SurgeInMemoryCache { + if m.surgeCache == nil { + m.surgeCache = &adapter.SurgeInMemoryCache{ + Data: make(map[string]string), + } + } + return m.surgeCache +} diff --git a/script/manager_stub.go b/script/manager_stub.go new file mode 100644 index 00000000..eae7ed66 --- /dev/null +++ b/script/manager_stub.go @@ -0,0 +1,43 @@ +//go:build !with_script + +package script + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +var _ adapter.ScriptManager = (*Manager)(nil) + +type Manager struct{} + +func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Script) (*Manager, error) { + if len(scripts) > 0 { + return nil, E.New(`script is not included in this build, rebuild with -tags with_script`) + } + return (*Manager)(nil), nil +} + +func (m *Manager) Start(stage adapter.StartStage) error { + return nil +} + +func (m *Manager) Close() error { + return nil +} + +func (m *Manager) Scripts() []adapter.Script { + return nil +} + +func (m *Manager) Script(name string) (adapter.Script, bool) { + return nil, false +} + +func (m *Manager) SurgeCache() *adapter.SurgeInMemoryCache { + return nil +} diff --git a/script/modules/boxctx/context.go b/script/modules/boxctx/context.go new file mode 100644 index 00000000..53e74860 --- /dev/null +++ b/script/modules/boxctx/context.go @@ -0,0 +1,50 @@ +package boxctx + +import ( + "context" + "time" + + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing/common/logger" + + "github.com/dop251/goja" +) + +type Context struct { + class jsc.Class[*Module, *Context] + Context context.Context + Logger logger.ContextLogger + Tag string + StartedAt time.Time + ErrorHandler func(error) +} + +func FromRuntime(runtime *goja.Runtime) *Context { + contextValue := runtime.Get("context") + if contextValue == nil { + return nil + } + context, isContext := contextValue.Export().(*Context) + if !isContext { + return nil + } + return context +} + +func MustFromRuntime(runtime *goja.Runtime) *Context { + context := FromRuntime(runtime) + if context == nil { + panic(runtime.NewTypeError("Missing sing-box context")) + } + return context +} + +func createContext(module *Module) jsc.Class[*Module, *Context] { + class := jsc.NewClass[*Module, *Context](module) + class.DefineMethod("toString", (*Context).toString) + return class +} + +func (c *Context) toString(call goja.FunctionCall) any { + return "[sing-box Context]" +} diff --git a/script/modules/boxctx/module.go b/script/modules/boxctx/module.go new file mode 100644 index 00000000..a18fe844 --- /dev/null +++ b/script/modules/boxctx/module.go @@ -0,0 +1,35 @@ +package boxctx + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + + "github.com/dop251/goja" +) + +const ModuleName = "context" + +type Module struct { + runtime *goja.Runtime + classContext jsc.Class[*Module, *Context] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classContext = createContext(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Context", m.classContext.ToValue()) +} + +func Enable(runtime *goja.Runtime, context *Context) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classContext := jsc.GetClass[*Module, *Context](runtime, exports, "Context") + context.class = classContext + runtime.Set("context", classContext.New(context)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/console/console.go b/script/modules/console/console.go new file mode 100644 index 00000000..4fcfec1f --- /dev/null +++ b/script/modules/console/console.go @@ -0,0 +1,281 @@ +package console + +import ( + "bytes" + "context" + "encoding/xml" + "sync" + "time" + + sLog "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + + "github.com/dop251/goja" +) + +type Console struct { + class jsc.Class[*Module, *Console] + access sync.Mutex + countMap map[string]int + timeMap map[string]time.Time +} + +func NewConsole(class jsc.Class[*Module, *Console]) goja.Value { + return class.New(&Console{ + class: class, + countMap: make(map[string]int), + timeMap: make(map[string]time.Time), + }) +} + +func createConsole(m *Module) jsc.Class[*Module, *Console] { + class := jsc.NewClass[*Module, *Console](m) + class.DefineMethod("assert", (*Console).assert) + class.DefineMethod("clear", (*Console).clear) + class.DefineMethod("count", (*Console).count) + class.DefineMethod("countReset", (*Console).countReset) + class.DefineMethod("debug", (*Console).debug) + class.DefineMethod("dir", (*Console).dir) + class.DefineMethod("dirxml", (*Console).dirxml) + class.DefineMethod("error", (*Console).error) + class.DefineMethod("group", (*Console).stub) + class.DefineMethod("groupCollapsed", (*Console).stub) + class.DefineMethod("groupEnd", (*Console).stub) + class.DefineMethod("info", (*Console).info) + class.DefineMethod("log", (*Console)._log) + class.DefineMethod("profile", (*Console).stub) + class.DefineMethod("profileEnd", (*Console).profileEnd) + class.DefineMethod("table", (*Console).table) + class.DefineMethod("time", (*Console).time) + class.DefineMethod("timeEnd", (*Console).timeEnd) + class.DefineMethod("timeLog", (*Console).timeLog) + class.DefineMethod("timeStamp", (*Console).stub) + class.DefineMethod("trace", (*Console).trace) + class.DefineMethod("warn", (*Console).warn) + return class +} + +func (c *Console) stub(call goja.FunctionCall) any { + return goja.Undefined() +} + +func (c *Console) assert(call goja.FunctionCall) any { + assertion := call.Argument(0).ToBoolean() + if !assertion { + return c.log(logger.ContextLogger.ErrorContext, call.Arguments[1:]) + } + return goja.Undefined() +} + +func (c *Console) clear(call goja.FunctionCall) any { + return nil +} + +func (c *Console) count(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + newValue := c.countMap[label] + 1 + c.countMap[label] = newValue + c.access.Unlock() + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", newValue)) + return goja.Undefined() +} + +func (c *Console) countReset(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + delete(c.countMap, label) + c.access.Unlock() + return goja.Undefined() +} + +func (c *Console) log(logFunc func(logger.ContextLogger, context.Context, ...any), args []goja.Value) any { + var buffer bytes.Buffer + var formatString string + if len(args) > 0 { + formatString = args[0].String() + } + format(c.class.Runtime(), &buffer, formatString, args[1:]...) + writeLog(c.class.Runtime(), logFunc, buffer.String()) + return goja.Undefined() +} + +func (c *Console) debug(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.DebugContext, call.Arguments) +} + +func (c *Console) dir(call goja.FunctionCall) any { + object := jsc.AssertObject(c.class.Runtime(), call.Argument(0), "object", false) + var buffer bytes.Buffer + for _, key := range object.Keys() { + value := object.Get(key) + buffer.WriteString(key) + buffer.WriteString(": ") + buffer.WriteString(value.String()) + buffer.WriteString("\n") + } + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String()) + return goja.Undefined() +} + +func (c *Console) dirxml(call goja.FunctionCall) any { + var buffer bytes.Buffer + encoder := xml.NewEncoder(&buffer) + encoder.Indent("", " ") + encoder.Encode(call.Argument(0).Export()) + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String()) + return goja.Undefined() +} + +func (c *Console) error(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.ErrorContext, call.Arguments) +} + +func (c *Console) info(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.InfoContext, call.Arguments) +} + +func (c *Console) _log(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.InfoContext, call.Arguments) +} + +func (c *Console) profileEnd(call goja.FunctionCall) any { + return goja.Undefined() +} + +func (c *Console) table(call goja.FunctionCall) any { + return c.dir(call) +} + +func (c *Console) time(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + c.timeMap[label] = time.Now() + c.access.Unlock() + return goja.Undefined() +} + +func (c *Console) timeEnd(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + startTime, ok := c.timeMap[label] + if !ok { + c.access.Unlock() + return goja.Undefined() + } + delete(c.timeMap, label) + c.access.Unlock() + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime).String(), " - - timer ended")) + return goja.Undefined() +} + +func (c *Console) timeLog(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + startTime, ok := c.timeMap[label] + c.access.Unlock() + if !ok { + writeLog(c.class.Runtime(), logger.ContextLogger.ErrorContext, F.ToString("Timer \"", label, "\" doesn't exist.")) + return goja.Undefined() + } + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime))) + return goja.Undefined() +} + +func (c *Console) trace(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.TraceContext, call.Arguments) +} + +func (c *Console) warn(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.WarnContext, call.Arguments) +} + +func writeLog(runtime *goja.Runtime, logFunc func(logger.ContextLogger, context.Context, ...any), message string) { + var ( + ctx context.Context + sLogger logger.ContextLogger + ) + boxCtx := boxctx.FromRuntime(runtime) + if boxCtx != nil { + ctx = boxCtx.Context + sLogger = boxCtx.Logger + } else { + ctx = context.Background() + sLogger = sLog.StdLogger() + } + logFunc(sLogger, ctx, message) +} + +func format(runtime *goja.Runtime, b *bytes.Buffer, f string, args ...goja.Value) { + pct := false + argNum := 0 + for _, chr := range f { + if pct { + if argNum < len(args) { + if format1(runtime, chr, args[argNum], b) { + argNum++ + } + } else { + b.WriteByte('%') + b.WriteRune(chr) + } + pct = false + } else { + if chr == '%' { + pct = true + } else { + b.WriteRune(chr) + } + } + } + + for _, arg := range args[argNum:] { + b.WriteByte(' ') + b.WriteString(arg.String()) + } +} + +func format1(runtime *goja.Runtime, f rune, val goja.Value, w *bytes.Buffer) bool { + switch f { + case 's': + w.WriteString(val.String()) + case 'd': + w.WriteString(val.ToNumber().String()) + case 'j': + if json, ok := runtime.Get("JSON").(*goja.Object); ok { + if stringify, ok := goja.AssertFunction(json.Get("stringify")); ok { + res, err := stringify(json, val) + if err != nil { + panic(err) + } + w.WriteString(res.String()) + } + } + case '%': + w.WriteByte('%') + return false + default: + w.WriteByte('%') + w.WriteRune(f) + return false + } + return true +} diff --git a/script/modules/console/context.go b/script/modules/console/context.go new file mode 100644 index 00000000..cfe522a5 --- /dev/null +++ b/script/modules/console/context.go @@ -0,0 +1,3 @@ +package console + +type Context struct{} diff --git a/script/modules/console/module.go b/script/modules/console/module.go new file mode 100644 index 00000000..4e7cf0ee --- /dev/null +++ b/script/modules/console/module.go @@ -0,0 +1,34 @@ +package console + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + + "github.com/dop251/goja" +) + +const ModuleName = "console" + +type Module struct { + runtime *goja.Runtime + console jsc.Class[*Module, *Console] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.console = createConsole(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Console", m.console.ToValue()) +} + +func Enable(runtime *goja.Runtime) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classConsole := jsc.GetClass[*Module, *Console](runtime, exports, "Console") + runtime.Set("console", NewConsole(classConsole)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/eventloop/eventloop.go b/script/modules/eventloop/eventloop.go new file mode 100644 index 00000000..33766bf9 --- /dev/null +++ b/script/modules/eventloop/eventloop.go @@ -0,0 +1,489 @@ +package eventloop + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/dop251/goja" +) + +type job struct { + cancel func() bool + fn func() + idx int + + cancelled bool +} + +type Timer struct { + job + timer *time.Timer +} + +type Interval struct { + job + ticker *time.Ticker + stopChan chan struct{} +} + +type Immediate struct { + job +} + +type EventLoop struct { + vm *goja.Runtime + jobChan chan func() + jobs []*job + jobCount int32 + canRun int32 + + auxJobsLock sync.Mutex + wakeupChan chan struct{} + + auxJobsSpare, auxJobs []func() + + stopLock sync.Mutex + stopCond *sync.Cond + running bool + terminated bool + + errorHandler func(error) +} + +func Enable(runtime *goja.Runtime, errorHandler func(error)) *EventLoop { + loop := &EventLoop{ + vm: runtime, + jobChan: make(chan func()), + wakeupChan: make(chan struct{}, 1), + errorHandler: errorHandler, + } + loop.stopCond = sync.NewCond(&loop.stopLock) + runtime.Set("setTimeout", loop.setTimeout) + runtime.Set("setInterval", loop.setInterval) + runtime.Set("setImmediate", loop.setImmediate) + runtime.Set("clearTimeout", loop.clearTimeout) + runtime.Set("clearInterval", loop.clearInterval) + runtime.Set("clearImmediate", loop.clearImmediate) + return loop +} + +func (loop *EventLoop) schedule(call goja.FunctionCall, repeating bool) goja.Value { + if fn, ok := goja.AssertFunction(call.Argument(0)); ok { + delay := call.Argument(1).ToInteger() + var args []goja.Value + if len(call.Arguments) > 2 { + args = append(args, call.Arguments[2:]...) + } + f := func() { + _, err := fn(nil, args...) + if err != nil { + loop.errorHandler(err) + } + } + loop.jobCount++ + var job *job + var ret goja.Value + if repeating { + interval := loop.newInterval(f) + interval.start(loop, time.Duration(delay)*time.Millisecond) + job = &interval.job + ret = loop.vm.ToValue(interval) + } else { + timeout := loop.newTimeout(f) + timeout.start(loop, time.Duration(delay)*time.Millisecond) + job = &timeout.job + ret = loop.vm.ToValue(timeout) + } + job.idx = len(loop.jobs) + loop.jobs = append(loop.jobs, job) + return ret + } + return nil +} + +func (loop *EventLoop) setTimeout(call goja.FunctionCall) goja.Value { + return loop.schedule(call, false) +} + +func (loop *EventLoop) setInterval(call goja.FunctionCall) goja.Value { + return loop.schedule(call, true) +} + +func (loop *EventLoop) setImmediate(call goja.FunctionCall) goja.Value { + if fn, ok := goja.AssertFunction(call.Argument(0)); ok { + var args []goja.Value + if len(call.Arguments) > 1 { + args = append(args, call.Arguments[1:]...) + } + f := func() { + _, err := fn(nil, args...) + if err != nil { + loop.errorHandler(err) + } + } + loop.jobCount++ + return loop.vm.ToValue(loop.addImmediate(f)) + } + return nil +} + +// SetTimeout schedules to run the specified function in the context +// of the loop as soon as possible after the specified timeout period. +// SetTimeout returns a Timer which can be passed to ClearTimeout. +// The instance of goja.Runtime that is passed to the function and any Values derived +// from it must not be used outside the function. SetTimeout is +// safe to call inside or outside the loop. +// If the loop is terminated (see Terminate()) returns nil. +func (loop *EventLoop) SetTimeout(fn func(*goja.Runtime), timeout time.Duration) *Timer { + t := loop.newTimeout(func() { fn(loop.vm) }) + if loop.addAuxJob(func() { + t.start(loop, timeout) + loop.jobCount++ + t.idx = len(loop.jobs) + loop.jobs = append(loop.jobs, &t.job) + }) { + return t + } + return nil +} + +// ClearTimeout cancels a Timer returned by SetTimeout if it has not run yet. +// ClearTimeout is safe to call inside or outside the loop. +func (loop *EventLoop) ClearTimeout(t *Timer) { + loop.addAuxJob(func() { + loop.clearTimeout(t) + }) +} + +// SetInterval schedules to repeatedly run the specified function in +// the context of the loop as soon as possible after every specified +// timeout period. SetInterval returns an Interval which can be +// passed to ClearInterval. The instance of goja.Runtime that is passed to the +// function and any Values derived from it must not be used outside +// the function. SetInterval is safe to call inside or outside the +// loop. +// If the loop is terminated (see Terminate()) returns nil. +func (loop *EventLoop) SetInterval(fn func(*goja.Runtime), timeout time.Duration) *Interval { + i := loop.newInterval(func() { fn(loop.vm) }) + if loop.addAuxJob(func() { + i.start(loop, timeout) + loop.jobCount++ + i.idx = len(loop.jobs) + loop.jobs = append(loop.jobs, &i.job) + }) { + return i + } + return nil +} + +// ClearInterval cancels an Interval returned by SetInterval. +// ClearInterval is safe to call inside or outside the loop. +func (loop *EventLoop) ClearInterval(i *Interval) { + loop.addAuxJob(func() { + loop.clearInterval(i) + }) +} + +func (loop *EventLoop) setRunning() { + loop.stopLock.Lock() + defer loop.stopLock.Unlock() + if loop.running { + panic("Loop is already started") + } + loop.running = true + atomic.StoreInt32(&loop.canRun, 1) + loop.auxJobsLock.Lock() + loop.terminated = false + loop.auxJobsLock.Unlock() +} + +// Run calls the specified function, starts the event loop and waits until there are no more delayed jobs to run +// after which it stops the loop and returns. +// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used +// outside the function. +// Do NOT use this function while the loop is already running. Use RunOnLoop() instead. +// If the loop is already started it will panic. +func (loop *EventLoop) Run(fn func(*goja.Runtime)) { + loop.setRunning() + fn(loop.vm) + loop.run(false) +} + +// Start the event loop in the background. The loop continues to run until Stop() is called. +// If the loop is already started it will panic. +func (loop *EventLoop) Start() { + loop.setRunning() + go loop.run(true) +} + +// StartInForeground starts the event loop in the current goroutine. The loop continues to run until Stop() is called. +// If the loop is already started it will panic. +// Use this instead of Start if you want to recover from panics that may occur while calling native Go functions from +// within setInterval and setTimeout callbacks. +func (loop *EventLoop) StartInForeground() { + loop.setRunning() + loop.run(true) +} + +// Stop the loop that was started with Start(). After this function returns there will be no more jobs executed +// by the loop. It is possible to call Start() or Run() again after this to resume the execution. +// Note, it does not cancel active timeouts (use Terminate() instead if you want this). +// It is not allowed to run Start() (or Run()) and Stop() or Terminate() concurrently. +// Calling Stop() on a non-running loop has no effect. +// It is not allowed to call Stop() from the loop, because it is synchronous and cannot complete until the loop +// is not running any jobs. Use StopNoWait() instead. +// return number of jobs remaining +func (loop *EventLoop) Stop() int { + loop.stopLock.Lock() + for loop.running { + atomic.StoreInt32(&loop.canRun, 0) + loop.wakeup() + loop.stopCond.Wait() + } + loop.stopLock.Unlock() + return int(loop.jobCount) +} + +// StopNoWait tells the loop to stop and returns immediately. Can be used inside the loop. Calling it on a +// non-running loop has no effect. +func (loop *EventLoop) StopNoWait() { + loop.stopLock.Lock() + if loop.running { + atomic.StoreInt32(&loop.canRun, 0) + loop.wakeup() + } + loop.stopLock.Unlock() +} + +// Terminate stops the loop and clears all active timeouts and intervals. After it returns there are no +// active timers or goroutines associated with the loop. Any attempt to submit a task (by using RunOnLoop(), +// SetTimeout() or SetInterval()) will not succeed. +// After being terminated the loop can be restarted again by using Start() or Run(). +// This method must not be called concurrently with Stop*(), Start(), or Run(). +func (loop *EventLoop) Terminate() { + loop.Stop() + + loop.auxJobsLock.Lock() + loop.terminated = true + loop.auxJobsLock.Unlock() + + loop.runAux() + + for i := 0; i < len(loop.jobs); i++ { + job := loop.jobs[i] + if !job.cancelled { + job.cancelled = true + if job.cancel() { + loop.removeJob(job) + i-- + } + } + } + + for len(loop.jobs) > 0 { + (<-loop.jobChan)() + } +} + +// RunOnLoop schedules to run the specified function in the context of the loop as soon as possible. +// The order of the runs is preserved (i.e. the functions will be called in the same order as calls to RunOnLoop()) +// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used +// outside the function. It is safe to call inside or outside the loop. +// Returns true on success or false if the loop is terminated (see Terminate()). +func (loop *EventLoop) RunOnLoop(fn func(*goja.Runtime)) bool { + return loop.addAuxJob(func() { fn(loop.vm) }) +} + +func (loop *EventLoop) runAux() { + loop.auxJobsLock.Lock() + jobs := loop.auxJobs + loop.auxJobs = loop.auxJobsSpare + loop.auxJobsLock.Unlock() + for i, job := range jobs { + job() + jobs[i] = nil + } + loop.auxJobsSpare = jobs[:0] +} + +func (loop *EventLoop) run(inBackground bool) { + loop.runAux() + if inBackground { + loop.jobCount++ + } +LOOP: + for loop.jobCount > 0 { + select { + case job := <-loop.jobChan: + job() + case <-loop.wakeupChan: + loop.runAux() + if atomic.LoadInt32(&loop.canRun) == 0 { + break LOOP + } + } + } + if inBackground { + loop.jobCount-- + } + + loop.stopLock.Lock() + loop.running = false + loop.stopLock.Unlock() + loop.stopCond.Broadcast() +} + +func (loop *EventLoop) wakeup() { + select { + case loop.wakeupChan <- struct{}{}: + default: + } +} + +func (loop *EventLoop) addAuxJob(fn func()) bool { + loop.auxJobsLock.Lock() + if loop.terminated { + loop.auxJobsLock.Unlock() + return false + } + loop.auxJobs = append(loop.auxJobs, fn) + loop.auxJobsLock.Unlock() + loop.wakeup() + return true +} + +func (loop *EventLoop) newTimeout(f func()) *Timer { + t := &Timer{ + job: job{fn: f}, + } + t.cancel = t.doCancel + + return t +} + +func (t *Timer) start(loop *EventLoop, timeout time.Duration) { + t.timer = time.AfterFunc(timeout, func() { + loop.jobChan <- func() { + loop.doTimeout(t) + } + }) +} + +func (loop *EventLoop) newInterval(f func()) *Interval { + i := &Interval{ + job: job{fn: f}, + stopChan: make(chan struct{}), + } + i.cancel = i.doCancel + + return i +} + +func (i *Interval) start(loop *EventLoop, timeout time.Duration) { + // https://nodejs.org/api/timers.html#timers_setinterval_callback_delay_args + if timeout <= 0 { + timeout = time.Millisecond + } + i.ticker = time.NewTicker(timeout) + go i.run(loop) +} + +func (loop *EventLoop) addImmediate(f func()) *Immediate { + i := &Immediate{ + job: job{fn: f}, + } + loop.addAuxJob(func() { + loop.doImmediate(i) + }) + return i +} + +func (loop *EventLoop) doTimeout(t *Timer) { + loop.removeJob(&t.job) + if !t.cancelled { + t.cancelled = true + loop.jobCount-- + t.fn() + } +} + +func (loop *EventLoop) doInterval(i *Interval) { + if !i.cancelled { + i.fn() + } +} + +func (loop *EventLoop) doImmediate(i *Immediate) { + if !i.cancelled { + i.cancelled = true + loop.jobCount-- + i.fn() + } +} + +func (loop *EventLoop) clearTimeout(t *Timer) { + if t != nil && !t.cancelled { + t.cancelled = true + loop.jobCount-- + if t.doCancel() { + loop.removeJob(&t.job) + } + } +} + +func (loop *EventLoop) clearInterval(i *Interval) { + if i != nil && !i.cancelled { + i.cancelled = true + loop.jobCount-- + i.doCancel() + } +} + +func (loop *EventLoop) removeJob(job *job) { + idx := job.idx + if idx < 0 { + return + } + if idx < len(loop.jobs)-1 { + loop.jobs[idx] = loop.jobs[len(loop.jobs)-1] + loop.jobs[idx].idx = idx + } + loop.jobs[len(loop.jobs)-1] = nil + loop.jobs = loop.jobs[:len(loop.jobs)-1] + job.idx = -1 +} + +func (loop *EventLoop) clearImmediate(i *Immediate) { + if i != nil && !i.cancelled { + i.cancelled = true + loop.jobCount-- + } +} + +func (i *Interval) doCancel() bool { + close(i.stopChan) + return false +} + +func (t *Timer) doCancel() bool { + return t.timer.Stop() +} + +func (i *Interval) run(loop *EventLoop) { +L: + for { + select { + case <-i.stopChan: + i.ticker.Stop() + break L + case <-i.ticker.C: + loop.jobChan <- func() { + loop.doInterval(i) + } + } + } + loop.jobChan <- func() { + loop.removeJob(&i.job) + } +} diff --git a/script/modules/require/module.go b/script/modules/require/module.go new file mode 100644 index 00000000..62ea3bfe --- /dev/null +++ b/script/modules/require/module.go @@ -0,0 +1,231 @@ +package require + +import ( + "errors" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "runtime" + "sync" + "syscall" + "text/template" + + js "github.com/dop251/goja" + "github.com/dop251/goja/parser" +) + +type ModuleLoader func(*js.Runtime, *js.Object) + +// SourceLoader represents a function that returns a file data at a given path. +// The function should return ModuleFileDoesNotExistError if the file either doesn't exist or is a directory. +// This error will be ignored by the resolver and the search will continue. Any other errors will be propagated. +type SourceLoader func(path string) ([]byte, error) + +var ( + InvalidModuleError = errors.New("Invalid module") + IllegalModuleNameError = errors.New("Illegal module name") + NoSuchBuiltInModuleError = errors.New("No such built-in module") + ModuleFileDoesNotExistError = errors.New("module file does not exist") +) + +// Registry contains a cache of compiled modules which can be used by multiple Runtimes +type Registry struct { + sync.Mutex + native map[string]ModuleLoader + builtin map[string]ModuleLoader + compiled map[string]*js.Program + + srcLoader SourceLoader + globalFolders []string + fsEnabled bool +} + +type RequireModule struct { + r *Registry + runtime *js.Runtime + modules map[string]*js.Object + nodeModules map[string]*js.Object +} + +func NewRegistry(opts ...Option) *Registry { + r := &Registry{} + + for _, opt := range opts { + opt(r) + } + + return r +} + +type Option func(*Registry) + +// WithLoader sets a function which will be called by the require() function in order to get a source code for a +// module at the given path. The same function will be used to get external source maps. +// Note, this only affects the modules loaded by the require() function. If you need to use it as a source map +// loader for code parsed in a different way (such as runtime.RunString() or eval()), use (*Runtime).SetParserOptions() +func WithLoader(srcLoader SourceLoader) Option { + return func(r *Registry) { + r.srcLoader = srcLoader + } +} + +// WithGlobalFolders appends the given paths to the registry's list of +// global folders to search if the requested module is not found +// elsewhere. By default, a registry's global folders list is empty. +// In the reference Node.js implementation, the default global folders +// list is $NODE_PATH, $HOME/.node_modules, $HOME/.node_libraries and +// $PREFIX/lib/node, see +// https://nodejs.org/api/modules.html#modules_loading_from_the_global_folders. +func WithGlobalFolders(globalFolders ...string) Option { + return func(r *Registry) { + r.globalFolders = globalFolders + } +} + +func WithFsEnable(enabled bool) Option { + return func(r *Registry) { + r.fsEnabled = enabled + } +} + +// Enable adds the require() function to the specified runtime. +func (r *Registry) Enable(runtime *js.Runtime) *RequireModule { + rrt := &RequireModule{ + r: r, + runtime: runtime, + modules: make(map[string]*js.Object), + nodeModules: make(map[string]*js.Object), + } + + runtime.Set("require", rrt.require) + return rrt +} + +func (r *Registry) RegisterNodeModule(name string, loader ModuleLoader) { + r.Lock() + defer r.Unlock() + + if r.builtin == nil { + r.builtin = make(map[string]ModuleLoader) + } + name = filepathClean(name) + r.builtin[name] = loader +} + +func (r *Registry) RegisterNativeModule(name string, loader ModuleLoader) { + r.Lock() + defer r.Unlock() + + if r.native == nil { + r.native = make(map[string]ModuleLoader) + } + name = filepathClean(name) + r.native[name] = loader +} + +// DefaultSourceLoader is used if none was set (see WithLoader()). It simply loads files from the host's filesystem. +func DefaultSourceLoader(filename string) ([]byte, error) { + fp := filepath.FromSlash(filename) + f, err := os.Open(fp) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + err = ModuleFileDoesNotExistError + } else if runtime.GOOS == "windows" { + if errors.Is(err, syscall.Errno(0x7b)) { // ERROR_INVALID_NAME, The filename, directory name, or volume label syntax is incorrect. + err = ModuleFileDoesNotExistError + } + } + return nil, err + } + + defer f.Close() + // On some systems (e.g. plan9 and FreeBSD) it is possible to use the standard read() call on directories + // which means we cannot rely on read() returning an error, we have to do stat() instead. + if fi, err := f.Stat(); err == nil { + if fi.IsDir() { + return nil, ModuleFileDoesNotExistError + } + } else { + return nil, err + } + return io.ReadAll(f) +} + +func (r *Registry) getSource(p string) ([]byte, error) { + srcLoader := r.srcLoader + if srcLoader == nil { + srcLoader = DefaultSourceLoader + } + return srcLoader(p) +} + +func (r *Registry) getCompiledSource(p string) (*js.Program, error) { + r.Lock() + defer r.Unlock() + + prg := r.compiled[p] + if prg == nil { + buf, err := r.getSource(p) + if err != nil { + return nil, err + } + s := string(buf) + + if path.Ext(p) == ".json" { + s = "module.exports = JSON.parse('" + template.JSEscapeString(s) + "')" + } + + source := "(function(exports, require, module) {" + s + "\n})" + parsed, err := js.Parse(p, source, parser.WithSourceMapLoader(r.srcLoader)) + if err != nil { + return nil, err + } + prg, err = js.CompileAST(parsed, false) + if err == nil { + if r.compiled == nil { + r.compiled = make(map[string]*js.Program) + } + r.compiled[p] = prg + } + return prg, err + } + return prg, nil +} + +func (r *RequireModule) require(call js.FunctionCall) js.Value { + ret, err := r.Require(call.Argument(0).String()) + if err != nil { + if _, ok := err.(*js.Exception); !ok { + panic(r.runtime.NewGoError(err)) + } + panic(err) + } + return ret +} + +func filepathClean(p string) string { + return path.Clean(p) +} + +// Require can be used to import modules from Go source (similar to JS require() function). +func (r *RequireModule) Require(p string) (ret js.Value, err error) { + module, err := r.resolve(p) + if err != nil { + return + } + ret = module.Get("exports") + return +} + +func Require(runtime *js.Runtime, name string) js.Value { + if r, ok := js.AssertFunction(runtime.Get("require")); ok { + mod, err := r(js.Undefined(), runtime.ToValue(name)) + if err != nil { + panic(err) + } + return mod + } + panic(runtime.NewTypeError("Please enable require for this runtime using new(require.Registry).Enable(runtime)")) +} diff --git a/script/modules/require/resolve.go b/script/modules/require/resolve.go new file mode 100644 index 00000000..e52fc2ff --- /dev/null +++ b/script/modules/require/resolve.go @@ -0,0 +1,277 @@ +package require + +import ( + "encoding/json" + "errors" + "path" + "path/filepath" + "runtime" + "strings" + + js "github.com/dop251/goja" +) + +const NodePrefix = "node:" + +// NodeJS module search algorithm described by +// https://nodejs.org/api/modules.html#modules_all_together +func (r *RequireModule) resolve(modpath string) (module *js.Object, err error) { + origPath, modpath := modpath, filepathClean(modpath) + if modpath == "" { + return nil, IllegalModuleNameError + } + + var start string + err = nil + if path.IsAbs(origPath) { + start = "/" + } else { + start = r.getCurrentModulePath() + } + + p := path.Join(start, modpath) + if isFileOrDirectoryPath(origPath) && r.r.fsEnabled { + if module = r.modules[p]; module != nil { + return + } + module, err = r.loadAsFileOrDirectory(p) + if err == nil && module != nil { + r.modules[p] = module + } + } else { + module, err = r.loadNative(origPath) + if err == nil { + return + } else { + if err == InvalidModuleError { + err = nil + } else { + return + } + } + if module = r.nodeModules[p]; module != nil { + return + } + if r.r.fsEnabled { + module, err = r.loadNodeModules(modpath, start) + if err == nil && module != nil { + r.nodeModules[p] = module + } + } + } + + if module == nil && err == nil { + err = InvalidModuleError + } + return +} + +func (r *RequireModule) loadNative(path string) (*js.Object, error) { + module := r.modules[path] + if module != nil { + return module, nil + } + + var ldr ModuleLoader + if r.r.native != nil { + ldr = r.r.native[path] + } + var isBuiltIn, withPrefix bool + if ldr == nil { + if r.r.builtin != nil { + ldr = r.r.builtin[path] + } + if ldr == nil && strings.HasPrefix(path, NodePrefix) { + ldr = r.r.builtin[path[len(NodePrefix):]] + if ldr == nil { + return nil, NoSuchBuiltInModuleError + } + withPrefix = true + } + isBuiltIn = true + } + + if ldr != nil { + module = r.createModuleObject() + r.modules[path] = module + if isBuiltIn { + if withPrefix { + r.modules[path[len(NodePrefix):]] = module + } else { + if !strings.HasPrefix(path, NodePrefix) { + r.modules[NodePrefix+path] = module + } + } + } + ldr(r.runtime, module) + return module, nil + } + + return nil, InvalidModuleError +} + +func (r *RequireModule) loadAsFileOrDirectory(path string) (module *js.Object, err error) { + if module, err = r.loadAsFile(path); module != nil || err != nil { + return + } + + return r.loadAsDirectory(path) +} + +func (r *RequireModule) loadAsFile(path string) (module *js.Object, err error) { + if module, err = r.loadModule(path); module != nil || err != nil { + return + } + + p := path + ".js" + if module, err = r.loadModule(p); module != nil || err != nil { + return + } + + p = path + ".json" + return r.loadModule(p) +} + +func (r *RequireModule) loadIndex(modpath string) (module *js.Object, err error) { + p := path.Join(modpath, "index.js") + if module, err = r.loadModule(p); module != nil || err != nil { + return + } + + p = path.Join(modpath, "index.json") + return r.loadModule(p) +} + +func (r *RequireModule) loadAsDirectory(modpath string) (module *js.Object, err error) { + p := path.Join(modpath, "package.json") + buf, err := r.r.getSource(p) + if err != nil { + return r.loadIndex(modpath) + } + var pkg struct { + Main string + } + err = json.Unmarshal(buf, &pkg) + if err != nil || len(pkg.Main) == 0 { + return r.loadIndex(modpath) + } + + m := path.Join(modpath, pkg.Main) + if module, err = r.loadAsFile(m); module != nil || err != nil { + return + } + + return r.loadIndex(m) +} + +func (r *RequireModule) loadNodeModule(modpath, start string) (*js.Object, error) { + return r.loadAsFileOrDirectory(path.Join(start, modpath)) +} + +func (r *RequireModule) loadNodeModules(modpath, start string) (module *js.Object, err error) { + for _, dir := range r.r.globalFolders { + if module, err = r.loadNodeModule(modpath, dir); module != nil || err != nil { + return + } + } + for { + var p string + if path.Base(start) != "node_modules" { + p = path.Join(start, "node_modules") + } else { + p = start + } + if module, err = r.loadNodeModule(modpath, p); module != nil || err != nil { + return + } + if start == ".." { // Dir('..') is '.' + break + } + parent := path.Dir(start) + if parent == start { + break + } + start = parent + } + + return +} + +func (r *RequireModule) getCurrentModulePath() string { + var buf [2]js.StackFrame + frames := r.runtime.CaptureCallStack(2, buf[:0]) + if len(frames) < 2 { + return "." + } + return path.Dir(frames[1].SrcName()) +} + +func (r *RequireModule) createModuleObject() *js.Object { + module := r.runtime.NewObject() + module.Set("exports", r.runtime.NewObject()) + return module +} + +func (r *RequireModule) loadModule(path string) (*js.Object, error) { + module := r.modules[path] + if module == nil { + module = r.createModuleObject() + r.modules[path] = module + err := r.loadModuleFile(path, module) + if err != nil { + module = nil + delete(r.modules, path) + if errors.Is(err, ModuleFileDoesNotExistError) { + err = nil + } + } + return module, err + } + return module, nil +} + +func (r *RequireModule) loadModuleFile(path string, jsModule *js.Object) error { + prg, err := r.r.getCompiledSource(path) + if err != nil { + return err + } + + f, err := r.runtime.RunProgram(prg) + if err != nil { + return err + } + + if call, ok := js.AssertFunction(f); ok { + jsExports := jsModule.Get("exports") + jsRequire := r.runtime.Get("require") + + // Run the module source, with "jsExports" as "this", + // "jsExports" as the "exports" variable, "jsRequire" + // as the "require" variable and "jsModule" as the + // "module" variable (Nodejs capable). + _, err = call(jsExports, jsExports, jsRequire, jsModule) + if err != nil { + return err + } + } else { + return InvalidModuleError + } + + return nil +} + +func isFileOrDirectoryPath(path string) bool { + result := path == "." || path == ".." || + strings.HasPrefix(path, "/") || + strings.HasPrefix(path, "./") || + strings.HasPrefix(path, "../") + + if runtime.GOOS == "windows" { + result = result || + strings.HasPrefix(path, `.\`) || + strings.HasPrefix(path, `..\`) || + filepath.IsAbs(path) + } + + return result +} diff --git a/script/modules/sgnotification/module.go b/script/modules/sgnotification/module.go new file mode 100644 index 00000000..918d07fd --- /dev/null +++ b/script/modules/sgnotification/module.go @@ -0,0 +1,111 @@ +package sgnotification + +import ( + "context" + "encoding/base64" + "strings" + + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/script/jsc" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" + + "github.com/dop251/goja" +) + +type SurgeNotification struct { + vm *goja.Runtime + logger logger.Logger + platformInterface platform.Interface + scriptTag string +} + +func Enable(vm *goja.Runtime, ctx context.Context, logger logger.Logger) { + platformInterface := service.FromContext[platform.Interface](ctx) + notification := &SurgeNotification{ + vm: vm, + logger: logger, + platformInterface: platformInterface, + } + notificationObject := vm.NewObject() + notificationObject.Set("post", notification.js_post) + vm.Set("$notification", notificationObject) +} + +func (s *SurgeNotification) js_post(call goja.FunctionCall) goja.Value { + var ( + title string + subtitle string + body string + openURL string + clipboard string + mediaURL string + mediaData []byte + mediaType string + autoDismiss int + ) + title = jsc.AssertString(s.vm, call.Argument(0), "title", true) + subtitle = jsc.AssertString(s.vm, call.Argument(1), "subtitle", true) + body = jsc.AssertString(s.vm, call.Argument(2), "body", true) + options := jsc.AssertObject(s.vm, call.Argument(3), "options", true) + if options != nil { + action := jsc.AssertString(s.vm, options.Get("action"), "options.action", true) + switch action { + case "open-url": + openURL = jsc.AssertString(s.vm, options.Get("url"), "options.url", false) + case "clipboard": + clipboard = jsc.AssertString(s.vm, options.Get("clipboard"), "options.clipboard", false) + } + mediaURL = jsc.AssertString(s.vm, options.Get("media-url"), "options.media-url", true) + mediaBase64 := jsc.AssertString(s.vm, options.Get("media-base64"), "options.media-base64", true) + if mediaBase64 != "" { + mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64) + if err != nil { + panic(s.vm.NewGoError(E.Cause(err, "decode media-base64"))) + } + mediaData = mediaBinary + mediaType = jsc.AssertString(s.vm, options.Get("media-base64-mime"), "options.media-base64-mime", false) + } + autoDismiss = int(jsc.AssertInt(s.vm, options.Get("auto-dismiss"), "options.auto-dismiss", true)) + } + if title != "" && subtitle == "" && body == "" { + body = title + title = "" + } else if title != "" && subtitle != "" && body == "" { + body = subtitle + subtitle = "" + } + var builder strings.Builder + if title != "" { + builder.WriteString("[") + builder.WriteString(title) + if subtitle != "" { + builder.WriteString(" - ") + builder.WriteString(subtitle) + } + builder.WriteString("]: ") + } + builder.WriteString(body) + s.logger.Info("notification: " + builder.String()) + if s.platformInterface != nil { + err := s.platformInterface.SendNotification(&platform.Notification{ + Identifier: "surge-script-notification-" + s.scriptTag, + TypeName: "Surge Script Notification (" + s.scriptTag + ")", + TypeID: 11, + Title: title, + Subtitle: subtitle, + Body: body, + OpenURL: openURL, + Clipboard: clipboard, + MediaURL: mediaURL, + MediaData: mediaData, + MediaType: mediaType, + Timeout: autoDismiss, + }) + if err != nil { + s.logger.Error(E.Cause(err, "send notification")) + } + } + return goja.Undefined() +} diff --git a/script/modules/surge/environment.go b/script/modules/surge/environment.go new file mode 100644 index 00000000..590469c6 --- /dev/null +++ b/script/modules/surge/environment.go @@ -0,0 +1,65 @@ +package surge + +import ( + "runtime" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/locale" + "github.com/sagernet/sing-box/script/jsc" + + "github.com/dop251/goja" +) + +type Environment struct { + class jsc.Class[*Module, *Environment] +} + +func createEnvironment(module *Module) jsc.Class[*Module, *Environment] { + class := jsc.NewClass[*Module, *Environment](module) + class.DefineField("system", (*Environment).getSystem, nil) + class.DefineField("surge-build", (*Environment).getSurgeBuild, nil) + class.DefineField("surge-version", (*Environment).getSurgeVersion, nil) + class.DefineField("language", (*Environment).getLanguage, nil) + class.DefineField("device-model", (*Environment).getDeviceModel, nil) + class.DefineMethod("toString", (*Environment).toString) + return class +} + +func (e *Environment) getSystem() any { + switch runtime.GOOS { + case "ios": + return "iOS" + case "darwin": + return "macOS" + case "tvos": + return "tvOS" + case "linux": + return "Linux" + case "android": + return "Android" + case "windows": + return "Windows" + default: + return runtime.GOOS + } +} + +func (e *Environment) getSurgeBuild() any { + return "N/A" +} + +func (e *Environment) getSurgeVersion() any { + return "sing-box " + C.Version +} + +func (e *Environment) getLanguage() any { + return locale.Current().Locale +} + +func (e *Environment) getDeviceModel() any { + return "N/A" +} + +func (e *Environment) toString(call goja.FunctionCall) any { + return "[sing-box Surge environment" +} diff --git a/script/modules/surge/http.go b/script/modules/surge/http.go new file mode 100644 index 00000000..49aef0d8 --- /dev/null +++ b/script/modules/surge/http.go @@ -0,0 +1,150 @@ +package surge + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "net/http/cookiejar" + "time" + + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" + "golang.org/x/net/publicsuffix" +) + +type HTTP struct { + class jsc.Class[*Module, *HTTP] + cookieJar *cookiejar.Jar + httpTransport *http.Transport +} + +func createHTTP(module *Module) jsc.Class[*Module, *HTTP] { + class := jsc.NewClass[*Module, *HTTP](module) + class.DefineConstructor(newHTTP) + class.DefineMethod("get", httpRequest(http.MethodGet)) + class.DefineMethod("post", httpRequest(http.MethodPost)) + class.DefineMethod("put", httpRequest(http.MethodPut)) + class.DefineMethod("delete", httpRequest(http.MethodDelete)) + class.DefineMethod("head", httpRequest(http.MethodHead)) + class.DefineMethod("options", httpRequest(http.MethodOptions)) + class.DefineMethod("patch", httpRequest(http.MethodPatch)) + class.DefineMethod("trace", httpRequest(http.MethodTrace)) + class.DefineMethod("toString", (*HTTP).toString) + return class +} + +func newHTTP(class jsc.Class[*Module, *HTTP], call goja.ConstructorCall) *HTTP { + return &HTTP{ + class: class, + cookieJar: common.Must1(cookiejar.New(&cookiejar.Options{ + PublicSuffixList: publicsuffix.List, + })), + httpTransport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{}, + }, + } +} + +func httpRequest(method string) func(s *HTTP, call goja.FunctionCall) any { + return func(s *HTTP, call goja.FunctionCall) any { + if len(call.Arguments) != 2 { + panic(s.class.Runtime().NewTypeError("invalid arguments")) + } + context := boxctx.MustFromRuntime(s.class.Runtime()) + var ( + url string + headers http.Header + body []byte + timeout = 5 * time.Second + insecure bool + autoCookie bool = true + autoRedirect bool + // policy string + binaryMode bool + ) + switch optionsValue := call.Argument(0).(type) { + case goja.String: + url = optionsValue.String() + case *goja.Object: + url = jsc.AssertString(s.class.Runtime(), optionsValue.Get("url"), "options.url", false) + headers = jsc.AssertHTTPHeader(s.class.Runtime(), optionsValue.Get("headers"), "option.headers") + body = jsc.AssertStringBinary(s.class.Runtime(), optionsValue.Get("body"), "options.body", true) + timeoutInt := jsc.AssertInt(s.class.Runtime(), optionsValue.Get("timeout"), "options.timeout", true) + if timeoutInt > 0 { + timeout = time.Duration(timeoutInt) * time.Second + } + insecure = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("insecure"), "options.insecure", true) + autoCookie = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-cookie"), "options.auto-cookie", true) + autoRedirect = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-redirect"), "options.auto-redirect", true) + // policy = jsc.AssertString(s.class.Runtime(), optionsValue.Get("policy"), "options.policy", true) + binaryMode = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("binary-mode"), "options.binary-mode", true) + default: + panic(s.class.Runtime().NewTypeError(F.ToString("invalid argument: options: expected string or object, but got ", optionsValue))) + } + callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(1), "callback") + s.httpTransport.TLSClientConfig.InsecureSkipVerify = insecure + httpClient := &http.Client{ + Timeout: timeout, + Transport: s.httpTransport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if autoRedirect { + return nil + } + return http.ErrUseLastResponse + }, + } + if autoCookie { + httpClient.Jar = s.cookieJar + } + request, err := http.NewRequestWithContext(context.Context, method, url, bytes.NewReader(body)) + if host := headers.Get("Host"); host != "" { + request.Host = host + headers.Del("Host") + } + request.Header = headers + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + go func() { + defer s.httpTransport.CloseIdleConnections() + response, executeErr := httpClient.Do(request) + if err != nil { + _, err = callback(nil, s.class.Runtime().NewGoError(executeErr), nil, nil) + if err != nil { + context.ErrorHandler(err) + } + return + } + defer response.Body.Close() + var content []byte + content, err = io.ReadAll(response.Body) + if err != nil { + _, err = callback(nil, s.class.Runtime().NewGoError(err), nil, nil) + if err != nil { + context.ErrorHandler(err) + } + } + responseObject := s.class.Runtime().NewObject() + responseObject.Set("status", response.StatusCode) + responseObject.Set("headers", jsc.HeadersToValue(s.class.Runtime(), response.Header)) + var bodyValue goja.Value + if binaryMode { + bodyValue = jsc.NewUint8Array(s.class.Runtime(), content) + } else { + bodyValue = s.class.Runtime().ToValue(string(content)) + } + _, err = callback(nil, nil, responseObject, bodyValue) + }() + return nil + } +} + +func (h *HTTP) toString(call goja.FunctionCall) any { + return "[sing-box Surge HTTP]" +} diff --git a/script/modules/surge/module.go b/script/modules/surge/module.go new file mode 100644 index 00000000..f3394426 --- /dev/null +++ b/script/modules/surge/module.go @@ -0,0 +1,63 @@ +package surge + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + "github.com/sagernet/sing/common" + + "github.com/dop251/goja" +) + +const ModuleName = "surge" + +type Module struct { + runtime *goja.Runtime + classScript jsc.Class[*Module, *Script] + classEnvironment jsc.Class[*Module, *Environment] + classPersistentStore jsc.Class[*Module, *PersistentStore] + classHTTP jsc.Class[*Module, *HTTP] + classUtils jsc.Class[*Module, *Utils] + classNotification jsc.Class[*Module, *Notification] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classScript = createScript(m) + m.classEnvironment = createEnvironment(m) + m.classPersistentStore = createPersistentStore(m) + m.classHTTP = createHTTP(m) + m.classUtils = createUtils(m) + m.classNotification = createNotification(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Script", m.classScript.ToValue()) + exports.Set("Environment", m.classEnvironment.ToValue()) + exports.Set("PersistentStore", m.classPersistentStore.ToValue()) + exports.Set("HTTP", m.classHTTP.ToValue()) + exports.Set("Utils", m.classUtils.ToValue()) + exports.Set("Notification", m.classNotification.ToValue()) +} + +func Enable(runtime *goja.Runtime, scriptType string, args []string) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classScript := jsc.GetClass[*Module, *Script](runtime, exports, "Script") + classEnvironment := jsc.GetClass[*Module, *Environment](runtime, exports, "Environment") + classPersistentStore := jsc.GetClass[*Module, *PersistentStore](runtime, exports, "PersistentStore") + classHTTP := jsc.GetClass[*Module, *HTTP](runtime, exports, "HTTP") + classUtils := jsc.GetClass[*Module, *Utils](runtime, exports, "Utils") + classNotification := jsc.GetClass[*Module, *Notification](runtime, exports, "Notification") + runtime.Set("$script", classScript.New(&Script{class: classScript, ScriptType: scriptType})) + runtime.Set("$environment", classEnvironment.New(&Environment{class: classEnvironment})) + runtime.Set("$persistentStore", newPersistentStore(classPersistentStore)) + runtime.Set("$http", classHTTP.New(newHTTP(classHTTP, goja.ConstructorCall{}))) + runtime.Set("$utils", classUtils.New(&Utils{class: classUtils})) + runtime.Set("$notification", newNotification(classNotification)) + runtime.Set("$argument", runtime.NewArray(common.Map(args, func(it string) any { + return it + })...)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/surge/notification.go b/script/modules/surge/notification.go new file mode 100644 index 00000000..4f330388 --- /dev/null +++ b/script/modules/surge/notification.go @@ -0,0 +1,120 @@ +package surge + +import ( + "encoding/base64" + "strings" + + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" + + "github.com/dop251/goja" +) + +type Notification struct { + class jsc.Class[*Module, *Notification] + logger logger.ContextLogger + tag string + platformInterface platform.Interface +} + +func createNotification(module *Module) jsc.Class[*Module, *Notification] { + class := jsc.NewClass[*Module, *Notification](module) + class.DefineMethod("post", (*Notification).post) + class.DefineMethod("toString", (*Notification).toString) + return class +} + +func newNotification(class jsc.Class[*Module, *Notification]) goja.Value { + context := boxctx.MustFromRuntime(class.Runtime()) + return class.New(&Notification{ + class: class, + logger: context.Logger, + tag: context.Tag, + platformInterface: service.FromContext[platform.Interface](context.Context), + }) +} + +func (s *Notification) post(call goja.FunctionCall) any { + var ( + title string + subtitle string + body string + openURL string + clipboard string + mediaURL string + mediaData []byte + mediaType string + autoDismiss int + ) + title = jsc.AssertString(s.class.Runtime(), call.Argument(0), "title", true) + subtitle = jsc.AssertString(s.class.Runtime(), call.Argument(1), "subtitle", true) + body = jsc.AssertString(s.class.Runtime(), call.Argument(2), "body", true) + options := jsc.AssertObject(s.class.Runtime(), call.Argument(3), "options", true) + if options != nil { + action := jsc.AssertString(s.class.Runtime(), options.Get("action"), "options.action", true) + switch action { + case "open-url": + openURL = jsc.AssertString(s.class.Runtime(), options.Get("url"), "options.url", false) + case "clipboard": + clipboard = jsc.AssertString(s.class.Runtime(), options.Get("clipboard"), "options.clipboard", false) + } + mediaURL = jsc.AssertString(s.class.Runtime(), options.Get("media-url"), "options.media-url", true) + mediaBase64 := jsc.AssertString(s.class.Runtime(), options.Get("media-base64"), "options.media-base64", true) + if mediaBase64 != "" { + mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64) + if err != nil { + panic(s.class.Runtime().NewGoError(E.Cause(err, "decode media-base64"))) + } + mediaData = mediaBinary + mediaType = jsc.AssertString(s.class.Runtime(), options.Get("media-base64-mime"), "options.media-base64-mime", false) + } + autoDismiss = int(jsc.AssertInt(s.class.Runtime(), options.Get("auto-dismiss"), "options.auto-dismiss", true)) + } + if title != "" && subtitle == "" && body == "" { + body = title + title = "" + } else if title != "" && subtitle != "" && body == "" { + body = subtitle + subtitle = "" + } + var builder strings.Builder + if title != "" { + builder.WriteString("[") + builder.WriteString(title) + if subtitle != "" { + builder.WriteString(" - ") + builder.WriteString(subtitle) + } + builder.WriteString("]: ") + } + builder.WriteString(body) + s.logger.Info("notification: " + builder.String()) + if s.platformInterface != nil { + err := s.platformInterface.SendNotification(&platform.Notification{ + Identifier: "surge-script-notification-" + s.tag, + TypeName: "Surge Script Notification (" + s.tag + ")", + TypeID: 11, + Title: title, + Subtitle: subtitle, + Body: body, + OpenURL: openURL, + Clipboard: clipboard, + MediaURL: mediaURL, + MediaData: mediaData, + MediaType: mediaType, + Timeout: autoDismiss, + }) + if err != nil { + s.logger.Error(E.Cause(err, "send notification")) + } + } + return nil +} + +func (s *Notification) toString(call goja.FunctionCall) any { + return "[sing-box Surge notification]" +} diff --git a/script/modules/surge/persistent_store.go b/script/modules/surge/persistent_store.go new file mode 100644 index 00000000..7c40f2fa --- /dev/null +++ b/script/modules/surge/persistent_store.go @@ -0,0 +1,78 @@ +package surge + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing/service" + + "github.com/dop251/goja" +) + +type PersistentStore struct { + class jsc.Class[*Module, *PersistentStore] + cacheFile adapter.CacheFile + inMemoryCache *adapter.SurgeInMemoryCache + tag string +} + +func createPersistentStore(module *Module) jsc.Class[*Module, *PersistentStore] { + class := jsc.NewClass[*Module, *PersistentStore](module) + class.DefineMethod("get", (*PersistentStore).get) + class.DefineMethod("set", (*PersistentStore).set) + class.DefineMethod("toString", (*PersistentStore).toString) + return class +} + +func newPersistentStore(class jsc.Class[*Module, *PersistentStore]) goja.Value { + boxCtx := boxctx.MustFromRuntime(class.Runtime()) + return class.New(&PersistentStore{ + class: class, + cacheFile: service.FromContext[adapter.CacheFile](boxCtx.Context), + inMemoryCache: service.FromContext[adapter.ScriptManager](boxCtx.Context).SurgeCache(), + tag: boxCtx.Tag, + }) +} + +func (s *PersistentStore) get(call goja.FunctionCall) any { + key := jsc.AssertString(s.class.Runtime(), call.Argument(0), "key", true) + if key == "" { + key = s.tag + } + var value string + if s.cacheFile != nil { + value = s.cacheFile.SurgePersistentStoreRead(key) + } else { + s.inMemoryCache.RLock() + value = s.inMemoryCache.Data[key] + s.inMemoryCache.RUnlock() + } + if value == "" { + return goja.Null() + } else { + return value + } +} + +func (s *PersistentStore) set(call goja.FunctionCall) any { + data := jsc.AssertString(s.class.Runtime(), call.Argument(0), "data", true) + key := jsc.AssertString(s.class.Runtime(), call.Argument(1), "key", true) + if key == "" { + key = s.tag + } + if s.cacheFile != nil { + err := s.cacheFile.SurgePersistentStoreWrite(key, data) + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + } else { + s.inMemoryCache.Lock() + s.inMemoryCache.Data[key] = data + s.inMemoryCache.Unlock() + } + return goja.Undefined() +} + +func (s *PersistentStore) toString(call goja.FunctionCall) any { + return "[sing-box Surge persistentStore]" +} diff --git a/script/modules/surge/script.go b/script/modules/surge/script.go new file mode 100644 index 00000000..de106ec8 --- /dev/null +++ b/script/modules/surge/script.go @@ -0,0 +1,32 @@ +package surge + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + F "github.com/sagernet/sing/common/format" +) + +type Script struct { + class jsc.Class[*Module, *Script] + ScriptType string +} + +func createScript(module *Module) jsc.Class[*Module, *Script] { + class := jsc.NewClass[*Module, *Script](module) + class.DefineField("name", (*Script).getName, nil) + class.DefineField("type", (*Script).getType, nil) + class.DefineField("startTime", (*Script).getStartTime, nil) + return class +} + +func (s *Script) getName() any { + return F.ToString("script:", boxctx.MustFromRuntime(s.class.Runtime()).Tag) +} + +func (s *Script) getType() any { + return s.ScriptType +} + +func (s *Script) getStartTime() any { + return boxctx.MustFromRuntime(s.class.Runtime()).StartedAt +} diff --git a/script/modules/surge/utils.go b/script/modules/surge/utils.go new file mode 100644 index 00000000..9320ab1c --- /dev/null +++ b/script/modules/surge/utils.go @@ -0,0 +1,50 @@ +package surge + +import ( + "bytes" + "compress/gzip" + "io" + + "github.com/sagernet/sing-box/script/jsc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/dop251/goja" +) + +type Utils struct { + class jsc.Class[*Module, *Utils] +} + +func createUtils(module *Module) jsc.Class[*Module, *Utils] { + class := jsc.NewClass[*Module, *Utils](module) + class.DefineMethod("geoip", (*Utils).stub) + class.DefineMethod("ipasn", (*Utils).stub) + class.DefineMethod("ipaso", (*Utils).stub) + class.DefineMethod("ungzip", (*Utils).ungzip) + class.DefineMethod("toString", (*Utils).toString) + return class +} + +func (u *Utils) stub(call goja.FunctionCall) any { + return nil +} + +func (u *Utils) ungzip(call goja.FunctionCall) any { + if len(call.Arguments) != 1 { + panic(u.class.Runtime().NewGoError(E.New("invalid argument"))) + } + binary := jsc.AssertBinary(u.class.Runtime(), call.Argument(0), "binary", false) + reader, err := gzip.NewReader(bytes.NewReader(binary)) + if err != nil { + panic(u.class.Runtime().NewGoError(err)) + } + binary, err = io.ReadAll(reader) + if err != nil { + panic(u.class.Runtime().NewGoError(err)) + } + return jsc.NewUint8Array(u.class.Runtime(), binary) +} + +func (u *Utils) toString(call goja.FunctionCall) any { + return "[sing-box Surge utils]" +} diff --git a/script/modules/url/escape.go b/script/modules/url/escape.go new file mode 100644 index 00000000..93c8ab1b --- /dev/null +++ b/script/modules/url/escape.go @@ -0,0 +1,55 @@ +package url + +import "strings" + +var tblEscapeURLQuery = [128]byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, +} + +// The code below is mostly borrowed from the standard Go url package + +const upperhex = "0123456789ABCDEF" + +func escape(s string, table *[128]byte, spaceToPlus bool) string { + spaceCount, hexCount := 0, 0 + for i := 0; i < len(s); i++ { + c := s[i] + if c > 127 || table[c] == 0 { + if c == ' ' && spaceToPlus { + spaceCount++ + } else { + hexCount++ + } + } + } + + if spaceCount == 0 && hexCount == 0 { + return s + } + + var sb strings.Builder + hexBuf := [3]byte{'%', 0, 0} + + sb.Grow(len(s) + 2*hexCount) + + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c == ' ' && spaceToPlus: + sb.WriteByte('+') + case c > 127 || table[c] == 0: + hexBuf[1] = upperhex[c>>4] + hexBuf[2] = upperhex[c&15] + sb.Write(hexBuf[:]) + default: + sb.WriteByte(c) + } + } + return sb.String() +} diff --git a/script/modules/url/module.go b/script/modules/url/module.go new file mode 100644 index 00000000..11b4b6c4 --- /dev/null +++ b/script/modules/url/module.go @@ -0,0 +1,41 @@ +package url + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + + "github.com/dop251/goja" +) + +const ModuleName = "url" + +var _ jsc.Module = (*Module)(nil) + +type Module struct { + runtime *goja.Runtime + classURL jsc.Class[*Module, *URL] + classURLSearchParams jsc.Class[*Module, *URLSearchParams] + classURLSearchParamsIterator jsc.Class[*Module, *jsc.Iterator[*Module, searchParam]] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classURL = createURL(m) + m.classURLSearchParams = createURLSearchParams(m) + m.classURLSearchParamsIterator = jsc.CreateIterator[*Module, searchParam](m) + exports := module.Get("exports").(*goja.Object) + exports.Set("URL", m.classURL.ToValue()) + exports.Set("URLSearchParams", m.classURLSearchParams.ToValue()) +} + +func Enable(runtime *goja.Runtime) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + runtime.Set("URL", exports.Get("URL")) + runtime.Set("URLSearchParams", exports.Get("URLSearchParams")) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/url/module_test.go b/script/modules/url/module_test.go new file mode 100644 index 00000000..2b38a40d --- /dev/null +++ b/script/modules/url/module_test.go @@ -0,0 +1,37 @@ +package url_test + +import ( + _ "embed" + "testing" + + "github.com/sagernet/sing-box/script/jstest" + "github.com/sagernet/sing-box/script/modules/url" + + "github.com/dop251/goja" +) + +var ( + //go:embed testdata/url_test.js + urlTest string + + //go:embed testdata/url_search_params_test.js + urlSearchParamsTest string +) + +func TestURL(t *testing.T) { + registry := jstest.NewRegistry() + registry.RegisterNodeModule(url.ModuleName, url.Require) + vm := goja.New() + registry.Enable(vm) + url.Enable(vm) + vm.RunScript("url_test.js", urlTest) +} + +func TestURLSearchParams(t *testing.T) { + registry := jstest.NewRegistry() + registry.RegisterNodeModule(url.ModuleName, url.Require) + vm := goja.New() + registry.Enable(vm) + url.Enable(vm) + vm.RunScript("url_search_params_test.js", urlSearchParamsTest) +} diff --git a/script/modules/url/testdata/url_search_params_test.js b/script/modules/url/testdata/url_search_params_test.js new file mode 100644 index 00000000..4c4897c3 --- /dev/null +++ b/script/modules/url/testdata/url_search_params_test.js @@ -0,0 +1,385 @@ +"use strict"; + +const assert = require("assert.js"); + +let params; + +function testCtor(value, expected) { + assert.sameValue(new URLSearchParams(value).toString(), expected); +} + +testCtor("user=abc&query=xyz", "user=abc&query=xyz"); +testCtor("?user=abc&query=xyz", "user=abc&query=xyz"); + +testCtor( + { + num: 1, + user: "abc", + query: ["first", "second"], + obj: { prop: "value" }, + b: true, + }, + "num=1&user=abc&query=first%2Csecond&obj=%5Bobject+Object%5D&b=true" +); + +const map = new Map(); +map.set("user", "abc"); +map.set("query", "xyz"); +testCtor(map, "user=abc&query=xyz"); + +testCtor( + [ + ["user", "abc"], + ["query", "first"], + ["query", "second"], + ], + "user=abc&query=first&query=second" +); + +// Each key-value pair must have exactly two elements +assert.throwsNodeError(() => new URLSearchParams([["single_value"]]), TypeError, "ERR_INVALID_TUPLE"); +assert.throwsNodeError(() => new URLSearchParams([["too", "many", "values"]]), TypeError, "ERR_INVALID_TUPLE"); + +params = new URLSearchParams("a=b&cc=d"); +params.forEach((value, name, searchParams) => { + if (name === "a") { + assert.sameValue(value, "b"); + } + if (name === "cc") { + assert.sameValue(value, "d"); + } + assert.sameValue(searchParams, params); +}); + +params.forEach((value, name, searchParams) => { + if (name === "a") { + assert.sameValue(value, "b"); + searchParams.set("cc", "d1"); + } + if (name === "cc") { + assert.sameValue(value, "d1"); + } + assert.sameValue(searchParams, params); +}); + +assert.throwsNodeError(() => params.forEach(123), TypeError, "ERR_INVALID_ARG_TYPE"); + +assert.throwsNodeError(() => params.forEach.call(1, 2), TypeError, "ERR_INVALID_THIS"); + +params = new URLSearchParams("a=1=2&b=3"); +assert.sameValue(params.size, 2); +assert.sameValue(params.get("a"), "1=2"); +assert.sameValue(params.get("b"), "3"); + +params = new URLSearchParams("&"); +assert.sameValue(params.size, 0); + +params = new URLSearchParams("& "); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(" "), ""); + +params = new URLSearchParams(" &"); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(" "), ""); + +params = new URLSearchParams("="); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(""), ""); + +params = new URLSearchParams("&=2"); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(""), "2"); + +params = new URLSearchParams("?user=abc"); +assert.throwsNodeError(() => params.append(), TypeError, "ERR_MISSING_ARGS"); +params.append("query", "first"); +assert.sameValue(params.toString(), "user=abc&query=first"); + +params = new URLSearchParams("first=one&second=two&third=three"); +assert.throwsNodeError(() => params.delete(), TypeError, "ERR_MISSING_ARGS"); +params.delete("second", "fake-value"); +assert.sameValue(params.toString(), "first=one&second=two&third=three"); +params.delete("third", "three"); +assert.sameValue(params.toString(), "first=one&second=two"); +params.delete("second"); +assert.sameValue(params.toString(), "first=one"); + +params = new URLSearchParams("user=abc&query=xyz"); +assert.throwsNodeError(() => params.get(), TypeError, "ERR_MISSING_ARGS"); +assert.sameValue(params.get("user"), "abc"); +assert.sameValue(params.get("non-existant"), null); + +params = new URLSearchParams("query=first&query=second"); +assert.throwsNodeError(() => params.getAll(), TypeError, "ERR_MISSING_ARGS"); +const all = params.getAll("query"); +assert.sameValue(all.includes("first"), true); +assert.sameValue(all.includes("second"), true); +assert.sameValue(all.length, 2); +const getAllUndefined = params.getAll(undefined); +assert.sameValue(getAllUndefined.length, 0); +const getAllNonExistant = params.getAll("does_not_exists"); +assert.sameValue(getAllNonExistant.length, 0); + +params = new URLSearchParams("user=abc&query=xyz"); +assert.throwsNodeError(() => params.has(), TypeError, "ERR_MISSING_ARGS"); +assert.sameValue(params.has(undefined), false); +assert.sameValue(params.has("user"), true); +assert.sameValue(params.has("user", "abc"), true); +assert.sameValue(params.has("user", "abc", "extra-param"), true); +assert.sameValue(params.has("user", "efg"), false); +assert.sameValue(params.has("user", undefined), true); + +params = new URLSearchParams(); +params.append("foo", "bar"); +params.append("foo", "baz"); +params.append("abc", "def"); +assert.sameValue(params.toString(), "foo=bar&foo=baz&abc=def"); +params.set("foo", "def"); +params.set("xyz", "opq"); +assert.sameValue(params.toString(), "foo=def&abc=def&xyz=opq"); + +params = new URLSearchParams("query=first&query=second&user=abc&double=first,second"); +const URLSearchIteratorPrototype = params.entries().__proto__; +assert.sameValue(typeof URLSearchIteratorPrototype, "object"); + +assert.sameValue(params[Symbol.iterator], params.entries); + +{ + const entries = params.entries(); + assert.sameValue(entries.toString(), "[object URLSearchParams Iterator]"); + assert.sameValue(entries.__proto__, URLSearchIteratorPrototype); + + let item = entries.next(); + assert.sameValue(item.value.toString(), ["query", "first"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["query", "second"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["user", "abc"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["double", "first,second"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + +params = new URLSearchParams("query=first&query=second&user=abc"); +{ + const keys = params.keys(); + assert.sameValue(keys.__proto__, URLSearchIteratorPrototype); + + let item = keys.next(); + assert.sameValue(item.value, "query"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, "query"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, "user"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + +params = new URLSearchParams("query=first&query=second&user=abc"); +{ + const values = params.values(); + assert.sameValue(values.__proto__, URLSearchIteratorPrototype); + + let item = values.next(); + assert.sameValue(item.value, "first"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, "second"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, "abc"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + + +params = new URLSearchParams("query[]=abc&type=search&query[]=123"); +params.sort(); +assert.sameValue(params.toString(), "query%5B%5D=abc&query%5B%5D=123&type=search"); + +params = new URLSearchParams("query=first&query=second&user=abc"); +assert.sameValue(params.size, 3); + +params = new URLSearchParams("%"); +assert.sameValue(params.has("%"), true); +assert.sameValue(params.toString(), "%25="); + +{ + const params = new URLSearchParams(""); + assert.sameValue(params.size, 0); + assert.sameValue(params.toString(), ""); + assert.sameValue(params.get(undefined), null); + params.set(undefined, true); + assert.sameValue(params.has(undefined), true); + assert.sameValue(params.has("undefined"), true); + assert.sameValue(params.get("undefined"), "true"); + assert.sameValue(params.get(undefined), "true"); + assert.sameValue(params.getAll(undefined).toString(), ["true"].toString()); + params.delete(undefined); + assert.sameValue(params.has(undefined), false); + assert.sameValue(params.has("undefined"), false); + + assert.sameValue(params.has(null), false); + params.set(null, "nullval"); + assert.sameValue(params.has(null), true); + assert.sameValue(params.has("null"), true); + assert.sameValue(params.get(null), "nullval"); + assert.sameValue(params.get("null"), "nullval"); + params.delete(null); + assert.sameValue(params.has(null), false); + assert.sameValue(params.has("null"), false); +} + +function* functionGeneratorExample() { + yield ["user", "abc"]; + yield ["query", "first"]; + yield ["query", "second"]; +} + +params = new URLSearchParams(functionGeneratorExample()); +assert.sameValue(params.toString(), "user=abc&query=first&query=second"); + +assert.sameValue(params.__proto__.constructor, URLSearchParams); +assert.sameValue(params instanceof URLSearchParams, true); + +{ + const params = new URLSearchParams("1=2&1=3"); + assert.sameValue(params.get(1), "2"); + assert.sameValue(params.getAll(1).toString(), ["2", "3"].toString()); + assert.sameValue(params.getAll("x").toString(), [].toString()); +} + +// Sync +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + assert.sameValue(params.size, 0); + url.search = "a=1"; + assert.sameValue(params.size, 1); + assert.sameValue(params.get("a"), "1"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + assert.sameValue(params.size, 1); + url.search = ""; + assert.sameValue(params.size, 0); + url.search = "b=2"; + assert.sameValue(params.size, 1); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + params.append("a", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1"); +} + +{ + const url = new URL("https://test.com/"); + url.searchParams.append("a", "1"); + url.searchParams.append("b", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1&b=1"); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + url.searchParams.append("a", "1"); + assert.sameValue(url.search, "?a=1"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + params.append("a", "2"); + assert.sameValue(url.search, "?a=1&a=2"); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + params.set("a", "1"); + assert.sameValue(url.search, "?a=1"); +} + +{ + const url = new URL("https://test.com/"); + url.searchParams.set("a", "1"); + url.searchParams.set("b", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1&b=1"); +} + +{ + const url = new URL("https://test.com/?a=1&b=2"); + const params = url.searchParams; + params.delete("a"); + assert.sameValue(url.search, "?b=2"); +} + +{ + const url = new URL("https://test.com/?b=2&a=1"); + const params = url.searchParams; + params.sort(); + assert.sameValue(url.search, "?a=1&b=2"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + params.delete("a"); + assert.sameValue(url.search, ""); + + params.set("a", 2); + assert.sameValue(url.search, "?a=2"); +} + +// FAILING: no custom properties on wrapped Go structs +/* +{ + const params = new URLSearchParams(""); + assert.sameValue(Object.isExtensible(params), true); + assert.sameValue(Reflect.defineProperty(params, "customField", {value: 42, configurable: true}), true); + assert.sameValue(params.customField, 42); + const desc = Reflect.getOwnPropertyDescriptor(params, "customField"); + assert.sameValue(desc.value, 42); + assert.sameValue(desc.writable, false); + assert.sameValue(desc.enumerable, false); + assert.sameValue(desc.configurable, true); +} +*/ + +// Escape +{ + const myURL = new URL('https://example.org/abc?fo~o=~ba r%z'); + + assert.sameValue(myURL.search, "?fo~o=~ba%20r%z"); + + // Modify the URL via searchParams... + myURL.searchParams.sort(); + + assert.sameValue(myURL.search, "?fo%7Eo=%7Eba+r%25z"); +} diff --git a/script/modules/url/testdata/url_test.js b/script/modules/url/testdata/url_test.js new file mode 100644 index 00000000..a6ff43be --- /dev/null +++ b/script/modules/url/testdata/url_test.js @@ -0,0 +1,229 @@ +"use strict"; + +const assert = require("assert.js"); + +function testURLCtor(str, expected) { + assert.sameValue(new URL(str).toString(), expected); +} + +function testURLCtorBase(ref, base, expected, message) { + assert.sameValue(new URL(ref, base).toString(), expected, message); +} + +testURLCtorBase("https://example.org/", undefined, "https://example.org/"); +testURLCtorBase("/foo", "https://example.org/", "https://example.org/foo"); +testURLCtorBase("http://Example.com/", "https://example.org/", "http://example.com/"); +testURLCtorBase("https://Example.com/", "https://example.org/", "https://example.com/"); +testURLCtorBase("foo://Example.com/", "https://example.org/", "foo://Example.com/"); +testURLCtorBase("foo:Example.com/", "https://example.org/", "foo:Example.com/"); +testURLCtorBase("#hash", "https://example.org/", "https://example.org/#hash"); + +testURLCtor("HTTP://test.com", "http://test.com/"); +testURLCtor("HTTPS://á.com", "https://xn--1ca.com/"); +testURLCtor("HTTPS://á.com:123", "https://xn--1ca.com:123/"); +testURLCtor("https://test.com#asdfá", "https://test.com/#asdf%C3%A1"); +testURLCtor("HTTPS://á.com:123/á", "https://xn--1ca.com:123/%C3%A1"); +testURLCtor("fish://á.com", "fish://%C3%A1.com"); +testURLCtor("https://test.com/?a=1 /2", "https://test.com/?a=1%20/2"); +testURLCtor("https://test.com/á=1?á=1&ü=2#é", "https://test.com/%C3%A1=1?%C3%A1=1&%C3%BC=2#%C3%A9"); + +assert.throws(() => new URL("test"), TypeError); +assert.throws(() => new URL("ssh://EEE:ddd"), TypeError); + +{ + let u = new URL("https://example.org/"); + assert.sameValue(u.__proto__.constructor, URL); + assert.sameValue(u instanceof URL, true); +} + +{ + let u = new URL("https://example.org/"); + assert.sameValue(u.searchParams, u.searchParams); +} + +let myURL; + +// Hash +myURL = new URL("https://example.org/foo#bar"); +myURL.hash = "baz"; +assert.sameValue(myURL.href, "https://example.org/foo#baz"); + +myURL.hash = "#baz"; +assert.sameValue(myURL.href, "https://example.org/foo#baz"); + +myURL.hash = "#á=1 2"; +assert.sameValue(myURL.href, "https://example.org/foo#%C3%A1=1%202"); + +myURL.hash = "#a/#b"; +// FAILING: the second # gets escaped +//assert.sameValue(myURL.href, "https://example.org/foo#a/#b"); +assert.sameValue(myURL.search, ""); +// FAILING: the second # gets escaped +//assert.sameValue(myURL.hash, "#a/#b"); + +// Host +myURL = new URL("https://example.org:81/foo"); +myURL.host = "example.com:82"; +assert.sameValue(myURL.href, "https://example.com:82/foo"); + +// Hostname +myURL = new URL("https://example.org:81/foo"); +myURL.hostname = "example.com:82"; +assert.sameValue(myURL.href, "https://example.org:81/foo"); + +myURL.hostname = "á.com"; +assert.sameValue(myURL.href, "https://xn--1ca.com:81/foo"); + +// href +myURL = new URL("https://example.org/foo"); +myURL.href = "https://example.com/bar"; +assert.sameValue(myURL.href, "https://example.com/bar"); + +// Password +myURL = new URL("https://abc:xyz@example.com"); +myURL.password = "123"; +assert.sameValue(myURL.href, "https://abc:123@example.com/"); + +// pathname +myURL = new URL("https://example.org/abc/xyz?123"); +myURL.pathname = "/abcdef"; +assert.sameValue(myURL.href, "https://example.org/abcdef?123"); + +myURL.pathname = ""; +assert.sameValue(myURL.href, "https://example.org/?123"); + +myURL.pathname = "á"; +assert.sameValue(myURL.pathname, "/%C3%A1"); +assert.sameValue(myURL.href, "https://example.org/%C3%A1?123"); + +// port + +myURL = new URL("https://example.org:8888"); +assert.sameValue(myURL.port, "8888"); + +function testSetPort(port, expected) { + const url = new URL("https://example.org:8888"); + url.port = port; + assert.sameValue(url.port, expected); +} + +testSetPort(0, "0"); +testSetPort(-0, "0"); + +// Default ports are automatically transformed to the empty string +// (HTTPS protocol's default port is 443) +testSetPort("443", ""); +testSetPort(443, ""); + +// Empty string is the same as default port +testSetPort("", ""); + +// Completely invalid port strings are ignored +testSetPort("abcd", "8888"); +testSetPort("-123", ""); +testSetPort(-123, ""); +testSetPort(-123.45, ""); +testSetPort(undefined, "8888"); +testSetPort(null, "8888"); +testSetPort(+Infinity, "8888"); +testSetPort(-Infinity, "8888"); +testSetPort(NaN, "8888"); + +// Leading numbers are treated as a port number +testSetPort("5678abcd", "5678"); +testSetPort("a5678abcd", ""); + +// Non-integers are truncated +testSetPort(1234.5678, "1234"); + +// Out-of-range numbers which are not represented in scientific notation +// will be ignored. +testSetPort(1e10, "8888"); +testSetPort("123456", "8888"); +testSetPort(123456, "8888"); +testSetPort(4.567e21, "4"); + +// toString() takes precedence over valueOf(), even if it returns a valid integer +testSetPort( + { + toString() { + return "2"; + }, + valueOf() { + return 1; + }, + }, + "2" +); + +// Protocol +function testSetProtocol(url, protocol, expected) { + url.protocol = protocol; + assert.sameValue(url.protocol, expected); +} +testSetProtocol(new URL("https://example.org"), "ftp", "ftp:"); +testSetProtocol(new URL("https://example.org"), "ftp:", "ftp:"); +testSetProtocol(new URL("https://example.org"), "FTP:", "ftp:"); +testSetProtocol(new URL("https://example.org"), "ftp: blah", "ftp:"); +// special to non-special +testSetProtocol(new URL("https://example.org"), "foo", "https:"); +// non-special to special +testSetProtocol(new URL("fish://example.org"), "https", "fish:"); + +// Search +myURL = new URL("https://example.org/abc?123"); +myURL.search = "abc=xyz"; +assert.sameValue(myURL.href, "https://example.org/abc?abc=xyz"); + +myURL.search = "a=1 2"; +assert.sameValue(myURL.href, "https://example.org/abc?a=1%202"); + +myURL.search = "á=ú"; +assert.sameValue(myURL.search, "?%C3%A1=%C3%BA"); +assert.sameValue(myURL.href, "https://example.org/abc?%C3%A1=%C3%BA"); + +myURL.hash = "hash"; +myURL.search = "a=#b"; +assert.sameValue(myURL.href, "https://example.org/abc?a=%23b#hash"); +assert.sameValue(myURL.search, "?a=%23b"); +assert.sameValue(myURL.hash, "#hash"); + +// Username +myURL = new URL("https://abc:xyz@example.com/"); +myURL.username = "123"; +assert.sameValue(myURL.href, "https://123:xyz@example.com/"); + +// Origin, read-only +assert.throws(() => { + myURL.origin = "abc"; +}, TypeError); + +// href +myURL = new URL("https://example.org"); +myURL.href = "https://example.com"; +assert.sameValue(myURL.href, "https://example.com/"); + +assert.throws(() => { + myURL.href = "test"; +}, TypeError); + +// Search Params +myURL = new URL("https://example.com/"); +myURL.searchParams.append("user", "abc"); +assert.sameValue(myURL.toString(), "https://example.com/?user=abc"); +myURL.searchParams.append("first", "one"); +assert.sameValue(myURL.toString(), "https://example.com/?user=abc&first=one"); +myURL.searchParams.delete("user"); +assert.sameValue(myURL.toString(), "https://example.com/?first=one"); + +{ + const url = require("url"); + + assert.sameValue(url.domainToASCII('español.com'), "xn--espaol-zwa.com"); + assert.sameValue(url.domainToASCII('中文.com'), "xn--fiq228c.com"); + assert.sameValue(url.domainToASCII('xn--iñvalid.com'), ""); + + assert.sameValue(url.domainToUnicode('xn--espaol-zwa.com'), "español.com"); + assert.sameValue(url.domainToUnicode('xn--fiq228c.com'), "中文.com"); + assert.sameValue(url.domainToUnicode('xn--iñvalid.com'), ""); +} diff --git a/script/modules/url/url.go b/script/modules/url/url.go new file mode 100644 index 00000000..7b442ded --- /dev/null +++ b/script/modules/url/url.go @@ -0,0 +1,315 @@ +package url + +import ( + "net" + "net/url" + "strings" + + "github.com/sagernet/sing-box/script/jsc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/dop251/goja" + "golang.org/x/net/idna" +) + +type URL struct { + class jsc.Class[*Module, *URL] + url *url.URL + params *URLSearchParams + paramsValue goja.Value +} + +func newURL(c jsc.Class[*Module, *URL], call goja.ConstructorCall) *URL { + var ( + u, base *url.URL + err error + ) + switch argURL := call.Argument(0).Export().(type) { + case *URL: + u = argURL.url + default: + u, err = parseURL(call.Argument(0).String()) + if err != nil { + panic(c.Runtime().NewGoError(E.Cause(err, "parse URL"))) + } + } + if len(call.Arguments) == 2 { + switch argBaseURL := call.Argument(1).Export().(type) { + case *URL: + base = argBaseURL.url + default: + base, err = parseURL(call.Argument(1).String()) + if err != nil { + panic(c.Runtime().NewGoError(E.Cause(err, "parse base URL"))) + } + } + } + if base != nil { + u = base.ResolveReference(u) + } + return &URL{class: c, url: u} +} + +func createURL(module *Module) jsc.Class[*Module, *URL] { + class := jsc.NewClass[*Module, *URL](module) + class.DefineConstructor(newURL) + class.DefineField("hash", (*URL).getHash, (*URL).setHash) + class.DefineField("host", (*URL).getHost, (*URL).setHost) + class.DefineField("hostname", (*URL).getHostName, (*URL).setHostName) + class.DefineField("href", (*URL).getHref, (*URL).setHref) + class.DefineField("origin", (*URL).getOrigin, nil) + class.DefineField("password", (*URL).getPassword, (*URL).setPassword) + class.DefineField("pathname", (*URL).getPathname, (*URL).setPathname) + class.DefineField("port", (*URL).getPort, (*URL).setPort) + class.DefineField("protocol", (*URL).getProtocol, (*URL).setProtocol) + class.DefineField("search", (*URL).getSearch, (*URL).setSearch) + class.DefineField("searchParams", (*URL).getSearchParams, (*URL).setSearchParams) + class.DefineField("username", (*URL).getUsername, (*URL).setUsername) + class.DefineMethod("toString", (*URL).toString) + class.DefineMethod("toJSON", (*URL).toJSON) + class.DefineStaticMethod("canParse", canParse) + // class.DefineStaticMethod("createObjectURL", createObjectURL) + class.DefineStaticMethod("parse", parse) + // class.DefineStaticMethod("revokeObjectURL", revokeObjectURL) + return class +} + +func canParse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any { + switch call.Argument(0).Export().(type) { + case *URL: + default: + _, err := parseURL(call.Argument(0).String()) + if err != nil { + return false + } + } + if len(call.Arguments) == 2 { + switch call.Argument(1).Export().(type) { + case *URL: + default: + _, err := parseURL(call.Argument(1).String()) + if err != nil { + return false + } + } + } + return true +} + +func parse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any { + var ( + u, base *url.URL + err error + ) + switch argURL := call.Argument(0).Export().(type) { + case *URL: + u = argURL.url + default: + u, err = parseURL(call.Argument(0).String()) + if err != nil { + return goja.Null() + } + } + if len(call.Arguments) == 2 { + switch argBaseURL := call.Argument(1).Export().(type) { + case *URL: + base = argBaseURL.url + default: + base, err = parseURL(call.Argument(1).String()) + if err != nil { + return goja.Null() + } + } + } + if base != nil { + u = base.ResolveReference(u) + } + return &URL{class: class, url: u} +} + +func (r *URL) getHash() any { + if r.url.Fragment != "" { + return "#" + r.url.EscapedFragment() + } + return "" +} + +func (r *URL) setHash(value goja.Value) { + r.url.RawFragment = strings.TrimPrefix(value.String(), "#") +} + +func (r *URL) getHost() any { + return r.url.Host +} + +func (r *URL) setHost(value goja.Value) { + r.url.Host = strings.TrimSuffix(value.String(), ":") +} + +func (r *URL) getHostName() any { + return r.url.Hostname() +} + +func (r *URL) setHostName(value goja.Value) { + r.url.Host = joinHostPort(value.String(), r.url.Port()) +} + +func (r *URL) getHref() any { + return r.url.String() +} + +func (r *URL) setHref(value goja.Value) { + newURL, err := url.Parse(value.String()) + if err != nil { + panic(r.class.Runtime().NewGoError(err)) + } + r.url = newURL + r.params = nil +} + +func (r *URL) getOrigin() any { + return r.url.Scheme + "://" + r.url.Host +} + +func (r *URL) getPassword() any { + if r.url.User != nil { + password, _ := r.url.User.Password() + return password + } + return "" +} + +func (r *URL) setPassword(value goja.Value) { + if r.url.User == nil { + r.url.User = url.UserPassword("", value.String()) + } else { + r.url.User = url.UserPassword(r.url.User.Username(), value.String()) + } +} + +func (r *URL) getPathname() any { + return r.url.EscapedPath() +} + +func (r *URL) setPathname(value goja.Value) { + r.url.RawPath = value.String() +} + +func (r *URL) getPort() any { + return r.url.Port() +} + +func (r *URL) setPort(value goja.Value) { + r.url.Host = joinHostPort(r.url.Hostname(), value.String()) +} + +func (r *URL) getProtocol() any { + return r.url.Scheme + ":" +} + +func (r *URL) setProtocol(value goja.Value) { + r.url.Scheme = strings.TrimSuffix(value.String(), ":") +} + +func (r *URL) getSearch() any { + if r.params != nil { + if len(r.params.params) > 0 { + return "?" + generateQuery(r.params.params) + } + } else if r.url.RawQuery != "" { + return "?" + r.url.RawQuery + } + return "" +} + +func (r *URL) setSearch(value goja.Value) { + params, err := parseQuery(value.String()) + if err == nil { + if r.params != nil { + r.params.params = params + } else { + r.url.RawQuery = generateQuery(params) + } + } +} + +func (r *URL) getSearchParams() any { + var params []searchParam + if r.url.RawQuery != "" { + params, _ = parseQuery(r.url.RawQuery) + } + if r.params == nil { + r.params = &URLSearchParams{ + class: r.class.Module().classURLSearchParams, + params: params, + } + r.paramsValue = r.class.Module().classURLSearchParams.New(r.params) + } + return r.paramsValue +} + +func (r *URL) setSearchParams(value goja.Value) { + if params, ok := value.Export().(*URLSearchParams); ok { + r.params = params + r.paramsValue = value + } +} + +func (r *URL) getUsername() any { + if r.url.User != nil { + return r.url.User.Username() + } + return "" +} + +func (r *URL) setUsername(value goja.Value) { + if r.url.User == nil { + r.url.User = url.User(value.String()) + } else { + password, _ := r.url.User.Password() + r.url.User = url.UserPassword(value.String(), password) + } +} + +func (r *URL) toString(call goja.FunctionCall) any { + if r.params != nil { + r.url.RawQuery = generateQuery(r.params.params) + } + return r.url.String() +} + +func (r *URL) toJSON(call goja.FunctionCall) any { + return r.toString(call) +} + +func parseURL(s string) (*url.URL, error) { + u, err := url.Parse(s) + if err != nil { + return nil, E.Cause(err, "invalid URL") + } + switch u.Scheme { + case "https", "http", "ftp", "wss", "ws": + if u.Path == "" { + u.Path = "/" + } + hostname := u.Hostname() + asciiHostname, err := idna.Punycode.ToASCII(strings.ToLower(hostname)) + if err != nil { + return nil, E.Cause(err, "invalid hostname") + } + if asciiHostname != hostname { + u.Host = joinHostPort(asciiHostname, u.Port()) + } + } + if u.RawQuery != "" { + u.RawQuery = escape(u.RawQuery, &tblEscapeURLQuery, false) + } + return u, nil +} + +func joinHostPort(hostname, port string) string { + if port == "" { + return hostname + } + return net.JoinHostPort(hostname, port) +} diff --git a/script/modules/url/url_search_params.go b/script/modules/url/url_search_params.go new file mode 100644 index 00000000..945f076f --- /dev/null +++ b/script/modules/url/url_search_params.go @@ -0,0 +1,244 @@ +package url + +import ( + "fmt" + "net/url" + "sort" + "strings" + + "github.com/sagernet/sing-box/script/jsc" + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" +) + +type URLSearchParams struct { + class jsc.Class[*Module, *URLSearchParams] + params []searchParam +} + +func createURLSearchParams(module *Module) jsc.Class[*Module, *URLSearchParams] { + class := jsc.NewClass[*Module, *URLSearchParams](module) + class.DefineConstructor(newURLSearchParams) + class.DefineField("size", (*URLSearchParams).getSize, nil) + class.DefineMethod("append", (*URLSearchParams).append) + class.DefineMethod("delete", (*URLSearchParams).delete) + class.DefineMethod("entries", (*URLSearchParams).entries) + class.DefineMethod("forEach", (*URLSearchParams).forEach) + class.DefineMethod("get", (*URLSearchParams).get) + class.DefineMethod("getAll", (*URLSearchParams).getAll) + class.DefineMethod("has", (*URLSearchParams).has) + class.DefineMethod("keys", (*URLSearchParams).keys) + class.DefineMethod("set", (*URLSearchParams).set) + class.DefineMethod("sort", (*URLSearchParams).sort) + class.DefineMethod("toString", (*URLSearchParams).toString) + class.DefineMethod("values", (*URLSearchParams).values) + return class +} + +func newURLSearchParams(class jsc.Class[*Module, *URLSearchParams], call goja.ConstructorCall) *URLSearchParams { + var ( + params []searchParam + err error + ) + switch argInit := call.Argument(0).Export().(type) { + case *URLSearchParams: + params = argInit.params + case string: + params, err = parseQuery(argInit) + if err != nil { + panic(class.Runtime().NewGoError(err)) + } + case [][]string: + for _, pair := range argInit { + if len(pair) != 2 { + panic(class.Runtime().NewTypeError("Each query pair must be an iterable [name, value] tuple")) + } + params = append(params, searchParam{pair[0], pair[1]}) + } + case map[string]any: + for name, value := range argInit { + stringValue, isString := value.(string) + if !isString { + panic(class.Runtime().NewTypeError("Invalid query value")) + } + params = append(params, searchParam{name, stringValue}) + } + } + return &URLSearchParams{class, params} +} + +func (s *URLSearchParams) getSize() any { + return len(s.params) +} + +func (s *URLSearchParams) append(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + value := call.Argument(1).String() + s.params = append(s.params, searchParam{name, value}) + return goja.Undefined() +} + +func (s *URLSearchParams) delete(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + argValue := call.Argument(1) + if !jsc.IsNil(argValue) { + value := argValue.String() + for i, param := range s.params { + if param.Key == name && param.Value == value { + s.params = append(s.params[:i], s.params[i+1:]...) + break + } + } + } else { + for i, param := range s.params { + if param.Key == name { + s.params = append(s.params[:i], s.params[i+1:]...) + break + } + } + } + return goja.Undefined() +} + +func (s *URLSearchParams) entries(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return s.class.Runtime().NewArray(this.Key, this.Value) + }) +} + +func (s *URLSearchParams) forEach(call goja.FunctionCall) any { + callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(0), "callbackFn") + thisValue := call.Argument(1) + for _, param := range s.params { + for _, value := range param.Value { + _, err := callback(thisValue, s.class.Runtime().ToValue(value), s.class.Runtime().ToValue(param.Key), call.This) + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + } + } + return goja.Undefined() +} + +func (s *URLSearchParams) get(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + for _, param := range s.params { + if param.Key == name { + return param.Value + } + } + return goja.Null() +} + +func (s *URLSearchParams) getAll(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + var values []any + for _, param := range s.params { + if param.Key == name { + values = append(values, param.Value) + } + } + return s.class.Runtime().NewArray(values...) +} + +func (s *URLSearchParams) has(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + argValue := call.Argument(1) + if !jsc.IsNil(argValue) { + value := argValue.String() + for _, param := range s.params { + if param.Key == name && param.Value == value { + return true + } + } + } else { + for _, param := range s.params { + if param.Key == name { + return true + } + } + } + return false +} + +func (s *URLSearchParams) keys(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return this.Key + }) +} + +func (s *URLSearchParams) set(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + value := call.Argument(1).String() + for i, param := range s.params { + if param.Key == name { + s.params[i].Value = value + return goja.Undefined() + } + } + s.params = append(s.params, searchParam{name, value}) + return goja.Undefined() +} + +func (s *URLSearchParams) sort(call goja.FunctionCall) any { + sort.SliceStable(s.params, func(i, j int) bool { + return s.params[i].Key < s.params[j].Key + }) + return goja.Undefined() +} + +func (s *URLSearchParams) toString(call goja.FunctionCall) any { + return generateQuery(s.params) +} + +func (s *URLSearchParams) values(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return this.Value + }) +} + +type searchParam struct { + Key string + Value string +} + +func parseQuery(query string) (params []searchParam, err error) { + query = strings.TrimPrefix(query, "?") + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + err = fmt.Errorf("invalid semicolon separator in query") + continue + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + key, err1 := url.QueryUnescape(key) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + value, err1 = url.QueryUnescape(value) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + params = append(params, searchParam{key, value}) + } + return +} + +func generateQuery(params []searchParam) string { + var parts []string + for _, param := range params { + parts = append(parts, F.ToString(param.Key, "=", url.QueryEscape(param.Value))) + } + return strings.Join(parts, "&") +} diff --git a/script/runtime.go b/script/runtime.go new file mode 100644 index 00000000..a5961dd0 --- /dev/null +++ b/script/runtime.go @@ -0,0 +1,49 @@ +//go:build with_script + +package script + +import ( + "context" + + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing-box/script/modules/console" + "github.com/sagernet/sing-box/script/modules/eventloop" + "github.com/sagernet/sing-box/script/modules/require" + "github.com/sagernet/sing-box/script/modules/surge" + "github.com/sagernet/sing-box/script/modules/url" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/common/ntp" + + "github.com/dop251/goja" + "github.com/dop251/goja/parser" +) + +func NewRuntime(ctx context.Context, cancel context.CancelCauseFunc) *goja.Runtime { + vm := goja.New() + if timeFunc := ntp.TimeFuncFromContext(ctx); timeFunc != nil { + vm.SetTimeSource(timeFunc) + } + vm.SetParserOptions(parser.WithDisableSourceMaps) + registry := require.NewRegistry(require.WithLoader(func(path string) ([]byte, error) { + return nil, E.New("unsupported usage") + })) + registry.Enable(vm) + registry.RegisterNodeModule(console.ModuleName, console.Require) + registry.RegisterNodeModule(url.ModuleName, url.Require) + registry.RegisterNativeModule(boxctx.ModuleName, boxctx.Require) + registry.RegisterNativeModule(surge.ModuleName, surge.Require) + console.Enable(vm) + url.Enable(vm) + eventloop.Enable(vm, cancel) + return vm +} + +func SetModules(runtime *goja.Runtime, ctx context.Context, logger logger.ContextLogger, errorHandler func(error), tag string) { + boxctx.Enable(runtime, &boxctx.Context{ + Context: ctx, + Logger: logger, + Tag: tag, + ErrorHandler: errorHandler, + }) +} diff --git a/script/script.go b/script/script.go new file mode 100644 index 00000000..72112004 --- /dev/null +++ b/script/script.go @@ -0,0 +1,22 @@ +//go:build with_script + +package script + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +func NewScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) { + switch options.Type { + case C.ScriptTypeSurge: + return NewSurgeScript(ctx, logger, options) + default: + return nil, E.New("unknown script type: ", options.Type) + } +} diff --git a/script/script_surge.go b/script/script_surge.go new file mode 100644 index 00000000..0bd604e4 --- /dev/null +++ b/script/script_surge.go @@ -0,0 +1,347 @@ +//go:build with_script + +package script + +import ( + "context" + "net/http" + "sync" + "time" + "unsafe" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/surge" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + + "github.com/adhocore/gronx" + "github.com/dop251/goja" +) + +const defaultSurgeScriptTimeout = 10 * time.Second + +var _ adapter.SurgeScript = (*SurgeScript)(nil) + +type SurgeScript struct { + ctx context.Context + logger logger.ContextLogger + tag string + source Source + + cronExpression string + cronTimeout time.Duration + cronArguments []string + cronTimer *time.Timer + cronDone chan struct{} +} + +func NewSurgeScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) { + source, err := NewSource(ctx, logger, options) + if err != nil { + return nil, err + } + cronOptions := common.PtrValueOrDefault(options.SurgeOptions.CronOptions) + if cronOptions.Expression != "" { + if !gronx.IsValid(cronOptions.Expression) { + return nil, E.New("invalid cron expression: ", cronOptions.Expression) + } + } + return &SurgeScript{ + ctx: ctx, + logger: logger, + tag: options.Tag, + source: source, + cronExpression: cronOptions.Expression, + cronTimeout: time.Duration(cronOptions.Timeout), + cronArguments: cronOptions.Arguments, + cronDone: make(chan struct{}), + }, nil +} + +func (s *SurgeScript) Type() string { + return C.ScriptTypeSurge +} + +func (s *SurgeScript) Tag() string { + return s.tag +} + +func (s *SurgeScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { + return s.source.StartContext(ctx, startContext) +} + +func (s *SurgeScript) PostStart() error { + err := s.source.PostStart() + if err != nil { + return err + } + if s.cronExpression != "" { + go s.loopCronEvents() + } + return nil +} + +func (s *SurgeScript) loopCronEvents() { + s.logger.Debug("starting event") + err := s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments) + if err != nil { + s.logger.Error(E.Cause(err, "running event")) + } + nextTick, err := gronx.NextTick(s.cronExpression, false) + if err != nil { + s.logger.Error(E.Cause(err, "determine next tick")) + return + } + s.cronTimer = time.NewTimer(nextTick.Sub(time.Now())) + s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat)) + for { + select { + case <-s.ctx.Done(): + return + case <-s.cronDone: + return + case <-s.cronTimer.C: + s.logger.Debug("starting event") + err = s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments) + if err != nil { + s.logger.Error(E.Cause(err, "running event")) + } + nextTick, err = gronx.NextTick(s.cronExpression, false) + if err != nil { + s.logger.Error(E.Cause(err, "determine next tick")) + return + } + s.cronTimer.Reset(nextTick.Sub(time.Now())) + s.logger.Debug("configured next event at: ", nextTick) + } + } +} + +func (s *SurgeScript) Close() error { + err := s.source.Close() + if s.cronTimer != nil { + s.cronTimer.Stop() + close(s.cronDone) + } + return err +} + +func (s *SurgeScript) ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error { + program := s.source.Program() + if program == nil { + return E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, scriptType, arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + doneFunc() + return goja.Undefined() + }) + var ( + access sync.Mutex + scriptErr error + ) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + } + return scriptErr +} + +func (s *SurgeScript) ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPRequestScriptResult, error) { + program := s.source.Program() + if program == nil { + return nil, E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, "http-request", arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + runtime.ClearInterrupt() + requestObject := runtime.NewObject() + requestObject.Set("url", request.URL.String()) + requestObject.Set("method", request.Method) + requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header)) + if !binaryBody { + requestObject.Set("body", string(body)) + } else { + requestObject.Set("body", jsc.NewUint8Array(runtime, body)) + } + requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) + runtime.Set("request", requestObject) + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + var ( + access sync.Mutex + result adapter.HTTPRequestScriptResult + scriptErr error + ) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + defer doneFunc() + resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true) + if resultObject == nil { + panic(runtime.NewGoError(E.New("request rejected by script"))) + } + access.Lock() + defer access.Unlock() + result.URL = jsc.AssertString(runtime, resultObject.Get("url"), "url", true) + result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers") + result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true) + responseObject := jsc.AssertObject(runtime, resultObject.Get("response"), "response", true) + if responseObject != nil { + result.Response = &adapter.HTTPRequestScriptResponse{ + Status: int(jsc.AssertInt(runtime, responseObject.Get("status"), "status", true)), + Headers: jsc.AssertHTTPHeader(runtime, responseObject.Get("headers"), "headers"), + Body: jsc.AssertStringBinary(runtime, responseObject.Get("body"), "body", true), + } + } + return goja.Undefined() + }) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return nil, ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + } + return &result, scriptErr +} + +func (s *SurgeScript) ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPResponseScriptResult, error) { + program := s.source.Program() + if program == nil { + return nil, E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, "http-response", arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + runtime.ClearInterrupt() + requestObject := runtime.NewObject() + requestObject.Set("url", request.URL.String()) + requestObject.Set("method", request.Method) + requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header)) + requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) + runtime.Set("request", requestObject) + + responseObject := runtime.NewObject() + responseObject.Set("status", response.StatusCode) + responseObject.Set("headers", jsc.HeadersToValue(runtime, response.Header)) + if !binaryBody { + responseObject.Set("body", string(body)) + } else { + responseObject.Set("body", jsc.NewUint8Array(runtime, body)) + } + runtime.Set("response", responseObject) + + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + var ( + access sync.Mutex + result adapter.HTTPResponseScriptResult + scriptErr error + ) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true) + if resultObject == nil { + panic(runtime.NewGoError(E.New("response rejected by script"))) + } + access.Lock() + defer access.Unlock() + result.Status = int(jsc.AssertInt(runtime, resultObject.Get("status"), "status", true)) + result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers") + result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true) + doneFunc() + return goja.Undefined() + }) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return nil, ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + return &result, scriptErr + } +} diff --git a/script/source.go b/script/source.go new file mode 100644 index 00000000..6f10b734 --- /dev/null +++ b/script/source.go @@ -0,0 +1,33 @@ +//go:build with_script + +package script + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + "github.com/dop251/goja" +) + +type Source interface { + StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error + PostStart() error + Program() *goja.Program + Close() error +} + +func NewSource(ctx context.Context, logger logger.Logger, options option.Script) (Source, error) { + switch options.Source { + case C.ScriptSourceTypeLocal: + return NewLocalSource(ctx, logger, options) + case C.ScriptSourceTypeRemote: + return NewRemoteSource(ctx, logger, options) + default: + return nil, E.New("unknown source type: ", options.Source) + } +} diff --git a/script/source_local.go b/script/source_local.go new file mode 100644 index 00000000..649a22b6 --- /dev/null +++ b/script/source_local.go @@ -0,0 +1,94 @@ +//go:build with_script + +package script + +import ( + "context" + "os" + "path/filepath" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service/filemanager" + + "github.com/dop251/goja" +) + +var _ Source = (*LocalSource)(nil) + +type LocalSource struct { + ctx context.Context + logger logger.Logger + tag string + program *goja.Program + watcher *fswatch.Watcher +} + +func NewLocalSource(ctx context.Context, logger logger.Logger, options option.Script) (*LocalSource, error) { + script := &LocalSource{ + ctx: ctx, + logger: logger, + tag: options.Tag, + } + filePath := filemanager.BasePath(ctx, options.LocalOptions.Path) + filePath, _ = filepath.Abs(options.LocalOptions.Path) + err := script.reloadFile(filePath) + if err != nil { + return nil, err + } + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: []string{filePath}, + Callback: func(path string) { + uErr := script.reloadFile(path) + if uErr != nil { + logger.Error(E.Cause(uErr, "reload script ", path)) + } + }, + }) + if err != nil { + return nil, err + } + script.watcher = watcher + return script, nil +} + +func (s *LocalSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { + if s.watcher != nil { + err := s.watcher.Start() + if err != nil { + s.logger.Error(E.Cause(err, "watch script file")) + } + } + return nil +} + +func (s *LocalSource) reloadFile(path string) error { + content, err := os.ReadFile(path) + if err != nil { + return err + } + program, err := goja.Compile("script:"+s.tag, string(content), false) + if err != nil { + return E.Cause(err, "compile ", path) + } + if s.program != nil { + s.logger.Info("reloaded from ", path) + } + s.program = program + return nil +} + +func (s *LocalSource) PostStart() error { + return nil +} + +func (s *LocalSource) Program() *goja.Program { + return s.program +} + +func (s *LocalSource) Close() error { + return s.watcher.Close() +} diff --git a/script/source_remote.go b/script/source_remote.go new file mode 100644 index 00000000..2075bf92 --- /dev/null +++ b/script/source_remote.go @@ -0,0 +1,226 @@ +//go:build with_script + +package script + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "runtime" + "time" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/service" + "github.com/sagernet/sing/service/pause" + + "github.com/dop251/goja" +) + +var _ Source = (*RemoteSource)(nil) + +type RemoteSource struct { + ctx context.Context + cancel context.CancelFunc + logger logger.Logger + outbound adapter.OutboundManager + options option.Script + updateInterval time.Duration + dialer N.Dialer + program *goja.Program + lastUpdated time.Time + lastEtag string + updateTicker *time.Ticker + cacheFile adapter.CacheFile + pauseManager pause.Manager +} + +func NewRemoteSource(ctx context.Context, logger logger.Logger, options option.Script) (*RemoteSource, error) { + ctx, cancel := context.WithCancel(ctx) + var updateInterval time.Duration + if options.RemoteOptions.UpdateInterval > 0 { + updateInterval = time.Duration(options.RemoteOptions.UpdateInterval) + } else { + updateInterval = 24 * time.Hour + } + return &RemoteSource{ + ctx: ctx, + cancel: cancel, + logger: logger, + outbound: service.FromContext[adapter.OutboundManager](ctx), + options: options, + updateInterval: updateInterval, + pauseManager: service.FromContext[pause.Manager](ctx), + }, nil +} + +func (s *RemoteSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { + s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) + var dialer N.Dialer + if s.options.RemoteOptions.DownloadDetour != "" { + outbound, loaded := s.outbound.Outbound(s.options.RemoteOptions.DownloadDetour) + if !loaded { + return E.New("download detour not found: ", s.options.RemoteOptions.DownloadDetour) + } + dialer = outbound + } else { + dialer = s.outbound.Default() + } + s.dialer = dialer + if s.cacheFile != nil { + if savedSet := s.cacheFile.LoadScript(s.options.Tag); savedSet != nil { + err := s.loadBytes(savedSet.Content) + if err != nil { + return E.Cause(err, "restore cached rule-set") + } + s.lastUpdated = savedSet.LastUpdated + s.lastEtag = savedSet.LastEtag + } + } + if s.lastUpdated.IsZero() { + err := s.fetchOnce(ctx, startContext) + if err != nil { + return E.Cause(err, "initial rule-set: ", s.options.Tag) + } + } + s.updateTicker = time.NewTicker(s.updateInterval) + return nil +} + +func (s *RemoteSource) PostStart() error { + go s.loopUpdate() + return nil +} + +func (s *RemoteSource) Program() *goja.Program { + return s.program +} + +func (s *RemoteSource) loadBytes(content []byte) error { + program, err := goja.Compile(F.ToString("script:", s.options.Tag), string(content), false) + if err != nil { + return err + } + s.program = program + return nil +} + +func (s *RemoteSource) loopUpdate() { + if time.Since(s.lastUpdated) > s.updateInterval { + err := s.fetchOnce(s.ctx, nil) + if err != nil { + s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) + } + } + for { + runtime.GC() + select { + case <-s.ctx.Done(): + return + case <-s.updateTicker.C: + s.pauseManager.WaitActive() + err := s.fetchOnce(s.ctx, nil) + if err != nil { + s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) + } + } + } +} + +func (s *RemoteSource) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error { + s.logger.Debug("updating script ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) + var httpClient *http.Client + if startContext != nil { + httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer) + } else { + httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSClientConfig: &tls.Config{ + Time: ntp.TimeFuncFromContext(s.ctx), + RootCAs: adapter.RootPoolFromContext(s.ctx), + }, + }, + } + } + request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil) + if err != nil { + return err + } + if s.lastEtag != "" { + request.Header.Set("If-None-Match", s.lastEtag) + } + response, err := httpClient.Do(request.WithContext(ctx)) + if err != nil { + return err + } + switch response.StatusCode { + case http.StatusOK: + case http.StatusNotModified: + s.lastUpdated = time.Now() + if s.cacheFile != nil { + savedRuleSet := s.cacheFile.LoadScript(s.options.Tag) + if savedRuleSet != nil { + savedRuleSet.LastUpdated = s.lastUpdated + err = s.cacheFile.SaveScript(s.options.Tag, savedRuleSet) + if err != nil { + s.logger.Error("save script updated time: ", err) + return nil + } + } + } + s.logger.Info("update script ", s.options.Tag, ": not modified") + return nil + default: + return E.New("unexpected status: ", response.Status) + } + content, err := io.ReadAll(response.Body) + if err != nil { + response.Body.Close() + return err + } + err = s.loadBytes(content) + if err != nil { + response.Body.Close() + return err + } + response.Body.Close() + eTagHeader := response.Header.Get("Etag") + if eTagHeader != "" { + s.lastEtag = eTagHeader + } + s.lastUpdated = time.Now() + if s.cacheFile != nil { + err = s.cacheFile.SaveScript(s.options.Tag, &adapter.SavedBinary{ + LastUpdated: s.lastUpdated, + Content: content, + LastEtag: s.lastEtag, + }) + if err != nil { + s.logger.Error("save script cache: ", err) + } + } + s.logger.Info("updated script ", s.options.Tag) + return nil +} + +func (s *RemoteSource) Close() error { + if s.updateTicker != nil { + s.updateTicker.Stop() + } + s.cancel() + return nil +}