From 003423745ebd1e038a4be796dcd55e4a42ffc562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 4 Feb 2025 15:00:59 +0800 Subject: [PATCH] mitm: Refactor & Add url --- adapter/script.go | 33 +- constant/script.go | 11 +- experimental/cachefile/cache.go | 16 +- mitm/engine.go | 164 ++++---- option/mitm.go | 25 +- option/script.go | 46 +-- script/jsc/assert.go | 3 + script/jsc/class.go | 192 +++++++++ script/jsc/iterator.go | 36 ++ script/jstest/assert.js | 83 ++++ script/jstest/test.go | 21 + script/manager.go | 29 +- script/modules/boxctx/context.go | 50 +++ script/modules/boxctx/module.go | 35 ++ script/modules/console/console.go | 281 +++++++++++++ script/modules/console/context.go | 3 + script/modules/console/module.go | 108 +---- script/modules/sghttp/module.go | 147 ------- script/modules/sgstore/module.go | 76 ---- script/modules/sgutils/module.go | 45 -- script/modules/surge/environment.go | 65 +++ script/modules/surge/http.go | 150 +++++++ script/modules/surge/module.go | 63 +++ script/modules/surge/notification.go | 120 ++++++ script/modules/surge/persistent_store.go | 78 ++++ script/modules/surge/script.go | 32 ++ script/modules/surge/utils.go | 50 +++ script/modules/url/escape.go | 55 +++ script/modules/url/module.go | 41 ++ script/modules/url/module_test.go | 37 ++ .../url/testdata/url_search_params_test.js | 385 ++++++++++++++++++ script/modules/url/testdata/url_test.js | 229 +++++++++++ script/modules/url/url.go | 315 ++++++++++++++ script/modules/url/url_search_params.go | 244 +++++++++++ script/runtime.go | 47 +++ script/script.go | 10 +- script/script_surge.go | 345 ++++++++++++++++ script/script_surge_cron.go | 119 ------ script/script_surge_generic.go | 183 --------- script/script_surge_http_request.go | 165 -------- script/script_surge_http_response.go | 175 -------- script/source.go | 4 +- 42 files changed, 3152 insertions(+), 1164 deletions(-) create mode 100644 script/jsc/class.go create mode 100644 script/jsc/iterator.go create mode 100644 script/jstest/assert.js create mode 100644 script/jstest/test.go create mode 100644 script/modules/boxctx/context.go create mode 100644 script/modules/boxctx/module.go create mode 100644 script/modules/console/console.go create mode 100644 script/modules/console/context.go delete mode 100644 script/modules/sghttp/module.go delete mode 100644 script/modules/sgstore/module.go delete mode 100644 script/modules/sgutils/module.go create mode 100644 script/modules/surge/environment.go create mode 100644 script/modules/surge/http.go create mode 100644 script/modules/surge/module.go create mode 100644 script/modules/surge/notification.go create mode 100644 script/modules/surge/persistent_store.go create mode 100644 script/modules/surge/script.go create mode 100644 script/modules/surge/utils.go create mode 100644 script/modules/url/escape.go create mode 100644 script/modules/url/module.go create mode 100644 script/modules/url/module_test.go create mode 100644 script/modules/url/testdata/url_search_params_test.js create mode 100644 script/modules/url/testdata/url_test.js create mode 100644 script/modules/url/url.go create mode 100644 script/modules/url/url_search_params.go create mode 100644 script/runtime.go create mode 100644 script/script_surge.go delete mode 100644 script/script_surge_cron.go delete mode 100644 script/script_surge_generic.go delete mode 100644 script/script_surge_http_request.go delete mode 100644 script/script_surge_http_response.go diff --git a/adapter/script.go b/adapter/script.go index 5cd85d63..3967ed92 100644 --- a/adapter/script.go +++ b/adapter/script.go @@ -3,12 +3,20 @@ package adapter import ( "context" "net/http" + "sync" + "time" ) type ScriptManager interface { Lifecycle Scripts() []Script - // Script(name string) (Script, bool) + Script(name string) (Script, bool) + SurgeCache() *SurgeInMemoryCache +} + +type SurgeInMemoryCache struct { + sync.RWMutex + Data map[string]string } type Script interface { @@ -19,21 +27,11 @@ type Script interface { Close() error } -type GenericScript interface { +type SurgeScript interface { Script - Run(ctx context.Context) error -} - -type HTTPScript interface { - Script - Match(requestURL string) bool - RequiresBody() bool - MaxSize() int64 -} - -type HTTPRequestScript interface { - HTTPScript - Run(ctx context.Context, request *http.Request, body []byte) (*HTTPRequestScriptResult, error) + ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error + ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*HTTPRequestScriptResult, error) + ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*HTTPResponseScriptResult, error) } type HTTPRequestScriptResult struct { @@ -49,11 +47,6 @@ type HTTPRequestScriptResponse struct { Body []byte } -type HTTPResponseScript interface { - HTTPScript - Run(ctx context.Context, request *http.Request, response *http.Response, body []byte) (*HTTPResponseScriptResult, error) -} - type HTTPResponseScriptResult struct { Status int Headers http.Header diff --git a/constant/script.go b/constant/script.go index 41199ea6..45574038 100644 --- a/constant/script.go +++ b/constant/script.go @@ -1,12 +1,7 @@ package constant const ( - ScriptTypeSurgeGeneric = "sg-generic" - ScriptTypeSurgeHTTPRequest = "sg-http-request" - ScriptTypeSurgeHTTPResponse = "sg-http-response" - ScriptTypeSurgeCron = "sg-cron" - ScriptTypeSurgeEvent = "sg-event" - - ScriptSourceLocal = "local" - ScriptSourceRemote = "remote" + ScriptTypeSurge = "surge" + ScriptSourceTypeLocal = "local" + ScriptSourceTypeRemote = "remote" ) diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index f4ff2654..fa1e28b2 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -370,10 +370,18 @@ func (c *CacheFile) SurgePersistentStoreRead(key string) string { func (c *CacheFile) SurgePersistentStoreWrite(key string, value string) error { return c.DB.Batch(func(t *bbolt.Tx) error { - bucket, err := c.createBucket(t, bucketSgPersistentStore) - if err != nil { - return err + if value != "" { + bucket, err := c.createBucket(t, bucketSgPersistentStore) + if err != nil { + return err + } + return bucket.Put([]byte(key), []byte(value)) + } else { + bucket := c.bucket(t, bucketSgPersistentStore) + if bucket == nil { + return nil + } + return bucket.Delete([]byte(key)) } - return bucket.Put([]byte(key), []byte(value)) }) } diff --git a/mitm/engine.go b/mitm/engine.go index 8fbd8b44..30a43a2f 100644 --- a/mitm/engine.go +++ b/mitm/engine.go @@ -195,28 +195,31 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls requestURL := rawRequestURL.String() request.RequestURI = "" var ( - requestMatch bool - requestScript adapter.HTTPRequestScript + requestMatch bool + requestScript adapter.SurgeScript + requestScriptOptions option.MITMRouteSurgeScriptOptions ) - for _, script := range e.script.Scripts() { - if !common.Contains(options.Script, script.Tag()) { +match: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) continue } - httpScript, isHTTP := script.(adapter.HTTPRequestScript) - if !isHTTP { - _, isHTTP = script.(adapter.HTTPScript) - if !isHTTP { - e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + requestScript = surgeScript + requestScriptOptions = scriptOptions + requestMatch = true + break match } - continue } - if !httpScript.Match(requestURL) { - continue - } - e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") - requestScript = httpScript - requestMatch = true - break } var body []byte if options.Print && request.ContentLength > 0 && request.ContentLength <= 131072 { @@ -230,7 +233,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls e.printRequest(ctx, request, body) } if requestScript != nil { - if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { + if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) { body, err = io.ReadAll(request.Body) if err != nil { return E.Cause(err, "read HTTP request body") @@ -238,7 +241,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls request.Body = io.NopCloser(bytes.NewReader(body)) } var result *adapter.HTTPRequestScriptResult - result, err = requestScript.Run(ctx, request, body) + result, err = requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments) if err != nil { return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]") } @@ -455,28 +458,31 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls return E.Errors(innerErr.Load(), err) } var ( - responseScript adapter.HTTPResponseScript - responseMatch bool + responseScript adapter.SurgeScript + responseMatch bool + responseScriptOptions option.MITMRouteSurgeScriptOptions ) - for _, script := range e.script.Scripts() { - if !common.Contains(options.Script, script.Tag()) { +matchResponse: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) continue } - httpScript, isHTTP := script.(adapter.HTTPResponseScript) - if !isHTTP { - _, isHTTP = script.(adapter.HTTPScript) - if !isHTTP { - e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + responseScript = surgeScript + responseScriptOptions = scriptOptions + responseMatch = true + break matchResponse } - continue } - if !httpScript.Match(requestURL) { - continue - } - e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") - responseScript = httpScript - responseMatch = true - break } var responseBody []byte if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { @@ -490,7 +496,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls e.printResponse(ctx, request, response, responseBody) } if responseScript != nil { - if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { + if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) { responseBody, err = io.ReadAll(response.Body) if err != nil { return E.Cause(err, "read HTTP response body") @@ -498,7 +504,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls response.Body = io.NopCloser(bytes.NewReader(responseBody)) } var result *adapter.HTTPResponseScriptResult - result, err = responseScript.Run(ctx, request, response, responseBody) + result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments) if err != nil { return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") } @@ -654,28 +660,31 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite requestURL := rawRequestURL.String() request.RequestURI = "" var ( - requestMatch bool - requestScript adapter.HTTPRequestScript + requestMatch bool + requestScript adapter.SurgeScript + requestScriptOptions option.MITMRouteSurgeScriptOptions ) - for _, script := range e.script.Scripts() { - if !common.Contains(options.Script, script.Tag()) { +match: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) continue } - httpScript, isHTTP := script.(adapter.HTTPRequestScript) - if !isHTTP { - _, isHTTP = script.(adapter.HTTPScript) - if !isHTTP { - e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + requestScript = surgeScript + requestScriptOptions = scriptOptions + requestMatch = true + break match } - continue } - if !httpScript.Match(requestURL) { - continue - } - e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") - requestScript = httpScript - requestMatch = true - break } var ( body []byte @@ -693,7 +702,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite e.printRequest(ctx, request, body) } if requestScript != nil { - if body == nil && requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { + if body == nil && requestScriptOptions.RequiresBody && request.ContentLength > 0 && (requestScriptOptions.MaxSize == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScriptOptions.MaxSize) { body, err = io.ReadAll(request.Body) if err != nil { return E.Cause(err, "read HTTP request body") @@ -701,7 +710,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite request.Body.Close() request.Body = io.NopCloser(bytes.NewReader(body)) } - result, err := requestScript.Run(ctx, request, body) + result, err := requestScript.ExecuteHTTPRequest(ctx, time.Duration(requestScriptOptions.Timeout), request, body, requestScriptOptions.BinaryBodyMode, requestScriptOptions.Arguments) if err != nil { return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]") } @@ -888,28 +897,31 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite return E.Cause(err, "exchange request") } var ( - responseScript adapter.HTTPResponseScript - responseMatch bool + responseScript adapter.SurgeScript + responseMatch bool + responseScriptOptions option.MITMRouteSurgeScriptOptions ) - for _, script := range e.script.Scripts() { - if !common.Contains(options.Script, script.Tag()) { +matchResponse: + for _, scriptOptions := range options.Script { + script, loaded := e.script.Script(scriptOptions.Tag) + if !loaded { + e.logger.WarnContext(ctx, "script not found: ", scriptOptions.Tag) continue } - httpScript, isHTTP := script.(adapter.HTTPResponseScript) - if !isHTTP { - _, isHTTP = script.(adapter.HTTPScript) - if !isHTTP { - e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + surgeScript, isSurge := script.(adapter.SurgeScript) + if !isSurge { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a Surge script") + continue + } + for _, pattern := range scriptOptions.Pattern { + if pattern.Build().MatchString(requestURL) { + e.logger.DebugContext(ctx, "match script/", surgeScript.Type(), "[", surgeScript.Tag(), "]") + responseScript = surgeScript + responseScriptOptions = scriptOptions + responseMatch = true + break matchResponse } - continue } - if !httpScript.Match(requestURL) { - continue - } - e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") - responseScript = httpScript - responseMatch = true - break } var responseBody []byte if options.Print && response.ContentLength > 0 && response.ContentLength <= 131072 { @@ -924,7 +936,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite e.printResponse(ctx, request, response, responseBody) } if responseScript != nil { - if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { + if responseBody == nil && responseScriptOptions.RequiresBody && response.ContentLength > 0 && (responseScriptOptions.MaxSize == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScriptOptions.MaxSize) { responseBody, err = io.ReadAll(response.Body) if err != nil { return E.Cause(err, "read HTTP response body") @@ -933,7 +945,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite response.Body = io.NopCloser(bytes.NewReader(responseBody)) } var result *adapter.HTTPResponseScriptResult - result, err = responseScript.Run(ctx, request, response, responseBody) + result, err = responseScript.ExecuteHTTPResponse(ctx, time.Duration(responseScriptOptions.Timeout), request, response, responseBody, responseScriptOptions.BinaryBodyMode, responseScriptOptions.Arguments) if err != nil { return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") } diff --git a/option/mitm.go b/option/mitm.go index 7166f76d..99704da8 100644 --- a/option/mitm.go +++ b/option/mitm.go @@ -17,11 +17,22 @@ type TLSDecryptionOptions struct { } type MITMRouteOptions struct { - Enabled bool `json:"enabled,omitempty"` - Print bool `json:"print,omitempty"` - Script badoption.Listable[string] `json:"script,omitempty"` - SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"sg_url_rewrite,omitempty"` - SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"` - SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"sg_body_rewrite,omitempty"` - SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"sg_map_local,omitempty"` + Enabled bool `json:"enabled,omitempty"` + Print bool `json:"print,omitempty"` + Script badoption.Listable[MITMRouteSurgeScriptOptions] `json:"sg_script,omitempty"` + SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"sg_url_rewrite,omitempty"` + SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"` + SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"sg_body_rewrite,omitempty"` + SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"sg_map_local,omitempty"` +} + +type MITMRouteSurgeScriptOptions struct { + Tag string `json:"tag"` + Type badoption.Listable[string] `json:"type"` + Pattern badoption.Listable[*badoption.Regexp] `json:"pattern"` + Timeout badoption.Duration `json:"timeout,omitempty"` + RequiresBody bool `json:"requires_body,omitempty"` + MaxSize int64 `json:"max_size,omitempty"` + BinaryBodyMode bool `json:"binary_body_mode,omitempty"` + Arguments badoption.Listable[string] `json:"arguments,omitempty"` } diff --git a/option/script.go b/option/script.go index cab9a764..90a3b586 100644 --- a/option/script.go +++ b/option/script.go @@ -29,9 +29,9 @@ type ScriptSourceOptions _ScriptSourceOptions func (o ScriptSourceOptions) MarshalJSON() ([]byte, error) { var source any switch o.Source { - case C.ScriptSourceLocal: + case C.ScriptSourceTypeLocal: source = o.LocalOptions - case C.ScriptSourceRemote: + case C.ScriptSourceTypeRemote: source = o.RemoteOptions default: return nil, E.New("unknown script source: ", o.Source) @@ -46,9 +46,9 @@ func (o *ScriptSourceOptions) UnmarshalJSON(bytes []byte) error { } var source any switch o.Source { - case C.ScriptSourceLocal: + case C.ScriptSourceTypeLocal: source = &o.LocalOptions - case C.ScriptSourceRemote: + case C.ScriptSourceTypeRemote: source = &o.RemoteOptions default: return E.New("unknown script source: ", o.Source) @@ -75,12 +75,9 @@ func (s *Script) UnmarshalJSON(bytes []byte) error { } type _ScriptOptions struct { - Type string `json:"type"` - Tag string `json:"tag"` - Timeout badoption.Duration `json:"timeout,omitempty"` - Arguments []any `json:"arguments,omitempty"` - HTTPOptions HTTPScriptOptions `json:"-"` - CronOptions CronScriptOptions `json:"-"` + Type string `json:"type"` + Tag string `json:"tag"` + SurgeOptions SurgeScriptOptions `json:"-"` } type ScriptOptions _ScriptOptions @@ -88,12 +85,8 @@ type ScriptOptions _ScriptOptions func (o ScriptOptions) MarshalJSON() ([]byte, error) { var v any switch o.Type { - case C.ScriptTypeSurgeGeneric: - v = nil - case C.ScriptTypeSurgeHTTPRequest, C.ScriptTypeSurgeHTTPResponse: - v = o.HTTPOptions - case C.ScriptTypeSurgeCron: - v = o.CronOptions + case C.ScriptTypeSurge: + v = &o.SurgeOptions default: return nil, E.New("unknown script type: ", o.Type) } @@ -110,12 +103,10 @@ func (o *ScriptOptions) UnmarshalJSON(bytes []byte) error { } var v any switch o.Type { - case C.ScriptTypeSurgeGeneric: - v = nil - case C.ScriptTypeSurgeHTTPRequest, C.ScriptTypeSurgeHTTPResponse: - v = &o.HTTPOptions - case C.ScriptTypeSurgeCron: - v = &o.CronOptions + case C.ScriptTypeSurge: + v = &o.SurgeOptions + case "": + return E.New("missing script type") default: return E.New("unknown script type: ", o.Type) } @@ -126,13 +117,12 @@ func (o *ScriptOptions) UnmarshalJSON(bytes []byte) error { return badjson.UnmarshallExcluded(bytes, (*_ScriptOptions)(o), v) } -type HTTPScriptOptions struct { - Pattern string `json:"pattern"` - RequiresBody bool `json:"requires_body,omitempty"` - MaxSize int64 `json:"max_size,omitempty"` - BinaryBodyMode bool `json:"binary_body_mode,omitempty"` +type SurgeScriptOptions struct { + CronOptions *CronScriptOptions `json:"cron,omitempty"` } type CronScriptOptions struct { - Expression string `json:"expression"` + Expression string `json:"expression"` + Arguments []string `json:"arguments,omitempty"` + Timeout badoption.Duration `json:"timeout,omitempty"` } diff --git a/script/jsc/assert.go b/script/jsc/assert.go index a578423a..0b7fe3b6 100644 --- a/script/jsc/assert.go +++ b/script/jsc/assert.go @@ -105,6 +105,9 @@ func AssertStringBinary(vm *goja.Runtime, value goja.Value, name string, nilable } func AssertFunction(vm *goja.Runtime, value goja.Value, name string) goja.Callable { + if IsNil(value) { + panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name))) + } functionValue, isFunction := goja.AssertFunction(value) if !isFunction { panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected function, but got ", value))) diff --git a/script/jsc/class.go b/script/jsc/class.go new file mode 100644 index 00000000..cb949512 --- /dev/null +++ b/script/jsc/class.go @@ -0,0 +1,192 @@ +package jsc + +import ( + "time" + + "github.com/sagernet/sing/common" + + "github.com/dop251/goja" +) + +type Module interface { + Runtime() *goja.Runtime +} + +type Class[M Module, C any] interface { + Module() M + Runtime() *goja.Runtime + DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value)) + DefineMethod(name string, method func(this C, call goja.FunctionCall) any) + DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any) + DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C) + ToValue() goja.Value + New(instance C) *goja.Object + Prototype() *goja.Object + Is(value goja.Value) bool + As(value goja.Value) C +} + +func GetClass[M Module, C any](runtime *goja.Runtime, exports *goja.Object, className string) Class[M, C] { + objectValue := exports.Get(className) + if objectValue == nil { + panic(runtime.NewTypeError("Missing class: " + className)) + } + object, isObject := objectValue.(*goja.Object) + if !isObject { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + classObject, isClass := object.Get("_class").(*goja.Object) + if !isClass { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + class, isClass := classObject.Export().(Class[M, C]) + if !isClass { + panic(runtime.NewTypeError("Invalid class: " + className)) + } + return class +} + +type goClass[M Module, C any] struct { + m M + prototype *goja.Object + constructor goja.Value +} + +func NewClass[M Module, C any](module M) Class[M, C] { + class := &goClass[M, C]{ + m: module, + prototype: module.Runtime().NewObject(), + } + clazz := module.Runtime().ToValue(class).(*goja.Object) + clazz.Set("toString", module.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + return module.Runtime().ToValue("[sing-box Class]") + })) + class.prototype.DefineAccessorProperty("_class", class.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { return clazz }), nil, goja.FLAG_FALSE, goja.FLAG_TRUE) + return class +} + +func (c *goClass[M, C]) Module() M { + return c.m +} + +func (c *goClass[M, C]) Runtime() *goja.Runtime { + return c.m.Runtime() +} + +func (c *goClass[M, C]) DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value)) { + var ( + getterValue goja.Value + setterValue goja.Value + ) + if getter != nil { + getterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.ExportType().String())) + } + return c.toValue(getter(this), goja.Null()) + }) + } + if setter != nil { + setterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String())) + } + setter(this, call.Argument(0)) + return goja.Undefined() + }) + } + c.prototype.DefineAccessorProperty(name, getterValue, setterValue, goja.FLAG_FALSE, goja.FLAG_TRUE) +} + +func (c *goClass[M, C]) DefineMethod(name string, method func(this C, call goja.FunctionCall) any) { + methodValue := c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + this, isThis := call.This.Export().(C) + if !isThis { + panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String())) + } + return c.toValue(method(this, call), goja.Undefined()) + }) + c.prototype.Set(name, methodValue) + if name == "entries" { + c.prototype.DefineDataPropertySymbol(goja.SymIterator, methodValue, goja.FLAG_TRUE, goja.FLAG_FALSE, goja.FLAG_TRUE) + } +} + +func (c *goClass[M, C]) DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any) { + c.prototype.Set(name, c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { + return c.toValue(method(c, call), goja.Undefined()) + })) +} + +func (c *goClass[M, C]) DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C) { + constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object { + value := constructor(c, call) + object := c.toValue(value, goja.Undefined()).(*goja.Object) + object.SetPrototype(call.This.Prototype()) + return object + }).(*goja.Object) + constructorObject.SetPrototype(c.prototype) + c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE) + c.constructor = constructorObject +} + +func (c *goClass[M, C]) toValue(rawValue any, defaultValue goja.Value) goja.Value { + switch value := rawValue.(type) { + case nil: + return defaultValue + case time.Time: + return TimeToValue(c.Runtime(), value) + default: + return c.Runtime().ToValue(value) + } +} + +func (c *goClass[M, C]) ToValue() goja.Value { + if c.constructor == nil { + constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object { + panic(c.Runtime().NewTypeError("Illegal constructor call")) + }).(*goja.Object) + constructorObject.SetPrototype(c.prototype) + c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE) + c.constructor = constructorObject + } + return c.constructor +} + +func (c *goClass[M, C]) New(instance C) *goja.Object { + object := c.Runtime().ToValue(instance).(*goja.Object) + object.SetPrototype(c.prototype) + return object +} + +func (c *goClass[M, C]) Prototype() *goja.Object { + return c.prototype +} + +func (c *goClass[M, C]) Is(value goja.Value) bool { + object, isObject := value.(*goja.Object) + if !isObject { + return false + } + prototype := object.Prototype() + for prototype != nil { + if prototype == c.prototype { + return true + } + prototype = prototype.Prototype() + } + return false +} + +func (c *goClass[M, C]) As(value goja.Value) C { + object, isObject := value.(*goja.Object) + if !isObject { + return common.DefaultValue[C]() + } + if !c.Is(object) { + return common.DefaultValue[C]() + } + return object.Export().(C) +} diff --git a/script/jsc/iterator.go b/script/jsc/iterator.go new file mode 100644 index 00000000..deb66764 --- /dev/null +++ b/script/jsc/iterator.go @@ -0,0 +1,36 @@ +package jsc + +import "github.com/dop251/goja" + +type Iterator[M Module, T any] struct { + c Class[M, *Iterator[M, T]] + values []T + block func(this T) any +} + +func NewIterator[M Module, T any](class Class[M, *Iterator[M, T]], values []T, block func(this T) any) goja.Value { + return class.New(&Iterator[M, T]{class, values, block}) +} + +func CreateIterator[M Module, T any](module M) Class[M, *Iterator[M, T]] { + class := NewClass[M, *Iterator[M, T]](module) + class.DefineMethod("next", (*Iterator[M, T]).next) + class.DefineMethod("toString", (*Iterator[M, T]).toString) + return class +} + +func (i *Iterator[M, T]) next(call goja.FunctionCall) any { + result := i.c.Runtime().NewObject() + if len(i.values) == 0 { + result.Set("done", true) + } else { + result.Set("done", false) + result.Set("value", i.block(i.values[0])) + i.values = i.values[1:] + } + return result +} + +func (i *Iterator[M, T]) toString(call goja.FunctionCall) any { + return "[sing-box Iterator]" +} diff --git a/script/jstest/assert.js b/script/jstest/assert.js new file mode 100644 index 00000000..b00076dd --- /dev/null +++ b/script/jstest/assert.js @@ -0,0 +1,83 @@ +'use strict'; + +const assert = { + _isSameValue(a, b) { + if (a === b) { + // Handle +/-0 vs. -/+0 + return a !== 0 || 1 / a === 1 / b; + } + + // Handle NaN vs. NaN + return a !== a && b !== b; + }, + + _toString(value) { + try { + if (value === 0 && 1 / value === -Infinity) { + return '-0'; + } + + return String(value); + } catch (err) { + if (err.name === 'TypeError') { + return Object.prototype.toString.call(value); + } + + throw err; + } + }, + + sameValue(actual, expected, message) { + if (assert._isSameValue(actual, expected)) { + return; + } + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + + message += 'Expected SameValue(«' + assert._toString(actual) + '», «' + assert._toString(expected) + '») to be true'; + + throw new Error(message); + }, + + throws(f, ctor, message) { + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + try { + f(); + } catch (e) { + if (e.constructor !== ctor) { + throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name); + } + return; + } + throw new Error(message + "No exception was thrown"); + }, + + throwsNodeError(f, ctor, code, message) { + if (message === undefined) { + message = ''; + } else { + message += ' '; + } + try { + f(); + } catch (e) { + if (e.constructor !== ctor) { + throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name); + } + if (e.code !== code) { + throw new Error(message + "Wrong exception code was thrown: " + e.code); + } + return; + } + throw new Error(message + "No exception was thrown"); + } +} + +module.exports = assert; \ No newline at end of file diff --git a/script/jstest/test.go b/script/jstest/test.go new file mode 100644 index 00000000..e287f8c2 --- /dev/null +++ b/script/jstest/test.go @@ -0,0 +1,21 @@ +package jstest + +import ( + _ "embed" + + "github.com/sagernet/sing-box/script/modules/require" +) + +//go:embed assert.js +var assertJS []byte + +func NewRegistry() *require.Registry { + return require.NewRegistry(require.WithFsEnable(true), require.WithLoader(func(path string) ([]byte, error) { + switch path { + case "assert.js": + return assertJS, nil + default: + return require.DefaultSourceLoader(path) + } + })) +} diff --git a/script/manager.go b/script/manager.go index d234afae..50c48362 100644 --- a/script/manager.go +++ b/script/manager.go @@ -17,17 +17,18 @@ import ( var _ adapter.ScriptManager = (*Manager)(nil) type Manager struct { - ctx context.Context - logger logger.ContextLogger - scripts []adapter.Script - // scriptByName map[string]adapter.Script + ctx context.Context + logger logger.ContextLogger + scripts []adapter.Script + scriptByName map[string]adapter.Script + surgeCache *adapter.SurgeInMemoryCache } func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Script) (*Manager, error) { manager := &Manager{ - ctx: ctx, - logger: logFactory.NewLogger("script"), - // scriptByName: make(map[string]adapter.Script), + ctx: ctx, + logger: logFactory.NewLogger("script"), + scriptByName: make(map[string]adapter.Script), } for _, scriptOptions := range scripts { script, err := NewScript(ctx, logFactory.NewLogger(F.ToString("script/", scriptOptions.Type, "[", scriptOptions.Tag, "]")), scriptOptions) @@ -35,7 +36,7 @@ func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Sc return nil, E.Cause(err, "initialize script: ", scriptOptions.Tag) } manager.scripts = append(manager.scripts, script) - // manager.scriptByName[scriptOptions.Tag] = script + manager.scriptByName[scriptOptions.Tag] = script } return manager, nil } @@ -100,8 +101,16 @@ func (m *Manager) Scripts() []adapter.Script { return m.scripts } -/* func (m *Manager) Script(name string) (adapter.Script, bool) { script, loaded := m.scriptByName[name] return script, loaded -}*/ +} + +func (m *Manager) SurgeCache() *adapter.SurgeInMemoryCache { + if m.surgeCache == nil { + m.surgeCache = &adapter.SurgeInMemoryCache{ + Data: make(map[string]string), + } + } + return m.surgeCache +} diff --git a/script/modules/boxctx/context.go b/script/modules/boxctx/context.go new file mode 100644 index 00000000..53e74860 --- /dev/null +++ b/script/modules/boxctx/context.go @@ -0,0 +1,50 @@ +package boxctx + +import ( + "context" + "time" + + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing/common/logger" + + "github.com/dop251/goja" +) + +type Context struct { + class jsc.Class[*Module, *Context] + Context context.Context + Logger logger.ContextLogger + Tag string + StartedAt time.Time + ErrorHandler func(error) +} + +func FromRuntime(runtime *goja.Runtime) *Context { + contextValue := runtime.Get("context") + if contextValue == nil { + return nil + } + context, isContext := contextValue.Export().(*Context) + if !isContext { + return nil + } + return context +} + +func MustFromRuntime(runtime *goja.Runtime) *Context { + context := FromRuntime(runtime) + if context == nil { + panic(runtime.NewTypeError("Missing sing-box context")) + } + return context +} + +func createContext(module *Module) jsc.Class[*Module, *Context] { + class := jsc.NewClass[*Module, *Context](module) + class.DefineMethod("toString", (*Context).toString) + return class +} + +func (c *Context) toString(call goja.FunctionCall) any { + return "[sing-box Context]" +} diff --git a/script/modules/boxctx/module.go b/script/modules/boxctx/module.go new file mode 100644 index 00000000..a18fe844 --- /dev/null +++ b/script/modules/boxctx/module.go @@ -0,0 +1,35 @@ +package boxctx + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + + "github.com/dop251/goja" +) + +const ModuleName = "context" + +type Module struct { + runtime *goja.Runtime + classContext jsc.Class[*Module, *Context] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classContext = createContext(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Context", m.classContext.ToValue()) +} + +func Enable(runtime *goja.Runtime, context *Context) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classContext := jsc.GetClass[*Module, *Context](runtime, exports, "Context") + context.class = classContext + runtime.Set("context", classContext.New(context)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/console/console.go b/script/modules/console/console.go new file mode 100644 index 00000000..4fcfec1f --- /dev/null +++ b/script/modules/console/console.go @@ -0,0 +1,281 @@ +package console + +import ( + "bytes" + "context" + "encoding/xml" + "sync" + "time" + + sLog "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + + "github.com/dop251/goja" +) + +type Console struct { + class jsc.Class[*Module, *Console] + access sync.Mutex + countMap map[string]int + timeMap map[string]time.Time +} + +func NewConsole(class jsc.Class[*Module, *Console]) goja.Value { + return class.New(&Console{ + class: class, + countMap: make(map[string]int), + timeMap: make(map[string]time.Time), + }) +} + +func createConsole(m *Module) jsc.Class[*Module, *Console] { + class := jsc.NewClass[*Module, *Console](m) + class.DefineMethod("assert", (*Console).assert) + class.DefineMethod("clear", (*Console).clear) + class.DefineMethod("count", (*Console).count) + class.DefineMethod("countReset", (*Console).countReset) + class.DefineMethod("debug", (*Console).debug) + class.DefineMethod("dir", (*Console).dir) + class.DefineMethod("dirxml", (*Console).dirxml) + class.DefineMethod("error", (*Console).error) + class.DefineMethod("group", (*Console).stub) + class.DefineMethod("groupCollapsed", (*Console).stub) + class.DefineMethod("groupEnd", (*Console).stub) + class.DefineMethod("info", (*Console).info) + class.DefineMethod("log", (*Console)._log) + class.DefineMethod("profile", (*Console).stub) + class.DefineMethod("profileEnd", (*Console).profileEnd) + class.DefineMethod("table", (*Console).table) + class.DefineMethod("time", (*Console).time) + class.DefineMethod("timeEnd", (*Console).timeEnd) + class.DefineMethod("timeLog", (*Console).timeLog) + class.DefineMethod("timeStamp", (*Console).stub) + class.DefineMethod("trace", (*Console).trace) + class.DefineMethod("warn", (*Console).warn) + return class +} + +func (c *Console) stub(call goja.FunctionCall) any { + return goja.Undefined() +} + +func (c *Console) assert(call goja.FunctionCall) any { + assertion := call.Argument(0).ToBoolean() + if !assertion { + return c.log(logger.ContextLogger.ErrorContext, call.Arguments[1:]) + } + return goja.Undefined() +} + +func (c *Console) clear(call goja.FunctionCall) any { + return nil +} + +func (c *Console) count(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + newValue := c.countMap[label] + 1 + c.countMap[label] = newValue + c.access.Unlock() + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", newValue)) + return goja.Undefined() +} + +func (c *Console) countReset(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + delete(c.countMap, label) + c.access.Unlock() + return goja.Undefined() +} + +func (c *Console) log(logFunc func(logger.ContextLogger, context.Context, ...any), args []goja.Value) any { + var buffer bytes.Buffer + var formatString string + if len(args) > 0 { + formatString = args[0].String() + } + format(c.class.Runtime(), &buffer, formatString, args[1:]...) + writeLog(c.class.Runtime(), logFunc, buffer.String()) + return goja.Undefined() +} + +func (c *Console) debug(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.DebugContext, call.Arguments) +} + +func (c *Console) dir(call goja.FunctionCall) any { + object := jsc.AssertObject(c.class.Runtime(), call.Argument(0), "object", false) + var buffer bytes.Buffer + for _, key := range object.Keys() { + value := object.Get(key) + buffer.WriteString(key) + buffer.WriteString(": ") + buffer.WriteString(value.String()) + buffer.WriteString("\n") + } + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String()) + return goja.Undefined() +} + +func (c *Console) dirxml(call goja.FunctionCall) any { + var buffer bytes.Buffer + encoder := xml.NewEncoder(&buffer) + encoder.Indent("", " ") + encoder.Encode(call.Argument(0).Export()) + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String()) + return goja.Undefined() +} + +func (c *Console) error(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.ErrorContext, call.Arguments) +} + +func (c *Console) info(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.InfoContext, call.Arguments) +} + +func (c *Console) _log(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.InfoContext, call.Arguments) +} + +func (c *Console) profileEnd(call goja.FunctionCall) any { + return goja.Undefined() +} + +func (c *Console) table(call goja.FunctionCall) any { + return c.dir(call) +} + +func (c *Console) time(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + c.timeMap[label] = time.Now() + c.access.Unlock() + return goja.Undefined() +} + +func (c *Console) timeEnd(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + startTime, ok := c.timeMap[label] + if !ok { + c.access.Unlock() + return goja.Undefined() + } + delete(c.timeMap, label) + c.access.Unlock() + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime).String(), " - - timer ended")) + return goja.Undefined() +} + +func (c *Console) timeLog(call goja.FunctionCall) any { + label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true) + if label == "" { + label = "default" + } + c.access.Lock() + startTime, ok := c.timeMap[label] + c.access.Unlock() + if !ok { + writeLog(c.class.Runtime(), logger.ContextLogger.ErrorContext, F.ToString("Timer \"", label, "\" doesn't exist.")) + return goja.Undefined() + } + writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime))) + return goja.Undefined() +} + +func (c *Console) trace(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.TraceContext, call.Arguments) +} + +func (c *Console) warn(call goja.FunctionCall) any { + return c.log(logger.ContextLogger.WarnContext, call.Arguments) +} + +func writeLog(runtime *goja.Runtime, logFunc func(logger.ContextLogger, context.Context, ...any), message string) { + var ( + ctx context.Context + sLogger logger.ContextLogger + ) + boxCtx := boxctx.FromRuntime(runtime) + if boxCtx != nil { + ctx = boxCtx.Context + sLogger = boxCtx.Logger + } else { + ctx = context.Background() + sLogger = sLog.StdLogger() + } + logFunc(sLogger, ctx, message) +} + +func format(runtime *goja.Runtime, b *bytes.Buffer, f string, args ...goja.Value) { + pct := false + argNum := 0 + for _, chr := range f { + if pct { + if argNum < len(args) { + if format1(runtime, chr, args[argNum], b) { + argNum++ + } + } else { + b.WriteByte('%') + b.WriteRune(chr) + } + pct = false + } else { + if chr == '%' { + pct = true + } else { + b.WriteRune(chr) + } + } + } + + for _, arg := range args[argNum:] { + b.WriteByte(' ') + b.WriteString(arg.String()) + } +} + +func format1(runtime *goja.Runtime, f rune, val goja.Value, w *bytes.Buffer) bool { + switch f { + case 's': + w.WriteString(val.String()) + case 'd': + w.WriteString(val.ToNumber().String()) + case 'j': + if json, ok := runtime.Get("JSON").(*goja.Object); ok { + if stringify, ok := goja.AssertFunction(json.Get("stringify")); ok { + res, err := stringify(json, val) + if err != nil { + panic(err) + } + w.WriteString(res.String()) + } + } + case '%': + w.WriteByte('%') + return false + default: + w.WriteByte('%') + w.WriteRune(f) + return false + } + return true +} diff --git a/script/modules/console/context.go b/script/modules/console/context.go new file mode 100644 index 00000000..cfe522a5 --- /dev/null +++ b/script/modules/console/context.go @@ -0,0 +1,3 @@ +package console + +type Context struct{} diff --git a/script/modules/console/module.go b/script/modules/console/module.go index d0640034..4e7cf0ee 100644 --- a/script/modules/console/module.go +++ b/script/modules/console/module.go @@ -1,108 +1,34 @@ package console import ( - "bytes" - "context" - + "github.com/sagernet/sing-box/script/jsc" "github.com/sagernet/sing-box/script/modules/require" - "github.com/sagernet/sing/common/logger" "github.com/dop251/goja" ) const ModuleName = "console" -type Console struct { - vm *goja.Runtime +type Module struct { + runtime *goja.Runtime + console jsc.Class[*Module, *Console] } -func (c *Console) log(ctx context.Context, p func(ctx context.Context, values ...any)) func(goja.FunctionCall) goja.Value { - return func(call goja.FunctionCall) goja.Value { - var buffer bytes.Buffer - var format string - if arg := call.Argument(0); !goja.IsUndefined(arg) { - format = arg.String() - } - var args []goja.Value - if len(call.Arguments) > 0 { - args = call.Arguments[1:] - } - c.Format(&buffer, format, args...) - p(ctx, buffer.String()) - return nil - } -} - -func (c *Console) Format(b *bytes.Buffer, f string, args ...goja.Value) { - pct := false - argNum := 0 - for _, chr := range f { - if pct { - if argNum < len(args) { - if c.format(chr, args[argNum], b) { - argNum++ - } - } else { - b.WriteByte('%') - b.WriteRune(chr) - } - pct = false - } else { - if chr == '%' { - pct = true - } else { - b.WriteRune(chr) - } - } - } - - for _, arg := range args[argNum:] { - b.WriteByte(' ') - b.WriteString(arg.String()) - } -} - -func (c *Console) format(f rune, val goja.Value, w *bytes.Buffer) bool { - switch f { - case 's': - w.WriteString(val.String()) - case 'd': - w.WriteString(val.ToNumber().String()) - case 'j': - if json, ok := c.vm.Get("JSON").(*goja.Object); ok { - if stringify, ok := goja.AssertFunction(json.Get("stringify")); ok { - res, err := stringify(json, val) - if err != nil { - panic(err) - } - w.WriteString(res.String()) - } - } - case '%': - w.WriteByte('%') - return false - default: - w.WriteByte('%') - w.WriteRune(f) - return false - } - return true -} - -func Require(ctx context.Context, logger logger.ContextLogger) require.ModuleLoader { - return func(runtime *goja.Runtime, module *goja.Object) { - c := &Console{ - vm: runtime, - } - o := module.Get("exports").(*goja.Object) - o.Set("log", c.log(ctx, logger.DebugContext)) - o.Set("error", c.log(ctx, logger.ErrorContext)) - o.Set("warn", c.log(ctx, logger.WarnContext)) - o.Set("info", c.log(ctx, logger.InfoContext)) - o.Set("debug", c.log(ctx, logger.DebugContext)) +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, } + m.console = createConsole(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Console", m.console.ToValue()) } func Enable(runtime *goja.Runtime) { - runtime.Set("console", require.Require(runtime, ModuleName)) + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classConsole := jsc.GetClass[*Module, *Console](runtime, exports, "Console") + runtime.Set("console", NewConsole(classConsole)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime } diff --git a/script/modules/sghttp/module.go b/script/modules/sghttp/module.go deleted file mode 100644 index f1b9cf8d..00000000 --- a/script/modules/sghttp/module.go +++ /dev/null @@ -1,147 +0,0 @@ -package sghttp - -import ( - "bytes" - "context" - "crypto/tls" - "io" - "net/http" - "net/http/cookiejar" - "sync" - "time" - - "github.com/sagernet/sing-box/script/jsc" - F "github.com/sagernet/sing/common/format" - - "github.com/dop251/goja" - "golang.org/x/net/publicsuffix" -) - -type SurgeHTTP struct { - vm *goja.Runtime - ctx context.Context - cookieAccess sync.RWMutex - cookieJar *cookiejar.Jar - errorHandler func(error) -} - -func Enable(vm *goja.Runtime, ctx context.Context, errorHandler func(error)) { - sgHTTP := &SurgeHTTP{ - vm: vm, - ctx: ctx, - errorHandler: errorHandler, - } - httpObject := vm.NewObject() - httpObject.Set("get", sgHTTP.request(http.MethodGet)) - httpObject.Set("post", sgHTTP.request(http.MethodPost)) - httpObject.Set("put", sgHTTP.request(http.MethodPut)) - httpObject.Set("delete", sgHTTP.request(http.MethodDelete)) - httpObject.Set("head", sgHTTP.request(http.MethodHead)) - httpObject.Set("options", sgHTTP.request(http.MethodOptions)) - httpObject.Set("patch", sgHTTP.request(http.MethodPatch)) - httpObject.Set("trace", sgHTTP.request(http.MethodTrace)) - vm.Set("$http", httpObject) -} - -func (s *SurgeHTTP) request(method string) func(call goja.FunctionCall) goja.Value { - return func(call goja.FunctionCall) goja.Value { - if len(call.Arguments) != 2 { - panic(s.vm.NewTypeError("invalid arguments")) - } - var ( - url string - headers http.Header - body []byte - timeout = 5 * time.Second - insecure bool - autoCookie bool - autoRedirect bool - // policy string - binaryMode bool - ) - switch optionsValue := call.Argument(0).(type) { - case goja.String: - url = optionsValue.String() - case *goja.Object: - url = jsc.AssertString(s.vm, optionsValue.Get("url"), "options.url", false) - headers = jsc.AssertHTTPHeader(s.vm, optionsValue.Get("headers"), "option.headers") - body = jsc.AssertStringBinary(s.vm, optionsValue.Get("body"), "options.body", true) - timeoutInt := jsc.AssertInt(s.vm, optionsValue.Get("timeout"), "options.timeout", true) - if timeoutInt > 0 { - timeout = time.Duration(timeoutInt) * time.Second - } - insecure = jsc.AssertBool(s.vm, optionsValue.Get("insecure"), "options.insecure", true) - autoCookie = jsc.AssertBool(s.vm, optionsValue.Get("auto-cookie"), "options.auto-cookie", true) - autoRedirect = jsc.AssertBool(s.vm, optionsValue.Get("auto-redirect"), "options.auto-redirect", true) - // policy = jsc.AssertString(s.vm, optionsValue.Get("policy"), "options.policy", true) - binaryMode = jsc.AssertBool(s.vm, optionsValue.Get("binary-mode"), "options.binary-mode", true) - default: - panic(s.vm.NewTypeError(F.ToString("invalid argument: options: expected string or object, but got ", optionsValue))) - } - callback := jsc.AssertFunction(s.vm, call.Argument(1), "callback") - httpClient := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecure, - }, - ForceAttemptHTTP2: true, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if autoRedirect { - return nil - } - return http.ErrUseLastResponse - }, - } - if autoCookie { - s.cookieAccess.Lock() - if s.cookieJar == nil { - s.cookieJar, _ = cookiejar.New(&cookiejar.Options{ - PublicSuffixList: publicsuffix.List, - }) - } - httpClient.Jar = s.cookieJar - s.cookieAccess.Lock() - } - request, err := http.NewRequestWithContext(s.ctx, method, url, bytes.NewReader(body)) - if host := headers.Get("Host"); host != "" { - request.Host = host - headers.Del("Host") - } - request.Header = headers - if err != nil { - panic(s.vm.NewGoError(err)) - } - go func() { - response, executeErr := httpClient.Do(request) - if err != nil { - _, err = callback(nil, s.vm.NewGoError(executeErr), nil, nil) - if err != nil { - s.errorHandler(err) - } - return - } - defer response.Body.Close() - var content []byte - content, err = io.ReadAll(response.Body) - if err != nil { - _, err = callback(nil, s.vm.NewGoError(err), nil, nil) - if err != nil { - s.errorHandler(err) - } - } - responseObject := s.vm.NewObject() - responseObject.Set("status", response.StatusCode) - responseObject.Set("headers", jsc.HeadersToValue(s.vm, response.Header)) - var bodyValue goja.Value - if binaryMode { - bodyValue = jsc.NewUint8Array(s.vm, content) - } else { - bodyValue = s.vm.ToValue(string(content)) - } - _, err = callback(nil, nil, responseObject, bodyValue) - }() - return goja.Undefined() - } -} diff --git a/script/modules/sgstore/module.go b/script/modules/sgstore/module.go deleted file mode 100644 index 8fe4bec9..00000000 --- a/script/modules/sgstore/module.go +++ /dev/null @@ -1,76 +0,0 @@ -package sgstore - -import ( - "context" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/script/jsc" - "github.com/sagernet/sing/service" - - "github.com/dop251/goja" -) - -type SurgePersistentStore struct { - vm *goja.Runtime - cacheFile adapter.CacheFile - data map[string]string - tag string -} - -func Enable(vm *goja.Runtime, ctx context.Context) { - object := vm.NewObject() - cacheFile := service.FromContext[adapter.CacheFile](ctx) - tag := vm.Get("$script").(*goja.Object).Get("name").String() - store := &SurgePersistentStore{ - vm: vm, - cacheFile: cacheFile, - tag: tag, - } - if cacheFile == nil { - store.data = make(map[string]string) - } - object.Set("read", store.js_read) - object.Set("write", store.js_write) - vm.Set("$persistentStore", object) -} - -func (s *SurgePersistentStore) js_read(call goja.FunctionCall) goja.Value { - if len(call.Arguments) > 1 { - panic(s.vm.NewTypeError("invalid arguments")) - } - key := jsc.AssertString(s.vm, call.Argument(0), "key", true) - if key == "" { - key = s.tag - } - var value string - if s.cacheFile != nil { - value = s.cacheFile.SurgePersistentStoreRead(key) - } else { - value = s.data[key] - } - if value == "" { - return goja.Null() - } else { - return s.vm.ToValue(value) - } -} - -func (s *SurgePersistentStore) js_write(call goja.FunctionCall) goja.Value { - if len(call.Arguments) == 0 || len(call.Arguments) > 2 { - panic(s.vm.NewTypeError("invalid arguments")) - } - data := jsc.AssertString(s.vm, call.Argument(0), "data", true) - key := jsc.AssertString(s.vm, call.Argument(1), "key", true) - if key == "" { - key = s.tag - } - if s.cacheFile != nil { - err := s.cacheFile.SurgePersistentStoreWrite(key, data) - if err != nil { - panic(s.vm.NewGoError(err)) - } - } else { - s.data[key] = data - } - return goja.Undefined() -} diff --git a/script/modules/sgutils/module.go b/script/modules/sgutils/module.go deleted file mode 100644 index 15152d9d..00000000 --- a/script/modules/sgutils/module.go +++ /dev/null @@ -1,45 +0,0 @@ -package sgutils - -import ( - "bytes" - "compress/gzip" - "io" - - "github.com/sagernet/sing-box/script/jsc" - E "github.com/sagernet/sing/common/exceptions" - - "github.com/dop251/goja" -) - -type SurgeUtils struct { - vm *goja.Runtime -} - -func Enable(runtime *goja.Runtime) { - utils := &SurgeUtils{runtime} - object := runtime.NewObject() - object.Set("geoip", utils.js_stub) - object.Set("ipasn", utils.js_stub) - object.Set("ipaso", utils.js_stub) - object.Set("ungzip", utils.js_ungzip) -} - -func (u *SurgeUtils) js_stub(call goja.FunctionCall) goja.Value { - panic(u.vm.NewGoError(E.New("not implemented"))) -} - -func (u *SurgeUtils) js_ungzip(call goja.FunctionCall) goja.Value { - if len(call.Arguments) != 1 { - panic(u.vm.NewGoError(E.New("invalid argument"))) - } - binary := jsc.AssertBinary(u.vm, call.Argument(0), "binary", false) - reader, err := gzip.NewReader(bytes.NewReader(binary)) - if err != nil { - panic(u.vm.NewGoError(err)) - } - binary, err = io.ReadAll(reader) - if err != nil { - panic(u.vm.NewGoError(err)) - } - return jsc.NewUint8Array(u.vm, binary) -} diff --git a/script/modules/surge/environment.go b/script/modules/surge/environment.go new file mode 100644 index 00000000..590469c6 --- /dev/null +++ b/script/modules/surge/environment.go @@ -0,0 +1,65 @@ +package surge + +import ( + "runtime" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/locale" + "github.com/sagernet/sing-box/script/jsc" + + "github.com/dop251/goja" +) + +type Environment struct { + class jsc.Class[*Module, *Environment] +} + +func createEnvironment(module *Module) jsc.Class[*Module, *Environment] { + class := jsc.NewClass[*Module, *Environment](module) + class.DefineField("system", (*Environment).getSystem, nil) + class.DefineField("surge-build", (*Environment).getSurgeBuild, nil) + class.DefineField("surge-version", (*Environment).getSurgeVersion, nil) + class.DefineField("language", (*Environment).getLanguage, nil) + class.DefineField("device-model", (*Environment).getDeviceModel, nil) + class.DefineMethod("toString", (*Environment).toString) + return class +} + +func (e *Environment) getSystem() any { + switch runtime.GOOS { + case "ios": + return "iOS" + case "darwin": + return "macOS" + case "tvos": + return "tvOS" + case "linux": + return "Linux" + case "android": + return "Android" + case "windows": + return "Windows" + default: + return runtime.GOOS + } +} + +func (e *Environment) getSurgeBuild() any { + return "N/A" +} + +func (e *Environment) getSurgeVersion() any { + return "sing-box " + C.Version +} + +func (e *Environment) getLanguage() any { + return locale.Current().Locale +} + +func (e *Environment) getDeviceModel() any { + return "N/A" +} + +func (e *Environment) toString(call goja.FunctionCall) any { + return "[sing-box Surge environment" +} diff --git a/script/modules/surge/http.go b/script/modules/surge/http.go new file mode 100644 index 00000000..49aef0d8 --- /dev/null +++ b/script/modules/surge/http.go @@ -0,0 +1,150 @@ +package surge + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "net/http/cookiejar" + "time" + + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" + "golang.org/x/net/publicsuffix" +) + +type HTTP struct { + class jsc.Class[*Module, *HTTP] + cookieJar *cookiejar.Jar + httpTransport *http.Transport +} + +func createHTTP(module *Module) jsc.Class[*Module, *HTTP] { + class := jsc.NewClass[*Module, *HTTP](module) + class.DefineConstructor(newHTTP) + class.DefineMethod("get", httpRequest(http.MethodGet)) + class.DefineMethod("post", httpRequest(http.MethodPost)) + class.DefineMethod("put", httpRequest(http.MethodPut)) + class.DefineMethod("delete", httpRequest(http.MethodDelete)) + class.DefineMethod("head", httpRequest(http.MethodHead)) + class.DefineMethod("options", httpRequest(http.MethodOptions)) + class.DefineMethod("patch", httpRequest(http.MethodPatch)) + class.DefineMethod("trace", httpRequest(http.MethodTrace)) + class.DefineMethod("toString", (*HTTP).toString) + return class +} + +func newHTTP(class jsc.Class[*Module, *HTTP], call goja.ConstructorCall) *HTTP { + return &HTTP{ + class: class, + cookieJar: common.Must1(cookiejar.New(&cookiejar.Options{ + PublicSuffixList: publicsuffix.List, + })), + httpTransport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{}, + }, + } +} + +func httpRequest(method string) func(s *HTTP, call goja.FunctionCall) any { + return func(s *HTTP, call goja.FunctionCall) any { + if len(call.Arguments) != 2 { + panic(s.class.Runtime().NewTypeError("invalid arguments")) + } + context := boxctx.MustFromRuntime(s.class.Runtime()) + var ( + url string + headers http.Header + body []byte + timeout = 5 * time.Second + insecure bool + autoCookie bool = true + autoRedirect bool + // policy string + binaryMode bool + ) + switch optionsValue := call.Argument(0).(type) { + case goja.String: + url = optionsValue.String() + case *goja.Object: + url = jsc.AssertString(s.class.Runtime(), optionsValue.Get("url"), "options.url", false) + headers = jsc.AssertHTTPHeader(s.class.Runtime(), optionsValue.Get("headers"), "option.headers") + body = jsc.AssertStringBinary(s.class.Runtime(), optionsValue.Get("body"), "options.body", true) + timeoutInt := jsc.AssertInt(s.class.Runtime(), optionsValue.Get("timeout"), "options.timeout", true) + if timeoutInt > 0 { + timeout = time.Duration(timeoutInt) * time.Second + } + insecure = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("insecure"), "options.insecure", true) + autoCookie = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-cookie"), "options.auto-cookie", true) + autoRedirect = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-redirect"), "options.auto-redirect", true) + // policy = jsc.AssertString(s.class.Runtime(), optionsValue.Get("policy"), "options.policy", true) + binaryMode = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("binary-mode"), "options.binary-mode", true) + default: + panic(s.class.Runtime().NewTypeError(F.ToString("invalid argument: options: expected string or object, but got ", optionsValue))) + } + callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(1), "callback") + s.httpTransport.TLSClientConfig.InsecureSkipVerify = insecure + httpClient := &http.Client{ + Timeout: timeout, + Transport: s.httpTransport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if autoRedirect { + return nil + } + return http.ErrUseLastResponse + }, + } + if autoCookie { + httpClient.Jar = s.cookieJar + } + request, err := http.NewRequestWithContext(context.Context, method, url, bytes.NewReader(body)) + if host := headers.Get("Host"); host != "" { + request.Host = host + headers.Del("Host") + } + request.Header = headers + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + go func() { + defer s.httpTransport.CloseIdleConnections() + response, executeErr := httpClient.Do(request) + if err != nil { + _, err = callback(nil, s.class.Runtime().NewGoError(executeErr), nil, nil) + if err != nil { + context.ErrorHandler(err) + } + return + } + defer response.Body.Close() + var content []byte + content, err = io.ReadAll(response.Body) + if err != nil { + _, err = callback(nil, s.class.Runtime().NewGoError(err), nil, nil) + if err != nil { + context.ErrorHandler(err) + } + } + responseObject := s.class.Runtime().NewObject() + responseObject.Set("status", response.StatusCode) + responseObject.Set("headers", jsc.HeadersToValue(s.class.Runtime(), response.Header)) + var bodyValue goja.Value + if binaryMode { + bodyValue = jsc.NewUint8Array(s.class.Runtime(), content) + } else { + bodyValue = s.class.Runtime().ToValue(string(content)) + } + _, err = callback(nil, nil, responseObject, bodyValue) + }() + return nil + } +} + +func (h *HTTP) toString(call goja.FunctionCall) any { + return "[sing-box Surge HTTP]" +} diff --git a/script/modules/surge/module.go b/script/modules/surge/module.go new file mode 100644 index 00000000..f3394426 --- /dev/null +++ b/script/modules/surge/module.go @@ -0,0 +1,63 @@ +package surge + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + "github.com/sagernet/sing/common" + + "github.com/dop251/goja" +) + +const ModuleName = "surge" + +type Module struct { + runtime *goja.Runtime + classScript jsc.Class[*Module, *Script] + classEnvironment jsc.Class[*Module, *Environment] + classPersistentStore jsc.Class[*Module, *PersistentStore] + classHTTP jsc.Class[*Module, *HTTP] + classUtils jsc.Class[*Module, *Utils] + classNotification jsc.Class[*Module, *Notification] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classScript = createScript(m) + m.classEnvironment = createEnvironment(m) + m.classPersistentStore = createPersistentStore(m) + m.classHTTP = createHTTP(m) + m.classUtils = createUtils(m) + m.classNotification = createNotification(m) + exports := module.Get("exports").(*goja.Object) + exports.Set("Script", m.classScript.ToValue()) + exports.Set("Environment", m.classEnvironment.ToValue()) + exports.Set("PersistentStore", m.classPersistentStore.ToValue()) + exports.Set("HTTP", m.classHTTP.ToValue()) + exports.Set("Utils", m.classUtils.ToValue()) + exports.Set("Notification", m.classNotification.ToValue()) +} + +func Enable(runtime *goja.Runtime, scriptType string, args []string) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + classScript := jsc.GetClass[*Module, *Script](runtime, exports, "Script") + classEnvironment := jsc.GetClass[*Module, *Environment](runtime, exports, "Environment") + classPersistentStore := jsc.GetClass[*Module, *PersistentStore](runtime, exports, "PersistentStore") + classHTTP := jsc.GetClass[*Module, *HTTP](runtime, exports, "HTTP") + classUtils := jsc.GetClass[*Module, *Utils](runtime, exports, "Utils") + classNotification := jsc.GetClass[*Module, *Notification](runtime, exports, "Notification") + runtime.Set("$script", classScript.New(&Script{class: classScript, ScriptType: scriptType})) + runtime.Set("$environment", classEnvironment.New(&Environment{class: classEnvironment})) + runtime.Set("$persistentStore", newPersistentStore(classPersistentStore)) + runtime.Set("$http", classHTTP.New(newHTTP(classHTTP, goja.ConstructorCall{}))) + runtime.Set("$utils", classUtils.New(&Utils{class: classUtils})) + runtime.Set("$notification", newNotification(classNotification)) + runtime.Set("$argument", runtime.NewArray(common.Map(args, func(it string) any { + return it + })...)) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/surge/notification.go b/script/modules/surge/notification.go new file mode 100644 index 00000000..4f330388 --- /dev/null +++ b/script/modules/surge/notification.go @@ -0,0 +1,120 @@ +package surge + +import ( + "encoding/base64" + "strings" + + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" + + "github.com/dop251/goja" +) + +type Notification struct { + class jsc.Class[*Module, *Notification] + logger logger.ContextLogger + tag string + platformInterface platform.Interface +} + +func createNotification(module *Module) jsc.Class[*Module, *Notification] { + class := jsc.NewClass[*Module, *Notification](module) + class.DefineMethod("post", (*Notification).post) + class.DefineMethod("toString", (*Notification).toString) + return class +} + +func newNotification(class jsc.Class[*Module, *Notification]) goja.Value { + context := boxctx.MustFromRuntime(class.Runtime()) + return class.New(&Notification{ + class: class, + logger: context.Logger, + tag: context.Tag, + platformInterface: service.FromContext[platform.Interface](context.Context), + }) +} + +func (s *Notification) post(call goja.FunctionCall) any { + var ( + title string + subtitle string + body string + openURL string + clipboard string + mediaURL string + mediaData []byte + mediaType string + autoDismiss int + ) + title = jsc.AssertString(s.class.Runtime(), call.Argument(0), "title", true) + subtitle = jsc.AssertString(s.class.Runtime(), call.Argument(1), "subtitle", true) + body = jsc.AssertString(s.class.Runtime(), call.Argument(2), "body", true) + options := jsc.AssertObject(s.class.Runtime(), call.Argument(3), "options", true) + if options != nil { + action := jsc.AssertString(s.class.Runtime(), options.Get("action"), "options.action", true) + switch action { + case "open-url": + openURL = jsc.AssertString(s.class.Runtime(), options.Get("url"), "options.url", false) + case "clipboard": + clipboard = jsc.AssertString(s.class.Runtime(), options.Get("clipboard"), "options.clipboard", false) + } + mediaURL = jsc.AssertString(s.class.Runtime(), options.Get("media-url"), "options.media-url", true) + mediaBase64 := jsc.AssertString(s.class.Runtime(), options.Get("media-base64"), "options.media-base64", true) + if mediaBase64 != "" { + mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64) + if err != nil { + panic(s.class.Runtime().NewGoError(E.Cause(err, "decode media-base64"))) + } + mediaData = mediaBinary + mediaType = jsc.AssertString(s.class.Runtime(), options.Get("media-base64-mime"), "options.media-base64-mime", false) + } + autoDismiss = int(jsc.AssertInt(s.class.Runtime(), options.Get("auto-dismiss"), "options.auto-dismiss", true)) + } + if title != "" && subtitle == "" && body == "" { + body = title + title = "" + } else if title != "" && subtitle != "" && body == "" { + body = subtitle + subtitle = "" + } + var builder strings.Builder + if title != "" { + builder.WriteString("[") + builder.WriteString(title) + if subtitle != "" { + builder.WriteString(" - ") + builder.WriteString(subtitle) + } + builder.WriteString("]: ") + } + builder.WriteString(body) + s.logger.Info("notification: " + builder.String()) + if s.platformInterface != nil { + err := s.platformInterface.SendNotification(&platform.Notification{ + Identifier: "surge-script-notification-" + s.tag, + TypeName: "Surge Script Notification (" + s.tag + ")", + TypeID: 11, + Title: title, + Subtitle: subtitle, + Body: body, + OpenURL: openURL, + Clipboard: clipboard, + MediaURL: mediaURL, + MediaData: mediaData, + MediaType: mediaType, + Timeout: autoDismiss, + }) + if err != nil { + s.logger.Error(E.Cause(err, "send notification")) + } + } + return nil +} + +func (s *Notification) toString(call goja.FunctionCall) any { + return "[sing-box Surge notification]" +} diff --git a/script/modules/surge/persistent_store.go b/script/modules/surge/persistent_store.go new file mode 100644 index 00000000..7c40f2fa --- /dev/null +++ b/script/modules/surge/persistent_store.go @@ -0,0 +1,78 @@ +package surge + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing/service" + + "github.com/dop251/goja" +) + +type PersistentStore struct { + class jsc.Class[*Module, *PersistentStore] + cacheFile adapter.CacheFile + inMemoryCache *adapter.SurgeInMemoryCache + tag string +} + +func createPersistentStore(module *Module) jsc.Class[*Module, *PersistentStore] { + class := jsc.NewClass[*Module, *PersistentStore](module) + class.DefineMethod("get", (*PersistentStore).get) + class.DefineMethod("set", (*PersistentStore).set) + class.DefineMethod("toString", (*PersistentStore).toString) + return class +} + +func newPersistentStore(class jsc.Class[*Module, *PersistentStore]) goja.Value { + boxCtx := boxctx.MustFromRuntime(class.Runtime()) + return class.New(&PersistentStore{ + class: class, + cacheFile: service.FromContext[adapter.CacheFile](boxCtx.Context), + inMemoryCache: service.FromContext[adapter.ScriptManager](boxCtx.Context).SurgeCache(), + tag: boxCtx.Tag, + }) +} + +func (s *PersistentStore) get(call goja.FunctionCall) any { + key := jsc.AssertString(s.class.Runtime(), call.Argument(0), "key", true) + if key == "" { + key = s.tag + } + var value string + if s.cacheFile != nil { + value = s.cacheFile.SurgePersistentStoreRead(key) + } else { + s.inMemoryCache.RLock() + value = s.inMemoryCache.Data[key] + s.inMemoryCache.RUnlock() + } + if value == "" { + return goja.Null() + } else { + return value + } +} + +func (s *PersistentStore) set(call goja.FunctionCall) any { + data := jsc.AssertString(s.class.Runtime(), call.Argument(0), "data", true) + key := jsc.AssertString(s.class.Runtime(), call.Argument(1), "key", true) + if key == "" { + key = s.tag + } + if s.cacheFile != nil { + err := s.cacheFile.SurgePersistentStoreWrite(key, data) + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + } else { + s.inMemoryCache.Lock() + s.inMemoryCache.Data[key] = data + s.inMemoryCache.Unlock() + } + return goja.Undefined() +} + +func (s *PersistentStore) toString(call goja.FunctionCall) any { + return "[sing-box Surge persistentStore]" +} diff --git a/script/modules/surge/script.go b/script/modules/surge/script.go new file mode 100644 index 00000000..de106ec8 --- /dev/null +++ b/script/modules/surge/script.go @@ -0,0 +1,32 @@ +package surge + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/boxctx" + F "github.com/sagernet/sing/common/format" +) + +type Script struct { + class jsc.Class[*Module, *Script] + ScriptType string +} + +func createScript(module *Module) jsc.Class[*Module, *Script] { + class := jsc.NewClass[*Module, *Script](module) + class.DefineField("name", (*Script).getName, nil) + class.DefineField("type", (*Script).getType, nil) + class.DefineField("startTime", (*Script).getStartTime, nil) + return class +} + +func (s *Script) getName() any { + return F.ToString("script:", boxctx.MustFromRuntime(s.class.Runtime()).Tag) +} + +func (s *Script) getType() any { + return s.ScriptType +} + +func (s *Script) getStartTime() any { + return boxctx.MustFromRuntime(s.class.Runtime()).StartedAt +} diff --git a/script/modules/surge/utils.go b/script/modules/surge/utils.go new file mode 100644 index 00000000..9320ab1c --- /dev/null +++ b/script/modules/surge/utils.go @@ -0,0 +1,50 @@ +package surge + +import ( + "bytes" + "compress/gzip" + "io" + + "github.com/sagernet/sing-box/script/jsc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/dop251/goja" +) + +type Utils struct { + class jsc.Class[*Module, *Utils] +} + +func createUtils(module *Module) jsc.Class[*Module, *Utils] { + class := jsc.NewClass[*Module, *Utils](module) + class.DefineMethod("geoip", (*Utils).stub) + class.DefineMethod("ipasn", (*Utils).stub) + class.DefineMethod("ipaso", (*Utils).stub) + class.DefineMethod("ungzip", (*Utils).ungzip) + class.DefineMethod("toString", (*Utils).toString) + return class +} + +func (u *Utils) stub(call goja.FunctionCall) any { + return nil +} + +func (u *Utils) ungzip(call goja.FunctionCall) any { + if len(call.Arguments) != 1 { + panic(u.class.Runtime().NewGoError(E.New("invalid argument"))) + } + binary := jsc.AssertBinary(u.class.Runtime(), call.Argument(0), "binary", false) + reader, err := gzip.NewReader(bytes.NewReader(binary)) + if err != nil { + panic(u.class.Runtime().NewGoError(err)) + } + binary, err = io.ReadAll(reader) + if err != nil { + panic(u.class.Runtime().NewGoError(err)) + } + return jsc.NewUint8Array(u.class.Runtime(), binary) +} + +func (u *Utils) toString(call goja.FunctionCall) any { + return "[sing-box Surge utils]" +} diff --git a/script/modules/url/escape.go b/script/modules/url/escape.go new file mode 100644 index 00000000..93c8ab1b --- /dev/null +++ b/script/modules/url/escape.go @@ -0,0 +1,55 @@ +package url + +import "strings" + +var tblEscapeURLQuery = [128]byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, +} + +// The code below is mostly borrowed from the standard Go url package + +const upperhex = "0123456789ABCDEF" + +func escape(s string, table *[128]byte, spaceToPlus bool) string { + spaceCount, hexCount := 0, 0 + for i := 0; i < len(s); i++ { + c := s[i] + if c > 127 || table[c] == 0 { + if c == ' ' && spaceToPlus { + spaceCount++ + } else { + hexCount++ + } + } + } + + if spaceCount == 0 && hexCount == 0 { + return s + } + + var sb strings.Builder + hexBuf := [3]byte{'%', 0, 0} + + sb.Grow(len(s) + 2*hexCount) + + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c == ' ' && spaceToPlus: + sb.WriteByte('+') + case c > 127 || table[c] == 0: + hexBuf[1] = upperhex[c>>4] + hexBuf[2] = upperhex[c&15] + sb.Write(hexBuf[:]) + default: + sb.WriteByte(c) + } + } + return sb.String() +} diff --git a/script/modules/url/module.go b/script/modules/url/module.go new file mode 100644 index 00000000..11b4b6c4 --- /dev/null +++ b/script/modules/url/module.go @@ -0,0 +1,41 @@ +package url + +import ( + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/require" + + "github.com/dop251/goja" +) + +const ModuleName = "url" + +var _ jsc.Module = (*Module)(nil) + +type Module struct { + runtime *goja.Runtime + classURL jsc.Class[*Module, *URL] + classURLSearchParams jsc.Class[*Module, *URLSearchParams] + classURLSearchParamsIterator jsc.Class[*Module, *jsc.Iterator[*Module, searchParam]] +} + +func Require(runtime *goja.Runtime, module *goja.Object) { + m := &Module{ + runtime: runtime, + } + m.classURL = createURL(m) + m.classURLSearchParams = createURLSearchParams(m) + m.classURLSearchParamsIterator = jsc.CreateIterator[*Module, searchParam](m) + exports := module.Get("exports").(*goja.Object) + exports.Set("URL", m.classURL.ToValue()) + exports.Set("URLSearchParams", m.classURLSearchParams.ToValue()) +} + +func Enable(runtime *goja.Runtime) { + exports := require.Require(runtime, ModuleName).ToObject(runtime) + runtime.Set("URL", exports.Get("URL")) + runtime.Set("URLSearchParams", exports.Get("URLSearchParams")) +} + +func (m *Module) Runtime() *goja.Runtime { + return m.runtime +} diff --git a/script/modules/url/module_test.go b/script/modules/url/module_test.go new file mode 100644 index 00000000..2b38a40d --- /dev/null +++ b/script/modules/url/module_test.go @@ -0,0 +1,37 @@ +package url_test + +import ( + _ "embed" + "testing" + + "github.com/sagernet/sing-box/script/jstest" + "github.com/sagernet/sing-box/script/modules/url" + + "github.com/dop251/goja" +) + +var ( + //go:embed testdata/url_test.js + urlTest string + + //go:embed testdata/url_search_params_test.js + urlSearchParamsTest string +) + +func TestURL(t *testing.T) { + registry := jstest.NewRegistry() + registry.RegisterNodeModule(url.ModuleName, url.Require) + vm := goja.New() + registry.Enable(vm) + url.Enable(vm) + vm.RunScript("url_test.js", urlTest) +} + +func TestURLSearchParams(t *testing.T) { + registry := jstest.NewRegistry() + registry.RegisterNodeModule(url.ModuleName, url.Require) + vm := goja.New() + registry.Enable(vm) + url.Enable(vm) + vm.RunScript("url_search_params_test.js", urlSearchParamsTest) +} diff --git a/script/modules/url/testdata/url_search_params_test.js b/script/modules/url/testdata/url_search_params_test.js new file mode 100644 index 00000000..4c4897c3 --- /dev/null +++ b/script/modules/url/testdata/url_search_params_test.js @@ -0,0 +1,385 @@ +"use strict"; + +const assert = require("assert.js"); + +let params; + +function testCtor(value, expected) { + assert.sameValue(new URLSearchParams(value).toString(), expected); +} + +testCtor("user=abc&query=xyz", "user=abc&query=xyz"); +testCtor("?user=abc&query=xyz", "user=abc&query=xyz"); + +testCtor( + { + num: 1, + user: "abc", + query: ["first", "second"], + obj: { prop: "value" }, + b: true, + }, + "num=1&user=abc&query=first%2Csecond&obj=%5Bobject+Object%5D&b=true" +); + +const map = new Map(); +map.set("user", "abc"); +map.set("query", "xyz"); +testCtor(map, "user=abc&query=xyz"); + +testCtor( + [ + ["user", "abc"], + ["query", "first"], + ["query", "second"], + ], + "user=abc&query=first&query=second" +); + +// Each key-value pair must have exactly two elements +assert.throwsNodeError(() => new URLSearchParams([["single_value"]]), TypeError, "ERR_INVALID_TUPLE"); +assert.throwsNodeError(() => new URLSearchParams([["too", "many", "values"]]), TypeError, "ERR_INVALID_TUPLE"); + +params = new URLSearchParams("a=b&cc=d"); +params.forEach((value, name, searchParams) => { + if (name === "a") { + assert.sameValue(value, "b"); + } + if (name === "cc") { + assert.sameValue(value, "d"); + } + assert.sameValue(searchParams, params); +}); + +params.forEach((value, name, searchParams) => { + if (name === "a") { + assert.sameValue(value, "b"); + searchParams.set("cc", "d1"); + } + if (name === "cc") { + assert.sameValue(value, "d1"); + } + assert.sameValue(searchParams, params); +}); + +assert.throwsNodeError(() => params.forEach(123), TypeError, "ERR_INVALID_ARG_TYPE"); + +assert.throwsNodeError(() => params.forEach.call(1, 2), TypeError, "ERR_INVALID_THIS"); + +params = new URLSearchParams("a=1=2&b=3"); +assert.sameValue(params.size, 2); +assert.sameValue(params.get("a"), "1=2"); +assert.sameValue(params.get("b"), "3"); + +params = new URLSearchParams("&"); +assert.sameValue(params.size, 0); + +params = new URLSearchParams("& "); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(" "), ""); + +params = new URLSearchParams(" &"); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(" "), ""); + +params = new URLSearchParams("="); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(""), ""); + +params = new URLSearchParams("&=2"); +assert.sameValue(params.size, 1); +assert.sameValue(params.get(""), "2"); + +params = new URLSearchParams("?user=abc"); +assert.throwsNodeError(() => params.append(), TypeError, "ERR_MISSING_ARGS"); +params.append("query", "first"); +assert.sameValue(params.toString(), "user=abc&query=first"); + +params = new URLSearchParams("first=one&second=two&third=three"); +assert.throwsNodeError(() => params.delete(), TypeError, "ERR_MISSING_ARGS"); +params.delete("second", "fake-value"); +assert.sameValue(params.toString(), "first=one&second=two&third=three"); +params.delete("third", "three"); +assert.sameValue(params.toString(), "first=one&second=two"); +params.delete("second"); +assert.sameValue(params.toString(), "first=one"); + +params = new URLSearchParams("user=abc&query=xyz"); +assert.throwsNodeError(() => params.get(), TypeError, "ERR_MISSING_ARGS"); +assert.sameValue(params.get("user"), "abc"); +assert.sameValue(params.get("non-existant"), null); + +params = new URLSearchParams("query=first&query=second"); +assert.throwsNodeError(() => params.getAll(), TypeError, "ERR_MISSING_ARGS"); +const all = params.getAll("query"); +assert.sameValue(all.includes("first"), true); +assert.sameValue(all.includes("second"), true); +assert.sameValue(all.length, 2); +const getAllUndefined = params.getAll(undefined); +assert.sameValue(getAllUndefined.length, 0); +const getAllNonExistant = params.getAll("does_not_exists"); +assert.sameValue(getAllNonExistant.length, 0); + +params = new URLSearchParams("user=abc&query=xyz"); +assert.throwsNodeError(() => params.has(), TypeError, "ERR_MISSING_ARGS"); +assert.sameValue(params.has(undefined), false); +assert.sameValue(params.has("user"), true); +assert.sameValue(params.has("user", "abc"), true); +assert.sameValue(params.has("user", "abc", "extra-param"), true); +assert.sameValue(params.has("user", "efg"), false); +assert.sameValue(params.has("user", undefined), true); + +params = new URLSearchParams(); +params.append("foo", "bar"); +params.append("foo", "baz"); +params.append("abc", "def"); +assert.sameValue(params.toString(), "foo=bar&foo=baz&abc=def"); +params.set("foo", "def"); +params.set("xyz", "opq"); +assert.sameValue(params.toString(), "foo=def&abc=def&xyz=opq"); + +params = new URLSearchParams("query=first&query=second&user=abc&double=first,second"); +const URLSearchIteratorPrototype = params.entries().__proto__; +assert.sameValue(typeof URLSearchIteratorPrototype, "object"); + +assert.sameValue(params[Symbol.iterator], params.entries); + +{ + const entries = params.entries(); + assert.sameValue(entries.toString(), "[object URLSearchParams Iterator]"); + assert.sameValue(entries.__proto__, URLSearchIteratorPrototype); + + let item = entries.next(); + assert.sameValue(item.value.toString(), ["query", "first"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["query", "second"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["user", "abc"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value.toString(), ["double", "first,second"].toString()); + assert.sameValue(item.done, false); + + item = entries.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + +params = new URLSearchParams("query=first&query=second&user=abc"); +{ + const keys = params.keys(); + assert.sameValue(keys.__proto__, URLSearchIteratorPrototype); + + let item = keys.next(); + assert.sameValue(item.value, "query"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, "query"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, "user"); + assert.sameValue(item.done, false); + + item = keys.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + +params = new URLSearchParams("query=first&query=second&user=abc"); +{ + const values = params.values(); + assert.sameValue(values.__proto__, URLSearchIteratorPrototype); + + let item = values.next(); + assert.sameValue(item.value, "first"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, "second"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, "abc"); + assert.sameValue(item.done, false); + + item = values.next(); + assert.sameValue(item.value, undefined); + assert.sameValue(item.done, true); +} + + +params = new URLSearchParams("query[]=abc&type=search&query[]=123"); +params.sort(); +assert.sameValue(params.toString(), "query%5B%5D=abc&query%5B%5D=123&type=search"); + +params = new URLSearchParams("query=first&query=second&user=abc"); +assert.sameValue(params.size, 3); + +params = new URLSearchParams("%"); +assert.sameValue(params.has("%"), true); +assert.sameValue(params.toString(), "%25="); + +{ + const params = new URLSearchParams(""); + assert.sameValue(params.size, 0); + assert.sameValue(params.toString(), ""); + assert.sameValue(params.get(undefined), null); + params.set(undefined, true); + assert.sameValue(params.has(undefined), true); + assert.sameValue(params.has("undefined"), true); + assert.sameValue(params.get("undefined"), "true"); + assert.sameValue(params.get(undefined), "true"); + assert.sameValue(params.getAll(undefined).toString(), ["true"].toString()); + params.delete(undefined); + assert.sameValue(params.has(undefined), false); + assert.sameValue(params.has("undefined"), false); + + assert.sameValue(params.has(null), false); + params.set(null, "nullval"); + assert.sameValue(params.has(null), true); + assert.sameValue(params.has("null"), true); + assert.sameValue(params.get(null), "nullval"); + assert.sameValue(params.get("null"), "nullval"); + params.delete(null); + assert.sameValue(params.has(null), false); + assert.sameValue(params.has("null"), false); +} + +function* functionGeneratorExample() { + yield ["user", "abc"]; + yield ["query", "first"]; + yield ["query", "second"]; +} + +params = new URLSearchParams(functionGeneratorExample()); +assert.sameValue(params.toString(), "user=abc&query=first&query=second"); + +assert.sameValue(params.__proto__.constructor, URLSearchParams); +assert.sameValue(params instanceof URLSearchParams, true); + +{ + const params = new URLSearchParams("1=2&1=3"); + assert.sameValue(params.get(1), "2"); + assert.sameValue(params.getAll(1).toString(), ["2", "3"].toString()); + assert.sameValue(params.getAll("x").toString(), [].toString()); +} + +// Sync +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + assert.sameValue(params.size, 0); + url.search = "a=1"; + assert.sameValue(params.size, 1); + assert.sameValue(params.get("a"), "1"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + assert.sameValue(params.size, 1); + url.search = ""; + assert.sameValue(params.size, 0); + url.search = "b=2"; + assert.sameValue(params.size, 1); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + params.append("a", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1"); +} + +{ + const url = new URL("https://test.com/"); + url.searchParams.append("a", "1"); + url.searchParams.append("b", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1&b=1"); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + url.searchParams.append("a", "1"); + assert.sameValue(url.search, "?a=1"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + params.append("a", "2"); + assert.sameValue(url.search, "?a=1&a=2"); +} + +{ + const url = new URL("https://test.com/"); + const params = url.searchParams; + params.set("a", "1"); + assert.sameValue(url.search, "?a=1"); +} + +{ + const url = new URL("https://test.com/"); + url.searchParams.set("a", "1"); + url.searchParams.set("b", "1"); + assert.sameValue(url.toString(), "https://test.com/?a=1&b=1"); +} + +{ + const url = new URL("https://test.com/?a=1&b=2"); + const params = url.searchParams; + params.delete("a"); + assert.sameValue(url.search, "?b=2"); +} + +{ + const url = new URL("https://test.com/?b=2&a=1"); + const params = url.searchParams; + params.sort(); + assert.sameValue(url.search, "?a=1&b=2"); +} + +{ + const url = new URL("https://test.com/?a=1"); + const params = url.searchParams; + params.delete("a"); + assert.sameValue(url.search, ""); + + params.set("a", 2); + assert.sameValue(url.search, "?a=2"); +} + +// FAILING: no custom properties on wrapped Go structs +/* +{ + const params = new URLSearchParams(""); + assert.sameValue(Object.isExtensible(params), true); + assert.sameValue(Reflect.defineProperty(params, "customField", {value: 42, configurable: true}), true); + assert.sameValue(params.customField, 42); + const desc = Reflect.getOwnPropertyDescriptor(params, "customField"); + assert.sameValue(desc.value, 42); + assert.sameValue(desc.writable, false); + assert.sameValue(desc.enumerable, false); + assert.sameValue(desc.configurable, true); +} +*/ + +// Escape +{ + const myURL = new URL('https://example.org/abc?fo~o=~ba r%z'); + + assert.sameValue(myURL.search, "?fo~o=~ba%20r%z"); + + // Modify the URL via searchParams... + myURL.searchParams.sort(); + + assert.sameValue(myURL.search, "?fo%7Eo=%7Eba+r%25z"); +} diff --git a/script/modules/url/testdata/url_test.js b/script/modules/url/testdata/url_test.js new file mode 100644 index 00000000..a6ff43be --- /dev/null +++ b/script/modules/url/testdata/url_test.js @@ -0,0 +1,229 @@ +"use strict"; + +const assert = require("assert.js"); + +function testURLCtor(str, expected) { + assert.sameValue(new URL(str).toString(), expected); +} + +function testURLCtorBase(ref, base, expected, message) { + assert.sameValue(new URL(ref, base).toString(), expected, message); +} + +testURLCtorBase("https://example.org/", undefined, "https://example.org/"); +testURLCtorBase("/foo", "https://example.org/", "https://example.org/foo"); +testURLCtorBase("http://Example.com/", "https://example.org/", "http://example.com/"); +testURLCtorBase("https://Example.com/", "https://example.org/", "https://example.com/"); +testURLCtorBase("foo://Example.com/", "https://example.org/", "foo://Example.com/"); +testURLCtorBase("foo:Example.com/", "https://example.org/", "foo:Example.com/"); +testURLCtorBase("#hash", "https://example.org/", "https://example.org/#hash"); + +testURLCtor("HTTP://test.com", "http://test.com/"); +testURLCtor("HTTPS://á.com", "https://xn--1ca.com/"); +testURLCtor("HTTPS://á.com:123", "https://xn--1ca.com:123/"); +testURLCtor("https://test.com#asdfá", "https://test.com/#asdf%C3%A1"); +testURLCtor("HTTPS://á.com:123/á", "https://xn--1ca.com:123/%C3%A1"); +testURLCtor("fish://á.com", "fish://%C3%A1.com"); +testURLCtor("https://test.com/?a=1 /2", "https://test.com/?a=1%20/2"); +testURLCtor("https://test.com/á=1?á=1&ü=2#é", "https://test.com/%C3%A1=1?%C3%A1=1&%C3%BC=2#%C3%A9"); + +assert.throws(() => new URL("test"), TypeError); +assert.throws(() => new URL("ssh://EEE:ddd"), TypeError); + +{ + let u = new URL("https://example.org/"); + assert.sameValue(u.__proto__.constructor, URL); + assert.sameValue(u instanceof URL, true); +} + +{ + let u = new URL("https://example.org/"); + assert.sameValue(u.searchParams, u.searchParams); +} + +let myURL; + +// Hash +myURL = new URL("https://example.org/foo#bar"); +myURL.hash = "baz"; +assert.sameValue(myURL.href, "https://example.org/foo#baz"); + +myURL.hash = "#baz"; +assert.sameValue(myURL.href, "https://example.org/foo#baz"); + +myURL.hash = "#á=1 2"; +assert.sameValue(myURL.href, "https://example.org/foo#%C3%A1=1%202"); + +myURL.hash = "#a/#b"; +// FAILING: the second # gets escaped +//assert.sameValue(myURL.href, "https://example.org/foo#a/#b"); +assert.sameValue(myURL.search, ""); +// FAILING: the second # gets escaped +//assert.sameValue(myURL.hash, "#a/#b"); + +// Host +myURL = new URL("https://example.org:81/foo"); +myURL.host = "example.com:82"; +assert.sameValue(myURL.href, "https://example.com:82/foo"); + +// Hostname +myURL = new URL("https://example.org:81/foo"); +myURL.hostname = "example.com:82"; +assert.sameValue(myURL.href, "https://example.org:81/foo"); + +myURL.hostname = "á.com"; +assert.sameValue(myURL.href, "https://xn--1ca.com:81/foo"); + +// href +myURL = new URL("https://example.org/foo"); +myURL.href = "https://example.com/bar"; +assert.sameValue(myURL.href, "https://example.com/bar"); + +// Password +myURL = new URL("https://abc:xyz@example.com"); +myURL.password = "123"; +assert.sameValue(myURL.href, "https://abc:123@example.com/"); + +// pathname +myURL = new URL("https://example.org/abc/xyz?123"); +myURL.pathname = "/abcdef"; +assert.sameValue(myURL.href, "https://example.org/abcdef?123"); + +myURL.pathname = ""; +assert.sameValue(myURL.href, "https://example.org/?123"); + +myURL.pathname = "á"; +assert.sameValue(myURL.pathname, "/%C3%A1"); +assert.sameValue(myURL.href, "https://example.org/%C3%A1?123"); + +// port + +myURL = new URL("https://example.org:8888"); +assert.sameValue(myURL.port, "8888"); + +function testSetPort(port, expected) { + const url = new URL("https://example.org:8888"); + url.port = port; + assert.sameValue(url.port, expected); +} + +testSetPort(0, "0"); +testSetPort(-0, "0"); + +// Default ports are automatically transformed to the empty string +// (HTTPS protocol's default port is 443) +testSetPort("443", ""); +testSetPort(443, ""); + +// Empty string is the same as default port +testSetPort("", ""); + +// Completely invalid port strings are ignored +testSetPort("abcd", "8888"); +testSetPort("-123", ""); +testSetPort(-123, ""); +testSetPort(-123.45, ""); +testSetPort(undefined, "8888"); +testSetPort(null, "8888"); +testSetPort(+Infinity, "8888"); +testSetPort(-Infinity, "8888"); +testSetPort(NaN, "8888"); + +// Leading numbers are treated as a port number +testSetPort("5678abcd", "5678"); +testSetPort("a5678abcd", ""); + +// Non-integers are truncated +testSetPort(1234.5678, "1234"); + +// Out-of-range numbers which are not represented in scientific notation +// will be ignored. +testSetPort(1e10, "8888"); +testSetPort("123456", "8888"); +testSetPort(123456, "8888"); +testSetPort(4.567e21, "4"); + +// toString() takes precedence over valueOf(), even if it returns a valid integer +testSetPort( + { + toString() { + return "2"; + }, + valueOf() { + return 1; + }, + }, + "2" +); + +// Protocol +function testSetProtocol(url, protocol, expected) { + url.protocol = protocol; + assert.sameValue(url.protocol, expected); +} +testSetProtocol(new URL("https://example.org"), "ftp", "ftp:"); +testSetProtocol(new URL("https://example.org"), "ftp:", "ftp:"); +testSetProtocol(new URL("https://example.org"), "FTP:", "ftp:"); +testSetProtocol(new URL("https://example.org"), "ftp: blah", "ftp:"); +// special to non-special +testSetProtocol(new URL("https://example.org"), "foo", "https:"); +// non-special to special +testSetProtocol(new URL("fish://example.org"), "https", "fish:"); + +// Search +myURL = new URL("https://example.org/abc?123"); +myURL.search = "abc=xyz"; +assert.sameValue(myURL.href, "https://example.org/abc?abc=xyz"); + +myURL.search = "a=1 2"; +assert.sameValue(myURL.href, "https://example.org/abc?a=1%202"); + +myURL.search = "á=ú"; +assert.sameValue(myURL.search, "?%C3%A1=%C3%BA"); +assert.sameValue(myURL.href, "https://example.org/abc?%C3%A1=%C3%BA"); + +myURL.hash = "hash"; +myURL.search = "a=#b"; +assert.sameValue(myURL.href, "https://example.org/abc?a=%23b#hash"); +assert.sameValue(myURL.search, "?a=%23b"); +assert.sameValue(myURL.hash, "#hash"); + +// Username +myURL = new URL("https://abc:xyz@example.com/"); +myURL.username = "123"; +assert.sameValue(myURL.href, "https://123:xyz@example.com/"); + +// Origin, read-only +assert.throws(() => { + myURL.origin = "abc"; +}, TypeError); + +// href +myURL = new URL("https://example.org"); +myURL.href = "https://example.com"; +assert.sameValue(myURL.href, "https://example.com/"); + +assert.throws(() => { + myURL.href = "test"; +}, TypeError); + +// Search Params +myURL = new URL("https://example.com/"); +myURL.searchParams.append("user", "abc"); +assert.sameValue(myURL.toString(), "https://example.com/?user=abc"); +myURL.searchParams.append("first", "one"); +assert.sameValue(myURL.toString(), "https://example.com/?user=abc&first=one"); +myURL.searchParams.delete("user"); +assert.sameValue(myURL.toString(), "https://example.com/?first=one"); + +{ + const url = require("url"); + + assert.sameValue(url.domainToASCII('español.com'), "xn--espaol-zwa.com"); + assert.sameValue(url.domainToASCII('中文.com'), "xn--fiq228c.com"); + assert.sameValue(url.domainToASCII('xn--iñvalid.com'), ""); + + assert.sameValue(url.domainToUnicode('xn--espaol-zwa.com'), "español.com"); + assert.sameValue(url.domainToUnicode('xn--fiq228c.com'), "中文.com"); + assert.sameValue(url.domainToUnicode('xn--iñvalid.com'), ""); +} diff --git a/script/modules/url/url.go b/script/modules/url/url.go new file mode 100644 index 00000000..7b442ded --- /dev/null +++ b/script/modules/url/url.go @@ -0,0 +1,315 @@ +package url + +import ( + "net" + "net/url" + "strings" + + "github.com/sagernet/sing-box/script/jsc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/dop251/goja" + "golang.org/x/net/idna" +) + +type URL struct { + class jsc.Class[*Module, *URL] + url *url.URL + params *URLSearchParams + paramsValue goja.Value +} + +func newURL(c jsc.Class[*Module, *URL], call goja.ConstructorCall) *URL { + var ( + u, base *url.URL + err error + ) + switch argURL := call.Argument(0).Export().(type) { + case *URL: + u = argURL.url + default: + u, err = parseURL(call.Argument(0).String()) + if err != nil { + panic(c.Runtime().NewGoError(E.Cause(err, "parse URL"))) + } + } + if len(call.Arguments) == 2 { + switch argBaseURL := call.Argument(1).Export().(type) { + case *URL: + base = argBaseURL.url + default: + base, err = parseURL(call.Argument(1).String()) + if err != nil { + panic(c.Runtime().NewGoError(E.Cause(err, "parse base URL"))) + } + } + } + if base != nil { + u = base.ResolveReference(u) + } + return &URL{class: c, url: u} +} + +func createURL(module *Module) jsc.Class[*Module, *URL] { + class := jsc.NewClass[*Module, *URL](module) + class.DefineConstructor(newURL) + class.DefineField("hash", (*URL).getHash, (*URL).setHash) + class.DefineField("host", (*URL).getHost, (*URL).setHost) + class.DefineField("hostname", (*URL).getHostName, (*URL).setHostName) + class.DefineField("href", (*URL).getHref, (*URL).setHref) + class.DefineField("origin", (*URL).getOrigin, nil) + class.DefineField("password", (*URL).getPassword, (*URL).setPassword) + class.DefineField("pathname", (*URL).getPathname, (*URL).setPathname) + class.DefineField("port", (*URL).getPort, (*URL).setPort) + class.DefineField("protocol", (*URL).getProtocol, (*URL).setProtocol) + class.DefineField("search", (*URL).getSearch, (*URL).setSearch) + class.DefineField("searchParams", (*URL).getSearchParams, (*URL).setSearchParams) + class.DefineField("username", (*URL).getUsername, (*URL).setUsername) + class.DefineMethod("toString", (*URL).toString) + class.DefineMethod("toJSON", (*URL).toJSON) + class.DefineStaticMethod("canParse", canParse) + // class.DefineStaticMethod("createObjectURL", createObjectURL) + class.DefineStaticMethod("parse", parse) + // class.DefineStaticMethod("revokeObjectURL", revokeObjectURL) + return class +} + +func canParse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any { + switch call.Argument(0).Export().(type) { + case *URL: + default: + _, err := parseURL(call.Argument(0).String()) + if err != nil { + return false + } + } + if len(call.Arguments) == 2 { + switch call.Argument(1).Export().(type) { + case *URL: + default: + _, err := parseURL(call.Argument(1).String()) + if err != nil { + return false + } + } + } + return true +} + +func parse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any { + var ( + u, base *url.URL + err error + ) + switch argURL := call.Argument(0).Export().(type) { + case *URL: + u = argURL.url + default: + u, err = parseURL(call.Argument(0).String()) + if err != nil { + return goja.Null() + } + } + if len(call.Arguments) == 2 { + switch argBaseURL := call.Argument(1).Export().(type) { + case *URL: + base = argBaseURL.url + default: + base, err = parseURL(call.Argument(1).String()) + if err != nil { + return goja.Null() + } + } + } + if base != nil { + u = base.ResolveReference(u) + } + return &URL{class: class, url: u} +} + +func (r *URL) getHash() any { + if r.url.Fragment != "" { + return "#" + r.url.EscapedFragment() + } + return "" +} + +func (r *URL) setHash(value goja.Value) { + r.url.RawFragment = strings.TrimPrefix(value.String(), "#") +} + +func (r *URL) getHost() any { + return r.url.Host +} + +func (r *URL) setHost(value goja.Value) { + r.url.Host = strings.TrimSuffix(value.String(), ":") +} + +func (r *URL) getHostName() any { + return r.url.Hostname() +} + +func (r *URL) setHostName(value goja.Value) { + r.url.Host = joinHostPort(value.String(), r.url.Port()) +} + +func (r *URL) getHref() any { + return r.url.String() +} + +func (r *URL) setHref(value goja.Value) { + newURL, err := url.Parse(value.String()) + if err != nil { + panic(r.class.Runtime().NewGoError(err)) + } + r.url = newURL + r.params = nil +} + +func (r *URL) getOrigin() any { + return r.url.Scheme + "://" + r.url.Host +} + +func (r *URL) getPassword() any { + if r.url.User != nil { + password, _ := r.url.User.Password() + return password + } + return "" +} + +func (r *URL) setPassword(value goja.Value) { + if r.url.User == nil { + r.url.User = url.UserPassword("", value.String()) + } else { + r.url.User = url.UserPassword(r.url.User.Username(), value.String()) + } +} + +func (r *URL) getPathname() any { + return r.url.EscapedPath() +} + +func (r *URL) setPathname(value goja.Value) { + r.url.RawPath = value.String() +} + +func (r *URL) getPort() any { + return r.url.Port() +} + +func (r *URL) setPort(value goja.Value) { + r.url.Host = joinHostPort(r.url.Hostname(), value.String()) +} + +func (r *URL) getProtocol() any { + return r.url.Scheme + ":" +} + +func (r *URL) setProtocol(value goja.Value) { + r.url.Scheme = strings.TrimSuffix(value.String(), ":") +} + +func (r *URL) getSearch() any { + if r.params != nil { + if len(r.params.params) > 0 { + return "?" + generateQuery(r.params.params) + } + } else if r.url.RawQuery != "" { + return "?" + r.url.RawQuery + } + return "" +} + +func (r *URL) setSearch(value goja.Value) { + params, err := parseQuery(value.String()) + if err == nil { + if r.params != nil { + r.params.params = params + } else { + r.url.RawQuery = generateQuery(params) + } + } +} + +func (r *URL) getSearchParams() any { + var params []searchParam + if r.url.RawQuery != "" { + params, _ = parseQuery(r.url.RawQuery) + } + if r.params == nil { + r.params = &URLSearchParams{ + class: r.class.Module().classURLSearchParams, + params: params, + } + r.paramsValue = r.class.Module().classURLSearchParams.New(r.params) + } + return r.paramsValue +} + +func (r *URL) setSearchParams(value goja.Value) { + if params, ok := value.Export().(*URLSearchParams); ok { + r.params = params + r.paramsValue = value + } +} + +func (r *URL) getUsername() any { + if r.url.User != nil { + return r.url.User.Username() + } + return "" +} + +func (r *URL) setUsername(value goja.Value) { + if r.url.User == nil { + r.url.User = url.User(value.String()) + } else { + password, _ := r.url.User.Password() + r.url.User = url.UserPassword(value.String(), password) + } +} + +func (r *URL) toString(call goja.FunctionCall) any { + if r.params != nil { + r.url.RawQuery = generateQuery(r.params.params) + } + return r.url.String() +} + +func (r *URL) toJSON(call goja.FunctionCall) any { + return r.toString(call) +} + +func parseURL(s string) (*url.URL, error) { + u, err := url.Parse(s) + if err != nil { + return nil, E.Cause(err, "invalid URL") + } + switch u.Scheme { + case "https", "http", "ftp", "wss", "ws": + if u.Path == "" { + u.Path = "/" + } + hostname := u.Hostname() + asciiHostname, err := idna.Punycode.ToASCII(strings.ToLower(hostname)) + if err != nil { + return nil, E.Cause(err, "invalid hostname") + } + if asciiHostname != hostname { + u.Host = joinHostPort(asciiHostname, u.Port()) + } + } + if u.RawQuery != "" { + u.RawQuery = escape(u.RawQuery, &tblEscapeURLQuery, false) + } + return u, nil +} + +func joinHostPort(hostname, port string) string { + if port == "" { + return hostname + } + return net.JoinHostPort(hostname, port) +} diff --git a/script/modules/url/url_search_params.go b/script/modules/url/url_search_params.go new file mode 100644 index 00000000..945f076f --- /dev/null +++ b/script/modules/url/url_search_params.go @@ -0,0 +1,244 @@ +package url + +import ( + "fmt" + "net/url" + "sort" + "strings" + + "github.com/sagernet/sing-box/script/jsc" + F "github.com/sagernet/sing/common/format" + + "github.com/dop251/goja" +) + +type URLSearchParams struct { + class jsc.Class[*Module, *URLSearchParams] + params []searchParam +} + +func createURLSearchParams(module *Module) jsc.Class[*Module, *URLSearchParams] { + class := jsc.NewClass[*Module, *URLSearchParams](module) + class.DefineConstructor(newURLSearchParams) + class.DefineField("size", (*URLSearchParams).getSize, nil) + class.DefineMethod("append", (*URLSearchParams).append) + class.DefineMethod("delete", (*URLSearchParams).delete) + class.DefineMethod("entries", (*URLSearchParams).entries) + class.DefineMethod("forEach", (*URLSearchParams).forEach) + class.DefineMethod("get", (*URLSearchParams).get) + class.DefineMethod("getAll", (*URLSearchParams).getAll) + class.DefineMethod("has", (*URLSearchParams).has) + class.DefineMethod("keys", (*URLSearchParams).keys) + class.DefineMethod("set", (*URLSearchParams).set) + class.DefineMethod("sort", (*URLSearchParams).sort) + class.DefineMethod("toString", (*URLSearchParams).toString) + class.DefineMethod("values", (*URLSearchParams).values) + return class +} + +func newURLSearchParams(class jsc.Class[*Module, *URLSearchParams], call goja.ConstructorCall) *URLSearchParams { + var ( + params []searchParam + err error + ) + switch argInit := call.Argument(0).Export().(type) { + case *URLSearchParams: + params = argInit.params + case string: + params, err = parseQuery(argInit) + if err != nil { + panic(class.Runtime().NewGoError(err)) + } + case [][]string: + for _, pair := range argInit { + if len(pair) != 2 { + panic(class.Runtime().NewTypeError("Each query pair must be an iterable [name, value] tuple")) + } + params = append(params, searchParam{pair[0], pair[1]}) + } + case map[string]any: + for name, value := range argInit { + stringValue, isString := value.(string) + if !isString { + panic(class.Runtime().NewTypeError("Invalid query value")) + } + params = append(params, searchParam{name, stringValue}) + } + } + return &URLSearchParams{class, params} +} + +func (s *URLSearchParams) getSize() any { + return len(s.params) +} + +func (s *URLSearchParams) append(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + value := call.Argument(1).String() + s.params = append(s.params, searchParam{name, value}) + return goja.Undefined() +} + +func (s *URLSearchParams) delete(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + argValue := call.Argument(1) + if !jsc.IsNil(argValue) { + value := argValue.String() + for i, param := range s.params { + if param.Key == name && param.Value == value { + s.params = append(s.params[:i], s.params[i+1:]...) + break + } + } + } else { + for i, param := range s.params { + if param.Key == name { + s.params = append(s.params[:i], s.params[i+1:]...) + break + } + } + } + return goja.Undefined() +} + +func (s *URLSearchParams) entries(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return s.class.Runtime().NewArray(this.Key, this.Value) + }) +} + +func (s *URLSearchParams) forEach(call goja.FunctionCall) any { + callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(0), "callbackFn") + thisValue := call.Argument(1) + for _, param := range s.params { + for _, value := range param.Value { + _, err := callback(thisValue, s.class.Runtime().ToValue(value), s.class.Runtime().ToValue(param.Key), call.This) + if err != nil { + panic(s.class.Runtime().NewGoError(err)) + } + } + } + return goja.Undefined() +} + +func (s *URLSearchParams) get(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + for _, param := range s.params { + if param.Key == name { + return param.Value + } + } + return goja.Null() +} + +func (s *URLSearchParams) getAll(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + var values []any + for _, param := range s.params { + if param.Key == name { + values = append(values, param.Value) + } + } + return s.class.Runtime().NewArray(values...) +} + +func (s *URLSearchParams) has(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + argValue := call.Argument(1) + if !jsc.IsNil(argValue) { + value := argValue.String() + for _, param := range s.params { + if param.Key == name && param.Value == value { + return true + } + } + } else { + for _, param := range s.params { + if param.Key == name { + return true + } + } + } + return false +} + +func (s *URLSearchParams) keys(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return this.Key + }) +} + +func (s *URLSearchParams) set(call goja.FunctionCall) any { + name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false) + value := call.Argument(1).String() + for i, param := range s.params { + if param.Key == name { + s.params[i].Value = value + return goja.Undefined() + } + } + s.params = append(s.params, searchParam{name, value}) + return goja.Undefined() +} + +func (s *URLSearchParams) sort(call goja.FunctionCall) any { + sort.SliceStable(s.params, func(i, j int) bool { + return s.params[i].Key < s.params[j].Key + }) + return goja.Undefined() +} + +func (s *URLSearchParams) toString(call goja.FunctionCall) any { + return generateQuery(s.params) +} + +func (s *URLSearchParams) values(call goja.FunctionCall) any { + return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any { + return this.Value + }) +} + +type searchParam struct { + Key string + Value string +} + +func parseQuery(query string) (params []searchParam, err error) { + query = strings.TrimPrefix(query, "?") + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + err = fmt.Errorf("invalid semicolon separator in query") + continue + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + key, err1 := url.QueryUnescape(key) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + value, err1 = url.QueryUnescape(value) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + params = append(params, searchParam{key, value}) + } + return +} + +func generateQuery(params []searchParam) string { + var parts []string + for _, param := range params { + parts = append(parts, F.ToString(param.Key, "=", url.QueryEscape(param.Value))) + } + return strings.Join(parts, "&") +} diff --git a/script/runtime.go b/script/runtime.go new file mode 100644 index 00000000..0da49018 --- /dev/null +++ b/script/runtime.go @@ -0,0 +1,47 @@ +package script + +import ( + "context" + + "github.com/sagernet/sing-box/script/modules/boxctx" + "github.com/sagernet/sing-box/script/modules/console" + "github.com/sagernet/sing-box/script/modules/eventloop" + "github.com/sagernet/sing-box/script/modules/require" + "github.com/sagernet/sing-box/script/modules/surge" + "github.com/sagernet/sing-box/script/modules/url" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/common/ntp" + + "github.com/dop251/goja" + "github.com/dop251/goja/parser" +) + +func NewRuntime(ctx context.Context, cancel context.CancelCauseFunc) *goja.Runtime { + vm := goja.New() + if timeFunc := ntp.TimeFuncFromContext(ctx); timeFunc != nil { + vm.SetTimeSource(timeFunc) + } + vm.SetParserOptions(parser.WithDisableSourceMaps) + registry := require.NewRegistry(require.WithLoader(func(path string) ([]byte, error) { + return nil, E.New("unsupported usage") + })) + registry.Enable(vm) + registry.RegisterNodeModule(console.ModuleName, console.Require) + registry.RegisterNodeModule(url.ModuleName, url.Require) + registry.RegisterNativeModule(boxctx.ModuleName, boxctx.Require) + registry.RegisterNativeModule(surge.ModuleName, surge.Require) + console.Enable(vm) + url.Enable(vm) + eventloop.Enable(vm, cancel) + return vm +} + +func SetModules(runtime *goja.Runtime, ctx context.Context, logger logger.ContextLogger, errorHandler func(error), tag string) { + boxctx.Enable(runtime, &boxctx.Context{ + Context: ctx, + Logger: logger, + Tag: tag, + ErrorHandler: errorHandler, + }) +} diff --git a/script/script.go b/script/script.go index 442e2620..7ed49a43 100644 --- a/script/script.go +++ b/script/script.go @@ -12,14 +12,8 @@ import ( func NewScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) { switch options.Type { - case C.ScriptTypeSurgeGeneric: - return NewSurgeGenericScript(ctx, logger, options) - case C.ScriptTypeSurgeHTTPRequest: - return NewSurgeHTTPRequestScript(ctx, logger, options) - case C.ScriptTypeSurgeHTTPResponse: - return NewSurgeHTTPResponseScript(ctx, logger, options) - case C.ScriptTypeSurgeCron: - return NewSurgeCronScript(ctx, logger, options) + case C.ScriptTypeSurge: + return NewSurgeScript(ctx, logger, options) default: return nil, E.New("unknown script type: ", options.Type) } diff --git a/script/script_surge.go b/script/script_surge.go new file mode 100644 index 00000000..6026cbb2 --- /dev/null +++ b/script/script_surge.go @@ -0,0 +1,345 @@ +package script + +import ( + "context" + "net/http" + "sync" + "time" + "unsafe" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/script/jsc" + "github.com/sagernet/sing-box/script/modules/surge" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/logger" + + "github.com/adhocore/gronx" + "github.com/dop251/goja" +) + +const defaultSurgeScriptTimeout = 10 * time.Second + +var _ adapter.SurgeScript = (*SurgeScript)(nil) + +type SurgeScript struct { + ctx context.Context + logger logger.ContextLogger + tag string + source Source + + cronExpression string + cronTimeout time.Duration + cronArguments []string + cronTimer *time.Timer + cronDone chan struct{} +} + +func NewSurgeScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) { + source, err := NewSource(ctx, logger, options) + if err != nil { + return nil, err + } + cronOptions := common.PtrValueOrDefault(options.SurgeOptions.CronOptions) + if cronOptions.Expression != "" { + if !gronx.IsValid(cronOptions.Expression) { + return nil, E.New("invalid cron expression: ", cronOptions.Expression) + } + } + return &SurgeScript{ + ctx: ctx, + logger: logger, + tag: options.Tag, + source: source, + cronExpression: cronOptions.Expression, + cronTimeout: time.Duration(cronOptions.Timeout), + cronArguments: cronOptions.Arguments, + cronDone: make(chan struct{}), + }, nil +} + +func (s *SurgeScript) Type() string { + return C.ScriptTypeSurge +} + +func (s *SurgeScript) Tag() string { + return s.tag +} + +func (s *SurgeScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { + return s.source.StartContext(ctx, startContext) +} + +func (s *SurgeScript) PostStart() error { + err := s.source.PostStart() + if err != nil { + return err + } + if s.cronExpression != "" { + go s.loopCronEvents() + } + return nil +} + +func (s *SurgeScript) loopCronEvents() { + s.logger.Debug("starting event") + err := s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments) + if err != nil { + s.logger.Error(E.Cause(err, "running event")) + } + nextTick, err := gronx.NextTick(s.cronExpression, false) + if err != nil { + s.logger.Error(E.Cause(err, "determine next tick")) + return + } + s.cronTimer = time.NewTimer(nextTick.Sub(time.Now())) + s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat)) + for { + select { + case <-s.ctx.Done(): + return + case <-s.cronDone: + return + case <-s.cronTimer.C: + s.logger.Debug("starting event") + err = s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments) + if err != nil { + s.logger.Error(E.Cause(err, "running event")) + } + nextTick, err = gronx.NextTick(s.cronExpression, false) + if err != nil { + s.logger.Error(E.Cause(err, "determine next tick")) + return + } + s.cronTimer.Reset(nextTick.Sub(time.Now())) + s.logger.Debug("configured next event at: ", nextTick) + } + } +} + +func (s *SurgeScript) Close() error { + err := s.source.Close() + if s.cronTimer != nil { + s.cronTimer.Stop() + close(s.cronDone) + } + return err +} + +func (s *SurgeScript) ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error { + program := s.source.Program() + if program == nil { + return E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, scriptType, arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + doneFunc() + return goja.Undefined() + }) + var ( + access sync.Mutex + scriptErr error + ) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + } + return scriptErr +} + +func (s *SurgeScript) ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPRequestScriptResult, error) { + program := s.source.Program() + if program == nil { + return nil, E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, "http-request", arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + runtime.ClearInterrupt() + requestObject := runtime.NewObject() + requestObject.Set("url", request.URL.String()) + requestObject.Set("method", request.Method) + requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header)) + if !binaryBody { + requestObject.Set("body", string(body)) + } else { + requestObject.Set("body", jsc.NewUint8Array(runtime, body)) + } + requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) + runtime.Set("request", requestObject) + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + var ( + access sync.Mutex + result adapter.HTTPRequestScriptResult + scriptErr error + ) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + defer doneFunc() + resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true) + if resultObject == nil { + panic(runtime.NewGoError(E.New("request rejected by script"))) + } + access.Lock() + defer access.Unlock() + result.URL = jsc.AssertString(runtime, resultObject.Get("url"), "url", true) + result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers") + result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true) + responseObject := jsc.AssertObject(runtime, resultObject.Get("response"), "response", true) + if responseObject != nil { + result.Response = &adapter.HTTPRequestScriptResponse{ + Status: int(jsc.AssertInt(runtime, responseObject.Get("status"), "status", true)), + Headers: jsc.AssertHTTPHeader(runtime, responseObject.Get("headers"), "headers"), + Body: jsc.AssertStringBinary(runtime, responseObject.Get("body"), "body", true), + } + } + return goja.Undefined() + }) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return nil, ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + } + return &result, scriptErr +} + +func (s *SurgeScript) ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPResponseScriptResult, error) { + program := s.source.Program() + if program == nil { + return nil, E.New("invalid script") + } + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + runtime := NewRuntime(ctx, cancel) + SetModules(runtime, ctx, s.logger, cancel, s.tag) + surge.Enable(runtime, "http-response", arguments) + if timeout == 0 { + timeout = defaultSurgeScriptTimeout + } + ctx, timeoutCancel := context.WithTimeout(ctx, timeout) + defer timeoutCancel() + runtime.ClearInterrupt() + requestObject := runtime.NewObject() + requestObject.Set("url", request.URL.String()) + requestObject.Set("method", request.Method) + requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header)) + requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) + runtime.Set("request", requestObject) + + responseObject := runtime.NewObject() + responseObject.Set("status", response.StatusCode) + responseObject.Set("headers", jsc.HeadersToValue(runtime, response.Header)) + if !binaryBody { + responseObject.Set("body", string(body)) + } else { + responseObject.Set("body", jsc.NewUint8Array(runtime, body)) + } + runtime.Set("response", responseObject) + + done := make(chan struct{}) + doneFunc := common.OnceFunc(func() { + close(done) + }) + var ( + access sync.Mutex + result adapter.HTTPResponseScriptResult + scriptErr error + ) + runtime.Set("done", func(call goja.FunctionCall) goja.Value { + resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true) + if resultObject == nil { + panic(runtime.NewGoError(E.New("response rejected by script"))) + } + access.Lock() + defer access.Unlock() + result.Status = int(jsc.AssertInt(runtime, resultObject.Get("status"), "status", true)) + result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers") + result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true) + doneFunc() + return goja.Undefined() + }) + go func() { + _, err := runtime.RunProgram(program) + if err != nil { + access.Lock() + scriptErr = err + access.Unlock() + doneFunc() + } + }() + select { + case <-ctx.Done(): + runtime.Interrupt(ctx.Err()) + return nil, ctx.Err() + case <-done: + access.Lock() + defer access.Unlock() + if scriptErr != nil { + runtime.Interrupt(scriptErr) + } else { + runtime.Interrupt("script done") + } + return &result, scriptErr + } +} diff --git a/script/script_surge_cron.go b/script/script_surge_cron.go deleted file mode 100644 index a123ce9e..00000000 --- a/script/script_surge_cron.go +++ /dev/null @@ -1,119 +0,0 @@ -package script - -import ( - "context" - "time" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - - "github.com/adhocore/gronx" -) - -var _ adapter.GenericScript = (*SurgeCronScript)(nil) - -type SurgeCronScript struct { - GenericScript - ctx context.Context - expression string - timer *time.Timer -} - -func NewSurgeCronScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeCronScript, error) { - source, err := NewSource(ctx, logger, options) - if err != nil { - return nil, err - } - if !gronx.IsValid(options.CronOptions.Expression) { - return nil, E.New("invalid cron expression: ", options.CronOptions.Expression) - } - return &SurgeCronScript{ - GenericScript: GenericScript{ - logger: logger, - tag: options.Tag, - timeout: time.Duration(options.Timeout), - arguments: options.Arguments, - source: source, - }, - ctx: ctx, - expression: options.CronOptions.Expression, - }, nil -} - -func (s *SurgeCronScript) Type() string { - return C.ScriptTypeSurgeCron -} - -func (s *SurgeCronScript) Tag() string { - return s.tag -} - -func (s *SurgeCronScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { - return s.source.StartContext(ctx, startContext) -} - -func (s *SurgeCronScript) PostStart() error { - err := s.source.PostStart() - if err != nil { - return err - } - go s.loop() - return nil -} - -func (s *SurgeCronScript) loop() { - s.logger.Debug("starting event") - err := s.Run(s.ctx) - if err != nil { - s.logger.Error(E.Cause(err, "running event")) - } - nextTick, err := gronx.NextTick(s.expression, false) - if err != nil { - s.logger.Error(E.Cause(err, "determine next tick")) - return - } - s.timer = time.NewTimer(nextTick.Sub(time.Now())) - s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat)) - for { - select { - case <-s.ctx.Done(): - return - case <-s.timer.C: - s.logger.Debug("starting event") - err = s.Run(s.ctx) - if err != nil { - s.logger.Error(E.Cause(err, "running event")) - } - nextTick, err = gronx.NextTick(s.expression, false) - if err != nil { - s.logger.Error(E.Cause(err, "determine next tick")) - return - } - s.timer.Reset(nextTick.Sub(time.Now())) - s.logger.Debug("next event at: ", nextTick) - } - } -} - -func (s *SurgeCronScript) Close() error { - return s.source.Close() -} - -func (s *SurgeCronScript) Run(ctx context.Context) error { - program := s.source.Program() - if program == nil { - return E.New("invalid script") - } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - vm := NewRuntime(ctx, s.logger, cancel) - err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments) - if err != nil { - return err - } - return ExecuteSurgeGeneral(vm, program, ctx, s.timeout) -} diff --git a/script/script_surge_generic.go b/script/script_surge_generic.go deleted file mode 100644 index b1a8ccbd..00000000 --- a/script/script_surge_generic.go +++ /dev/null @@ -1,183 +0,0 @@ -package script - -import ( - "context" - "runtime" - "time" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/experimental/locale" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/script/jsc" - "github.com/sagernet/sing-box/script/modules/console" - "github.com/sagernet/sing-box/script/modules/eventloop" - "github.com/sagernet/sing-box/script/modules/require" - "github.com/sagernet/sing-box/script/modules/sghttp" - "github.com/sagernet/sing-box/script/modules/sgnotification" - "github.com/sagernet/sing-box/script/modules/sgstore" - "github.com/sagernet/sing-box/script/modules/sgutils" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/logger" - "github.com/sagernet/sing/common/ntp" - - "github.com/dop251/goja" - "github.com/dop251/goja/parser" -) - -const defaultScriptTimeout = 10 * time.Second - -var _ adapter.GenericScript = (*GenericScript)(nil) - -type GenericScript struct { - logger logger.ContextLogger - tag string - timeout time.Duration - arguments []any - source Source -} - -func NewSurgeGenericScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*GenericScript, error) { - source, err := NewSource(ctx, logger, options) - if err != nil { - return nil, err - } - return &GenericScript{ - logger: logger, - tag: options.Tag, - timeout: time.Duration(options.Timeout), - arguments: options.Arguments, - source: source, - }, nil -} - -func (s *GenericScript) Type() string { - return C.ScriptTypeSurgeGeneric -} - -func (s *GenericScript) Tag() string { - return s.tag -} - -func (s *GenericScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { - return s.source.StartContext(ctx, startContext) -} - -func (s *GenericScript) PostStart() error { - return s.source.PostStart() -} - -func (s *GenericScript) Close() error { - return s.source.Close() -} - -func (s *GenericScript) Run(ctx context.Context) error { - program := s.source.Program() - if program == nil { - return E.New("invalid script") - } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - vm := NewRuntime(ctx, s.logger, cancel) - err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments) - if err != nil { - return err - } - return ExecuteSurgeGeneral(vm, program, ctx, s.timeout) -} - -func NewRuntime(ctx context.Context, logger logger.ContextLogger, cancel context.CancelCauseFunc) *goja.Runtime { - vm := goja.New() - if timeFunc := ntp.TimeFuncFromContext(ctx); timeFunc != nil { - vm.SetTimeSource(timeFunc) - } - vm.SetParserOptions(parser.WithDisableSourceMaps) - registry := require.NewRegistry(require.WithLoader(func(path string) ([]byte, error) { - return nil, E.New("unsupported usage") - })) - registry.Enable(vm) - registry.RegisterNodeModule(console.ModuleName, console.Require(ctx, logger)) - console.Enable(vm) - eventloop.Enable(vm, cancel) - return vm -} - -func SetSurgeModules(vm *goja.Runtime, ctx context.Context, logger logger.Logger, errorHandler func(error), tag string, scriptType string, arguments []any) error { - script := vm.NewObject() - script.Set("name", F.ToString("script:", tag)) - script.Set("startTime", jsc.TimeToValue(vm, time.Now())) - script.Set("type", scriptType) - vm.Set("$script", script) - - environment := vm.NewObject() - var system string - switch runtime.GOOS { - case "ios": - system = "iOS" - case "darwin": - system = "macOS" - case "tvos": - system = "tvOS" - case "linux": - system = "Linux" - case "android": - system = "Android" - case "windows": - system = "Windows" - default: - system = runtime.GOOS - } - environment.Set("system", system) - environment.Set("surge-build", "N/A") - environment.Set("surge-version", "sing-box "+C.Version) - environment.Set("language", locale.Current().Locale) - environment.Set("device-model", "N/A") - vm.Set("$environment", environment) - - sgstore.Enable(vm, ctx) - sgutils.Enable(vm) - sghttp.Enable(vm, ctx, errorHandler) - sgnotification.Enable(vm, ctx, logger) - - vm.Set("$argument", arguments) - return nil -} - -func ExecuteSurgeGeneral(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration) error { - if timeout == 0 { - timeout = defaultScriptTimeout - } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - vm.ClearInterrupt() - done := make(chan struct{}) - doneFunc := common.OnceFunc(func() { - close(done) - }) - vm.Set("done", func(call goja.FunctionCall) goja.Value { - doneFunc() - return goja.Undefined() - }) - var err error - go func() { - _, err = vm.RunProgram(program) - if err != nil { - doneFunc() - } - }() - select { - case <-ctx.Done(): - vm.Interrupt(ctx.Err()) - return ctx.Err() - case <-done: - if err != nil { - vm.Interrupt(err) - } else { - vm.Interrupt("script done") - } - } - return err -} diff --git a/script/script_surge_http_request.go b/script/script_surge_http_request.go deleted file mode 100644 index ccb5ca67..00000000 --- a/script/script_surge_http_request.go +++ /dev/null @@ -1,165 +0,0 @@ -package script - -import ( - "context" - "net/http" - "regexp" - "time" - "unsafe" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/script/jsc" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/logger" - - "github.com/dop251/goja" -) - -var _ adapter.HTTPRequestScript = (*SurgeHTTPRequestScript)(nil) - -type SurgeHTTPRequestScript struct { - GenericScript - pattern *regexp.Regexp - requiresBody bool - maxSize int64 - binaryBodyMode bool -} - -func NewSurgeHTTPRequestScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeHTTPRequestScript, error) { - source, err := NewSource(ctx, logger, options) - if err != nil { - return nil, err - } - pattern, err := regexp.Compile(options.HTTPOptions.Pattern) - if err != nil { - return nil, E.Cause(err, "parse pattern") - } - return &SurgeHTTPRequestScript{ - GenericScript: GenericScript{ - logger: logger, - tag: options.Tag, - timeout: time.Duration(options.Timeout), - arguments: options.Arguments, - source: source, - }, - pattern: pattern, - requiresBody: options.HTTPOptions.RequiresBody, - maxSize: options.HTTPOptions.MaxSize, - binaryBodyMode: options.HTTPOptions.BinaryBodyMode, - }, nil -} - -func (s *SurgeHTTPRequestScript) Type() string { - return C.ScriptTypeSurgeHTTPRequest -} - -func (s *SurgeHTTPRequestScript) Tag() string { - return s.tag -} - -func (s *SurgeHTTPRequestScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { - return s.source.StartContext(ctx, startContext) -} - -func (s *SurgeHTTPRequestScript) PostStart() error { - return s.source.PostStart() -} - -func (s *SurgeHTTPRequestScript) Close() error { - return s.source.Close() -} - -func (s *SurgeHTTPRequestScript) Match(requestURL string) bool { - return s.pattern.MatchString(requestURL) -} - -func (s *SurgeHTTPRequestScript) RequiresBody() bool { - return s.requiresBody -} - -func (s *SurgeHTTPRequestScript) MaxSize() int64 { - return s.maxSize -} - -func (s *SurgeHTTPRequestScript) Run(ctx context.Context, request *http.Request, body []byte) (*adapter.HTTPRequestScriptResult, error) { - program := s.source.Program() - if program == nil { - return nil, E.New("invalid script") - } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - vm := NewRuntime(ctx, s.logger, cancel) - err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments) - if err != nil { - return nil, err - } - return ExecuteSurgeHTTPRequest(vm, program, ctx, s.timeout, request, body, s.binaryBodyMode) -} - -func ExecuteSurgeHTTPRequest(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool) (*adapter.HTTPRequestScriptResult, error) { - if timeout == 0 { - timeout = defaultScriptTimeout - } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - vm.ClearInterrupt() - requestObject := vm.NewObject() - requestObject.Set("url", request.URL.String()) - requestObject.Set("method", request.Method) - requestObject.Set("headers", jsc.HeadersToValue(vm, request.Header)) - if !binaryBody { - requestObject.Set("body", string(body)) - } else { - requestObject.Set("body", jsc.NewUint8Array(vm, body)) - } - requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) - vm.Set("request", requestObject) - done := make(chan struct{}) - doneFunc := common.OnceFunc(func() { - close(done) - }) - var result adapter.HTTPRequestScriptResult - vm.Set("done", func(call goja.FunctionCall) goja.Value { - defer doneFunc() - resultObject := jsc.AssertObject(vm, call.Argument(0), "done() argument", true) - if resultObject == nil { - panic(vm.NewGoError(E.New("request rejected by script"))) - } - result.URL = jsc.AssertString(vm, resultObject.Get("url"), "url", true) - result.Headers = jsc.AssertHTTPHeader(vm, resultObject.Get("headers"), "headers") - result.Body = jsc.AssertStringBinary(vm, resultObject.Get("body"), "body", true) - responseObject := jsc.AssertObject(vm, resultObject.Get("response"), "response", true) - if responseObject != nil { - result.Response = &adapter.HTTPRequestScriptResponse{ - Status: int(jsc.AssertInt(vm, responseObject.Get("status"), "status", true)), - Headers: jsc.AssertHTTPHeader(vm, responseObject.Get("headers"), "headers"), - Body: jsc.AssertStringBinary(vm, responseObject.Get("body"), "body", true), - } - } - return goja.Undefined() - }) - var err error - go func() { - _, err = vm.RunProgram(program) - if err != nil { - doneFunc() - } - }() - select { - case <-ctx.Done(): - vm.Interrupt(ctx.Err()) - return nil, ctx.Err() - case <-done: - if err != nil { - vm.Interrupt(err) - } else { - vm.Interrupt("script done") - } - } - return &result, err -} diff --git a/script/script_surge_http_response.go b/script/script_surge_http_response.go deleted file mode 100644 index 8d2f06b2..00000000 --- a/script/script_surge_http_response.go +++ /dev/null @@ -1,175 +0,0 @@ -package script - -import ( - "context" - "net/http" - "regexp" - "sync" - "time" - "unsafe" - - "github.com/sagernet/sing-box/adapter" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/script/jsc" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/logger" - - "github.com/dop251/goja" -) - -var _ adapter.HTTPResponseScript = (*SurgeHTTPResponseScript)(nil) - -type SurgeHTTPResponseScript struct { - GenericScript - pattern *regexp.Regexp - requiresBody bool - maxSize int64 - binaryBodyMode bool -} - -func NewSurgeHTTPResponseScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeHTTPResponseScript, error) { - source, err := NewSource(ctx, logger, options) - if err != nil { - return nil, err - } - pattern, err := regexp.Compile(options.HTTPOptions.Pattern) - if err != nil { - return nil, E.Cause(err, "parse pattern") - } - return &SurgeHTTPResponseScript{ - GenericScript: GenericScript{ - logger: logger, - tag: options.Tag, - timeout: time.Duration(options.Timeout), - arguments: options.Arguments, - source: source, - }, - pattern: pattern, - requiresBody: options.HTTPOptions.RequiresBody, - maxSize: options.HTTPOptions.MaxSize, - binaryBodyMode: options.HTTPOptions.BinaryBodyMode, - }, nil -} - -func (s *SurgeHTTPResponseScript) Type() string { - return C.ScriptTypeSurgeHTTPResponse -} - -func (s *SurgeHTTPResponseScript) Tag() string { - return s.tag -} - -func (s *SurgeHTTPResponseScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error { - return s.source.StartContext(ctx, startContext) -} - -func (s *SurgeHTTPResponseScript) PostStart() error { - return s.source.PostStart() -} - -func (s *SurgeHTTPResponseScript) Close() error { - return s.source.Close() -} - -func (s *SurgeHTTPResponseScript) Match(requestURL string) bool { - return s.pattern.MatchString(requestURL) -} - -func (s *SurgeHTTPResponseScript) RequiresBody() bool { - return s.requiresBody -} - -func (s *SurgeHTTPResponseScript) MaxSize() int64 { - return s.maxSize -} - -func (s *SurgeHTTPResponseScript) Run(ctx context.Context, request *http.Request, response *http.Response, body []byte) (*adapter.HTTPResponseScriptResult, error) { - program := s.source.Program() - if program == nil { - return nil, E.New("invalid script") - } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - vm := NewRuntime(ctx, s.logger, cancel) - err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments) - if err != nil { - return nil, err - } - return ExecuteSurgeHTTPResponse(vm, program, ctx, s.timeout, request, response, body, s.binaryBodyMode) -} - -func ExecuteSurgeHTTPResponse(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool) (*adapter.HTTPResponseScriptResult, error) { - if timeout == 0 { - timeout = defaultScriptTimeout - } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - vm.ClearInterrupt() - requestObject := vm.NewObject() - requestObject.Set("url", request.URL.String()) - requestObject.Set("method", request.Method) - requestObject.Set("headers", jsc.HeadersToValue(vm, request.Header)) - requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request)))) - vm.Set("request", requestObject) - - responseObject := vm.NewObject() - responseObject.Set("status", response.StatusCode) - responseObject.Set("headers", jsc.HeadersToValue(vm, response.Header)) - if !binaryBody { - responseObject.Set("body", string(body)) - } else { - responseObject.Set("body", jsc.NewUint8Array(vm, body)) - } - vm.Set("response", responseObject) - - done := make(chan struct{}) - doneFunc := common.OnceFunc(func() { - close(done) - }) - var ( - access sync.Mutex - result adapter.HTTPResponseScriptResult - ) - vm.Set("done", func(call goja.FunctionCall) goja.Value { - resultObject := jsc.AssertObject(vm, call.Argument(0), "done() argument", true) - if resultObject == nil { - panic(vm.NewGoError(E.New("response rejected by script"))) - } - access.Lock() - defer access.Unlock() - result.Status = int(jsc.AssertInt(vm, resultObject.Get("status"), "status", true)) - result.Headers = jsc.AssertHTTPHeader(vm, resultObject.Get("headers"), "headers") - result.Body = jsc.AssertStringBinary(vm, resultObject.Get("body"), "body", true) - doneFunc() - return goja.Undefined() - }) - var scriptErr error - go func() { - _, err := vm.RunProgram(program) - if err != nil { - access.Lock() - scriptErr = err - access.Unlock() - doneFunc() - } - }() - select { - case <-ctx.Done(): - println(1) - vm.Interrupt(ctx.Err()) - return nil, ctx.Err() - case <-done: - access.Lock() - defer access.Unlock() - if scriptErr != nil { - vm.Interrupt(scriptErr) - } else { - vm.Interrupt("script done") - } - return &result, scriptErr - } -} diff --git a/script/source.go b/script/source.go index f601d8e1..a7f9f2be 100644 --- a/script/source.go +++ b/script/source.go @@ -21,9 +21,9 @@ type Source interface { func NewSource(ctx context.Context, logger logger.Logger, options option.Script) (Source, error) { switch options.Source { - case C.ScriptSourceLocal: + case C.ScriptSourceTypeLocal: return NewLocalSource(ctx, logger, options) - case C.ScriptSourceRemote: + case C.ScriptSourceTypeRemote: return NewRemoteSource(ctx, logger, options) default: return nil, E.New("unknown source type: ", options.Source)