Add Surge MITM and scripts

This commit is contained in:
世界 2025-02-02 17:03:27 +08:00
parent b55bfca7de
commit 5e28a80e63
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
53 changed files with 4437 additions and 15842 deletions

View file

@ -52,6 +52,10 @@ type CacheFile interface {
StoreGroupExpand(group string, expand bool) error
LoadRuleSet(tag string) *SavedBinary
SaveRuleSet(tag string, set *SavedBinary) error
LoadScript(tag string) *SavedBinary
SaveScript(tag string, script *SavedBinary) error
SurgePersistentStoreRead(key string) string
SurgePersistentStoreWrite(key string, value string) error
}
type SavedBinary struct {

View file

@ -2,6 +2,8 @@ package adapter
import (
"context"
"crypto/tls"
"net/http"
"net/netip"
"time"
@ -57,6 +59,8 @@ type InboundContext struct {
Domain string
Client string
SniffContext any
HTTPRequest *http.Request
ClientHello *tls.ClientHelloInfo
// cache
@ -73,6 +77,7 @@ type InboundContext struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
MITM *option.MITMRouteOptions
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType

View file

@ -1,6 +1,8 @@
package adapter
import E "github.com/sagernet/sing/common/exceptions"
import (
E "github.com/sagernet/sing/common/exceptions"
)
type StartStage uint8
@ -45,6 +47,9 @@ type LifecycleService interface {
func Start(stage StartStage, services ...Lifecycle) error {
for _, service := range services {
if service == nil {
continue
}
err := service.Start(stage)
if err != nil {
return err

13
adapter/mitm.go Normal file
View file

@ -0,0 +1,13 @@
package adapter
import (
"context"
"net"
N "github.com/sagernet/sing/common/network"
)
type MITMEngine interface {
Lifecycle
NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
}

61
adapter/script.go Normal file
View file

@ -0,0 +1,61 @@
package adapter
import (
"context"
"net/http"
)
type ScriptManager interface {
Lifecycle
Scripts() []Script
// Script(name string) (Script, bool)
}
type Script interface {
Type() string
Tag() string
StartContext(ctx context.Context, startContext *HTTPStartContext) error
PostStart() error
Close() error
}
type GenericScript interface {
Script
Run(ctx context.Context) error
}
type HTTPScript interface {
Script
Match(requestURL string) bool
RequiresBody() bool
MaxSize() int64
}
type HTTPRequestScript interface {
HTTPScript
Run(ctx context.Context, request *http.Request, body []byte) (*HTTPRequestScriptResult, error)
}
type HTTPRequestScriptResult struct {
URL string
Headers http.Header
Body []byte
Response *HTTPRequestScriptResponse
}
type HTTPRequestScriptResponse struct {
Status int
Headers http.Header
Body []byte
}
type HTTPResponseScript interface {
HTTPScript
Run(ctx context.Context, request *http.Request, response *http.Response, body []byte) (*HTTPResponseScriptResult, error)
}
type HTTPResponseScriptResult struct {
Status int
Headers http.Header
Body []byte
}

35
box.go
View file

@ -23,9 +23,11 @@ import (
"github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/mitm"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/direct"
"github.com/sagernet/sing-box/route"
"github.com/sagernet/sing-box/script"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
@ -48,6 +50,8 @@ type Box struct {
dnsRouter *dns.Router
connection *route.ConnectionManager
router *route.Router
script *script.Manager
mitm adapter.MITMEngine //*mitm.Engine
services []adapter.LifecycleService
done chan struct{}
}
@ -173,7 +177,7 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize network manager")
}
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
connectionManager := route.NewConnectionManager(ctx, logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router)
@ -181,8 +185,8 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "initialize router")
}
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
ntpOptions := common.PtrValueOrDefault(options.NTP)
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
@ -289,6 +293,11 @@ func New(options Options) (*Box, error) {
"local",
option.LocalDNSServerOptions{},
)))
scriptManager, err := script.NewManager(ctx, logFactory, options.Scripts)
if err != nil {
return nil, E.Cause(err, "initialize script manager")
}
service.MustRegister[adapter.ScriptManager](ctx, scriptManager)
if platformInterface != nil {
err = platformInterface.Initialize(networkManager)
if err != nil {
@ -338,6 +347,16 @@ func New(options Options) (*Box, error) {
timeService.TimeService = ntpService
services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
}
mitmOptions := common.PtrValueOrDefault(options.MITM)
var mitmEngine *mitm.Engine
if mitmOptions.Enabled {
engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions)
if err != nil {
return nil, E.Cause(err, "create MITM engine")
}
service.MustRegister[adapter.MITMEngine](ctx, engine)
mitmEngine = engine
}
return &Box{
network: networkManager,
endpoint: endpointManager,
@ -347,6 +366,8 @@ func New(options Options) (*Box, error) {
dnsRouter: dnsRouter,
connection: connectionManager,
router: router,
script: scriptManager,
mitm: mitmEngine,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
@ -405,11 +426,11 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router, s.script, s.mitm)
if err != nil {
return err
}
@ -433,7 +454,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -441,7 +462,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -460,7 +481,7 @@ func (s *Box) Close() error {
close(s.done)
}
err := common.Close(
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
s.inbound, s.outbound, s.endpoint, s.mitm, s.script, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
)
for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error {

View file

@ -18,5 +18,6 @@ func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Rea
}
metadata.Protocol = C.ProtocolHTTP
metadata.Domain = M.ParseSocksaddr(request.Host).AddrString()
metadata.HTTPRequest = request
return nil
}

View file

@ -21,6 +21,7 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade
if clientHello != nil {
metadata.Protocol = C.ProtocolTLS
metadata.Domain = clientHello.ServerName
metadata.ClientHello = clientHello
return nil
}
return err

View file

@ -8,7 +8,10 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
M "github.com/sagernet/sing/common/metadata"
)
func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
@ -35,17 +38,30 @@ func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func(
if err != nil {
return
}
template := &x509.Certificate{
SerialNumber: serialNumber,
NotBefore: timeFunc().Add(time.Hour * -1),
NotAfter: expire,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
Subject: pkix.Name{
CommonName: serverName,
},
DNSNames: []string{serverName},
var template *x509.Certificate
if serverAddress := M.ParseAddr(serverName); serverAddress.IsValid() {
template = &x509.Certificate{
SerialNumber: serialNumber,
IPAddresses: []net.IP{serverAddress.AsSlice()},
NotBefore: timeFunc().Add(time.Hour * -1),
NotAfter: expire,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
} else {
template = &x509.Certificate{
SerialNumber: serialNumber,
NotBefore: timeFunc().Add(time.Hour * -1),
NotAfter: expire,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
Subject: pkix.Name{
CommonName: serverName,
},
DNSNames: []string{serverName},
}
}
if parent == nil {
parent = template

View file

@ -7,8 +7,9 @@ import (
"strings"
"time"
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/publicsuffix"
)
type Conn struct {
@ -42,30 +43,12 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return
}
}
splits := strings.Split(string(b[serverName.Index:serverName.Index+serverName.Length]), ".")
splits := strings.Split(serverName.ServerName, ".")
currentIndex := serverName.Index
var striped bool
if len(splits) > 3 {
suffix := splits[len(splits)-3] + "." + splits[len(splits)-2] + "." + splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-3]
}
striped = true
if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
}
if !striped && len(splits) > 2 {
suffix := splits[len(splits)-2] + "." + splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-2]
}
striped = true
}
if !striped && len(splits) > 1 {
suffix := splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-1]
}
}
if len(splits) > 1 && common.Contains(publicPrefix, splits[0]) {
if len(splits) > 1 && splits[0] == "..." {
currentIndex += len(splits[0]) + 1
splits = splits[1:]
}

View file

@ -23,9 +23,9 @@ const (
)
type myServerName struct {
Index int
Length int
sex []byte
Index int
Length int
ServerName string
}
func indexTLSServerName(payload []byte) *myServerName {
@ -119,9 +119,9 @@ func indexTLSServerNameFromExtensions(exs []byte) *myServerName {
sniLen := uint16(sex[3])<<8 | uint16(sex[4])
sex = sex[sniExtensionHeaderLen:]
return &myServerName{
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen),
sex: sex,
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen),
ServerName: string(sex),
}
}
exs = exs[4+exLen:]

View file

@ -1,55 +0,0 @@
package tf
import (
"bufio"
"bytes"
_ "embed"
"io"
"strings"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/domain"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
var publicPrefix = []string{
"www",
}
//go:generate wget -O public_suffix_list.dat https://publicsuffix.org/list/public_suffix_list.dat
//go:embed public_suffix_list.dat
var publicSuffix []byte
var publicSuffixMatcher = common.OnceValue(func() *domain.Matcher {
matcher, err := initPublicSuffixMatcher()
if err != nil {
panic(F.ToString("error in initialize public suffix matcher"))
}
return matcher
})
func initPublicSuffixMatcher() (*domain.Matcher, error) {
reader := bufio.NewReader(bytes.NewReader(publicSuffix))
var domainList []string
for {
line, isPrefix, err := reader.ReadLine()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if isPrefix {
return nil, E.New("unexpected prefix line")
}
lineStr := string(line)
lineStr = strings.TrimSpace(lineStr)
if lineStr == "" || strings.HasPrefix(lineStr, "//") {
continue
}
domainList = append(domainList, lineStr)
}
return domain.NewMatcher(domainList, nil, false), nil
}

File diff suppressed because it is too large Load diff

12
constant/script.go Normal file
View file

@ -0,0 +1,12 @@
package constant
const (
ScriptTypeSurgeGeneric = "sg-generic"
ScriptTypeSurgeHTTPRequest = "sg-http-request"
ScriptTypeSurgeHTTPResponse = "sg-http-response"
ScriptTypeSurgeCron = "sg-cron"
ScriptTypeSurgeEvent = "sg-event"
ScriptSourceLocal = "local"
ScriptSourceRemote = "remote"
)

View file

@ -19,10 +19,12 @@ import (
)
var (
bucketSelected = []byte("selected")
bucketExpand = []byte("group_expand")
bucketMode = []byte("clash_mode")
bucketRuleSet = []byte("rule_set")
bucketSelected = []byte("selected")
bucketExpand = []byte("group_expand")
bucketMode = []byte("clash_mode")
bucketRuleSet = []byte("rule_set")
bucketScript = []byte("script")
bucketSgPersistentStore = []byte("sg_persistent_store")
bucketNameList = []string{
string(bucketSelected),
@ -316,3 +318,62 @@ func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedBinary) error {
return bucket.Put([]byte(tag), setBinary)
})
}
func (c *CacheFile) LoadScript(tag string) *adapter.SavedBinary {
var savedSet adapter.SavedBinary
err := c.DB.View(func(t *bbolt.Tx) error {
bucket := c.bucket(t, bucketScript)
if bucket == nil {
return os.ErrNotExist
}
scriptBinary := bucket.Get([]byte(tag))
if len(scriptBinary) == 0 {
return os.ErrInvalid
}
return savedSet.UnmarshalBinary(scriptBinary)
})
if err != nil {
return nil
}
return &savedSet
}
func (c *CacheFile) SaveScript(tag string, set *adapter.SavedBinary) error {
return c.DB.Batch(func(t *bbolt.Tx) error {
bucket, err := c.createBucket(t, bucketScript)
if err != nil {
return err
}
scriptBinary, err := set.MarshalBinary()
if err != nil {
return err
}
return bucket.Put([]byte(tag), scriptBinary)
})
}
func (c *CacheFile) SurgePersistentStoreRead(key string) string {
var value string
_ = c.DB.View(func(t *bbolt.Tx) error {
bucket := c.bucket(t, bucketSgPersistentStore)
if bucket == nil {
return nil
}
valueBinary := bucket.Get([]byte(key))
if len(valueBinary) > 0 {
value = string(valueBinary)
}
return nil
})
return value
}
func (c *CacheFile) SurgePersistentStoreWrite(key string, value string) error {
return c.DB.Batch(func(t *bbolt.Tx) error {
bucket, err := c.createBucket(t, bucketSgPersistentStore)
if err != nil {
return err
}
return bucket.Put([]byte(key), []byte(value))
})
}

View file

@ -32,4 +32,9 @@ type Notification struct {
Subtitle string
Body string
OpenURL string
Clipboard string
MediaURL string
MediaData []byte
MediaType string
Timeout int
}

8
go.mod
View file

@ -3,9 +3,11 @@ module github.com/sagernet/sing-box
go 1.20
require (
github.com/adhocore/gronx v1.19.5
github.com/caddyserver/certmagic v0.20.0
github.com/cloudflare/circl v1.3.7
github.com/cretz/bine v0.2.0
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/render v1.0.3
github.com/gofrs/uuid/v5 v5.3.0
@ -61,15 +63,17 @@ require (
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a // indirect
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
github.com/hashicorp/yamux v0.1.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v1.1.0 // indirect
@ -78,7 +82,6 @@ require (
github.com/libdns/libdns v0.2.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@ -97,7 +100,6 @@ require (
golang.org/x/tools v0.24.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/blake3 v1.3.0 // indirect
)

32
go.sum
View file

@ -1,3 +1,6 @@
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/adhocore/gronx v1.19.5 h1:cwIG4nT1v9DvadxtHBe6MzE+FZ1JDvAUC45U2fl4eSQ=
github.com/adhocore/gronx v1.19.5/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
@ -16,6 +19,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbYd8tQGRWacE9kU=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1/go.mod h1:+hnT3ywWDTAFrW5aE+u2Sa/wT555ZqwoCS+pk3p6ry4=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17 h1:spJaibPy2sZNwo6Q0HjBVufq7hBUj5jNFOKRoogCBow=
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
@ -25,6 +32,8 @@ github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIo
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q=
github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
@ -41,8 +50,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a h1:fEBsGL/sjAuJrgah5XqmmYsTLzJp/TO9Lhy39gkverk=
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@ -58,9 +67,6 @@ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6K
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/libdns/alidns v1.0.3 h1:LFHuGnbseq5+HCeGa1aW8awyX/4M2psB9962fdD2+yQ=
github.com/libdns/alidns v1.0.3/go.mod h1:e18uAG6GanfRhcJj6/tps2rCMzQJaYVcGKT+ELjdjGE=
github.com/libdns/cloudflare v0.1.1 h1:FVPfWwP8zZCqj268LZjmkDleXlHPlFU9KC4OJ3yn054=
@ -80,8 +86,6 @@ github.com/mholt/acmez v1.2.0 h1:1hhLxSgY5FvH5HCnGUuwbKY2VQVo8IU7rxXKSnZ7F30=
github.com/mholt/acmez v1.2.0/go.mod h1:VT9YwH1xgNX1kmYY89gY8xPJC84BFAisjo8Egigt4kE=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
@ -114,8 +118,6 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/quic-go v0.48.2-beta.1 h1:W0plrLWa1XtOWDTdX3CJwxmQuxkya12nN5BRGZ87kEg=
github.com/sagernet/quic-go v0.48.2-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k=
github.com/sagernet/quic-go v0.49.0-beta.1 h1:3LdoCzVVfYRibZns1tYWSIoB65fpTmrwy+yfK8DQ8Jk=
github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8WHNsRs71b3Lt1+p/U=
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc=
@ -172,8 +174,6 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
@ -182,8 +182,6 @@ golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
@ -195,12 +193,10 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
@ -222,10 +218,10 @@ google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -10,6 +10,10 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)
const (
DefaultTimeFormat = "-0700 2006-01-02 15:04:05"
)
type Options struct {
Context context.Context
Options option.LogOptions
@ -47,7 +51,7 @@ func New(options Options) (Factory, error) {
DisableColors: logOptions.DisableColor || logFilePath != "",
DisableTimestamp: !logOptions.Timestamp && logFilePath != "",
FullTimestamp: logOptions.Timestamp,
TimestampFormat: "-0700 2006-01-02 15:04:05",
TimestampFormat: DefaultTimeFormat,
}
factory := NewDefaultFactory(
options.Context,

11
mitm/constants.go Normal file
View file

@ -0,0 +1,11 @@
package mitm
import (
"encoding/base64"
"github.com/sagernet/sing/common"
)
var surgeTinyGif = common.OnceValue(func() []byte {
return common.Must1(base64.StdEncoding.DecodeString("R0lGODlhAQABAAAAACH5BAEAAAAALAAAAAABAAEAAAIBAAA="))
})

597
mitm/engine.go Normal file
View file

@ -0,0 +1,597 @@
package mitm
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"io"
"mime"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
sTLS "github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
sHTTP "github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/service"
"golang.org/x/crypto/pkcs12"
)
var _ adapter.MITMEngine = (*Engine)(nil)
type Engine struct {
ctx context.Context
logger logger.ContextLogger
connection adapter.ConnectionManager
script adapter.ScriptManager
timeFunc func() time.Time
http2Enabled bool
tlsDecryptionEnabled bool
tlsPrivateKey any
tlsCertificate *x509.Certificate
}
func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) {
engine := &Engine{
ctx: ctx,
logger: logger,
// http2Enabled: options.HTTP2Enabled,
}
if options.TLSDecryptionOptions != nil && options.TLSDecryptionOptions.Enabled {
pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryptionOptions.KeyPair)
if err != nil {
return nil, E.Cause(err, "decode key pair base64 bytes")
}
privateKey, certificate, err := pkcs12.Decode(pfxBytes, options.TLSDecryptionOptions.KeyPassword)
if err != nil {
return nil, E.Cause(err, "decode key pair")
}
engine.tlsDecryptionEnabled = true
engine.tlsPrivateKey = privateKey
engine.tlsCertificate = certificate
}
return engine, nil
}
func (e *Engine) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateInitialize:
e.connection = service.FromContext[adapter.ConnectionManager](e.ctx)
e.script = service.FromContext[adapter.ScriptManager](e.ctx)
e.timeFunc = ntp.TimeFuncFromContext(e.ctx)
if e.timeFunc == nil {
e.timeFunc = time.Now
}
}
return nil
}
func (e *Engine) Close() error {
return nil
}
func (e *Engine) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if e.tlsDecryptionEnabled && metadata.ClientHello != nil {
err := e.newTLS(ctx, this, conn, metadata, onClose)
if err != nil {
e.logger.ErrorContext(ctx, err)
} else {
e.logger.DebugContext(ctx, "connection closed")
}
if onClose != nil {
onClose(err)
}
return
} else if metadata.HTTPRequest != nil {
err := e.newHTTP1(ctx, this, conn, nil, metadata)
if err != nil {
e.logger.ErrorContext(ctx, err)
} else {
e.logger.DebugContext(ctx, "connection closed")
}
if onClose != nil {
onClose(err)
}
return
} else {
e.logger.DebugContext(ctx, "HTTP and TLS not detected, skipped")
}
metadata.MITM = nil
e.connection.NewConnection(ctx, this, conn, metadata, onClose)
}
func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
acceptHTTP := len(metadata.ClientHello.SupportedProtos) == 0 || common.Contains(metadata.ClientHello.SupportedProtos, "http/1.1")
acceptH2 := e.http2Enabled && common.Contains(metadata.ClientHello.SupportedProtos, "h2")
if !acceptHTTP && !acceptH2 {
e.logger.DebugContext(ctx, "unsupported application protocol: ", strings.Join(metadata.ClientHello.SupportedProtos, ","))
e.connection.NewConnection(ctx, this, conn, metadata, onClose)
return nil
}
var nextProtos []string
if acceptH2 {
nextProtos = append(nextProtos, "h2")
} else if acceptHTTP {
nextProtos = append(nextProtos, "http/1.1")
}
var (
maxVersion uint16
minVersion uint16
)
for _, version := range metadata.ClientHello.SupportedVersions {
maxVersion = common.Max(maxVersion, version)
minVersion = common.Min(minVersion, version)
}
serverName := metadata.ClientHello.ServerName
if serverName == "" && metadata.Destination.IsIP() {
serverName = metadata.Destination.Addr.String()
}
tlsConfig := &tls.Config{
Time: e.timeFunc,
CipherSuites: metadata.ClientHello.CipherSuites,
ServerName: serverName,
CurvePreferences: metadata.ClientHello.SupportedCurves,
NextProtos: nextProtos,
MinVersion: minVersion,
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return sTLS.GenerateKeyPair(e.tlsCertificate, e.tlsPrivateKey, e.timeFunc, serverName)
},
}
tlsConn := tls.Server(conn, tlsConfig)
err := tlsConn.HandshakeContext(ctx)
if err != nil {
return E.Cause(err, "TLS handshake")
}
if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
return e.newHTTP2(ctx, this, tlsConn, metadata, onClose)
} else {
return e.newHTTP1(ctx, this, tlsConn, tlsConfig, metadata)
}
}
func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext) error {
options := metadata.MITM
metadata.MITM = nil
defer conn.Close()
reader := bufio.NewReader(conn)
request, err := sHTTP.ReadRequest(reader)
if err != nil {
return E.Cause(err, "read HTTP request")
}
rawRequestURL := request.URL
rawRequestURL.Scheme = "https"
if rawRequestURL.Host == "" {
rawRequestURL.Host = request.Host
}
requestURL := rawRequestURL.String()
request.RequestURI = ""
var (
requestMatch bool
requestScript adapter.HTTPRequestScript
)
for _, script := range e.script.Scripts() {
if !common.Contains(options.Script, script.Tag()) {
continue
}
httpScript, isHTTP := script.(adapter.HTTPRequestScript)
if !isHTTP {
_, isHTTP = script.(adapter.HTTPScript)
if !isHTTP {
e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script")
}
continue
}
if !httpScript.Match(requestURL) {
continue
}
e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]")
requestScript = httpScript
requestMatch = true
break
}
if requestScript != nil {
var body []byte
if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) {
body, err = io.ReadAll(request.Body)
if err != nil {
return E.Cause(err, "read HTTP request body")
}
request.Body = io.NopCloser(bytes.NewReader(body))
}
var result *adapter.HTTPRequestScriptResult
result, err = requestScript.Run(ctx, request, body)
if err != nil {
return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]")
}
if result.Response != nil {
if result.Response.Status == 0 {
result.Response.Status = http.StatusOK
}
response := &http.Response{
StatusCode: result.Response.Status,
Status: http.StatusText(result.Response.Status),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Header: result.Response.Headers,
Body: io.NopCloser(bytes.NewReader(result.Response.Body)),
}
err = response.Write(conn)
if err != nil {
return E.Cause(err, "write fake response body")
}
return nil
} else {
if result.URL != "" {
var newURL *url.URL
newURL, err = url.Parse(result.URL)
if err != nil {
return E.Cause(err, "parse updated request URL")
}
request.URL = newURL
newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port())
if newDestination.Port == 0 {
newDestination.Port = metadata.Destination.Port
}
metadata.Destination = newDestination
if tlsConfig != nil {
tlsConfig.ServerName = newURL.Hostname()
}
}
for key, values := range result.Headers {
request.Header[key] = values
}
if newHost := result.Headers.Get("Host"); newHost != "" {
request.Host = newHost
request.Header.Del("Host")
}
if result.Body != nil {
request.Body = io.NopCloser(bytes.NewReader(result.Body))
request.ContentLength = int64(len(body))
}
}
}
if !requestMatch {
for i, rule := range options.SurgeURLRewrite {
if !rule.Pattern.MatchString(requestURL) {
continue
}
e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String())
if rule.Reject {
return E.New("request rejected by url_rewrite")
} else if rule.Redirect {
w := new(simpleResponseWriter)
http.Redirect(w, request, rule.Destination.String(), http.StatusFound)
err = w.Build(request).Write(conn)
if err != nil {
return E.Cause(err, "write url_rewrite 302 response")
}
return nil
}
requestMatch = true
request.URL = rule.Destination
newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port())
if newDestination.Port == 0 {
newDestination.Port = metadata.Destination.Port
}
metadata.Destination = newDestination
if tlsConfig != nil {
tlsConfig.ServerName = rule.Destination.Hostname()
}
break
}
for i, rule := range options.SurgeHeaderRewrite {
if rule.Response {
continue
}
if !rule.Pattern.MatchString(requestURL) {
continue
}
requestMatch = true
e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
switch {
case rule.Add:
if strings.ToLower(rule.Key) == "host" {
request.Host = rule.Value
continue
}
request.Header.Add(rule.Key, rule.Value)
case rule.Delete:
request.Header.Del(rule.Key)
case rule.Replace:
if request.Header.Get(rule.Key) != "" {
request.Header.Set(rule.Key, rule.Value)
}
case rule.ReplaceRegex:
if value := request.Header.Get(rule.Key); value != "" {
request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
}
}
}
for i, rule := range options.SurgeBodyRewrite {
if rule.Response {
continue
}
if !rule.Pattern.MatchString(requestURL) {
continue
}
requestMatch = true
e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
var body []byte
if request.ContentLength <= 0 {
e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
break
} else if request.ContentLength > 131072 {
e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
break
}
body, err = io.ReadAll(request.Body)
if err != nil {
return E.Cause(err, "read HTTP request body")
}
for mi := 0; i < len(rule.Match); i++ {
body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
}
request.Body = io.NopCloser(bytes.NewReader(body))
request.ContentLength = int64(len(body))
}
}
if !requestMatch {
for i, rule := range options.SurgeMapLocal {
if !rule.Pattern.MatchString(requestURL) {
continue
}
requestMatch = true
e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String())
var (
statusCode = http.StatusOK
headers = make(http.Header)
body []byte
)
if rule.StatusCode > 0 {
statusCode = rule.StatusCode
}
switch {
case rule.File:
resource, err := os.ReadFile(rule.Data)
if err != nil {
return E.Cause(err, "open map local source")
}
mimeType := mime.TypeByExtension(filepath.Ext(rule.Data))
if mimeType == "" {
mimeType = "application/octet-stream"
}
headers.Set("Content-Type", mimeType)
body = resource
case rule.Text:
headers.Set("Content-Type", "text/plain")
body = []byte(rule.Data)
case rule.TinyGif:
headers.Set("Content-Type", "image/gif")
body = surgeTinyGif()
case rule.Base64:
headers.Set("Content-Type", "application/octet-stream")
body = rule.Base64Data
}
response := &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Header: headers,
Body: io.NopCloser(bytes.NewReader(body)),
}
err = response.Write(conn)
if err != nil {
return E.Cause(err, "write map local response")
}
return nil
}
}
ctx = adapter.WithContext(ctx, &metadata)
var remoteConn net.Conn
if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return E.Cause(err, "open outbound connection")
}
defer remoteConn.Close()
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if tlsConfig != nil {
return tls.Client(remoteConn, tlsConfig), nil
} else {
return remoteConn, nil
}
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
defer httpClient.CloseIdleConnections()
requestCtx, cancel := context.WithCancel(ctx)
defer cancel()
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
var (
responseScript adapter.HTTPResponseScript
responseMatch bool
)
for _, script := range e.script.Scripts() {
if !common.Contains(options.Script, script.Tag()) {
continue
}
httpScript, isHTTP := script.(adapter.HTTPResponseScript)
if !isHTTP {
_, isHTTP = script.(adapter.HTTPScript)
if !isHTTP {
e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script")
}
continue
}
if !httpScript.Match(requestURL) {
continue
}
e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]")
responseScript = httpScript
responseMatch = true
break
}
if responseScript != nil {
var body []byte
if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
body, err = io.ReadAll(response.Body)
if err != nil {
return E.Cause(err, "read HTTP response body")
}
response.Body = io.NopCloser(bytes.NewReader(body))
}
var result *adapter.HTTPResponseScriptResult
result, err = responseScript.Run(ctx, request, response, body)
if err != nil {
return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]")
}
if result.Status > 0 {
response.Status = http.StatusText(result.Status)
response.StatusCode = result.Status
}
for key, values := range result.Headers {
response.Header[key] = values
}
if result.Body != nil {
response.Body.Close()
response.Body = io.NopCloser(bytes.NewReader(result.Body))
response.ContentLength = int64(len(result.Body))
}
}
if !responseMatch {
for i, rule := range options.SurgeHeaderRewrite {
if !rule.Response {
continue
}
if !rule.Pattern.MatchString(requestURL) {
continue
}
responseMatch = true
e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String())
switch {
case rule.Add:
response.Header.Add(rule.Key, rule.Value)
case rule.Delete:
response.Header.Del(rule.Key)
case rule.Replace:
if response.Header.Get(rule.Key) != "" {
response.Header.Set(rule.Key, rule.Value)
}
case rule.ReplaceRegex:
if value := response.Header.Get(rule.Key); value != "" {
response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value))
}
}
}
for i, rule := range options.SurgeBodyRewrite {
if !rule.Response {
continue
}
if !rule.Pattern.MatchString(requestURL) {
continue
}
responseMatch = true
e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String())
var body []byte
if response.ContentLength <= 0 {
e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length")
break
} else if response.ContentLength > 131072 {
e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength)
break
}
body, err = io.ReadAll(response.Body)
if err != nil {
return E.Cause(err, "read HTTP request body")
}
for mi := 0; i < len(rule.Match); i++ {
body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i]))
}
response.Body = io.NopCloser(bytes.NewReader(body))
response.ContentLength = int64(len(body))
}
}
if !requestMatch && !responseMatch {
e.logger.WarnContext(ctx, "request not modified")
}
err = response.Write(conn)
if err != nil {
return E.Errors(E.Cause(err, "write HTTP response"), innerErr.Load())
} else if innerErr.Load() != nil {
return E.Cause(innerErr.Load(), "write HTTP response")
}
return nil
}
func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn *tls.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
// TODO: implement http2 support
return nil
}
type simpleResponseWriter struct {
statusCode int
header http.Header
body bytes.Buffer
}
func (w *simpleResponseWriter) Build(request *http.Request) *http.Response {
return &http.Response{
StatusCode: w.statusCode,
Status: http.StatusText(w.statusCode),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Header: w.header,
Body: io.NopCloser(&w.body),
}
}
func (w *simpleResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *simpleResponseWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
func (w *simpleResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}

26
option/mitm.go Normal file
View file

@ -0,0 +1,26 @@
package option
import (
"github.com/sagernet/sing/common/json/badoption"
)
type MITMOptions struct {
Enabled bool `json:"enabled,omitempty"`
// HTTP2Enabled bool `json:"http2_enabled,omitempty"`
TLSDecryptionOptions *TLSDecryptionOptions `json:"tls_decryption,omitempty"`
}
type TLSDecryptionOptions struct {
Enabled bool `json:"enabled,omitempty"`
KeyPair string `json:"key_pair_p12,omitempty"`
KeyPassword string `json:"key_password,omitempty"`
}
type MITMRouteOptions struct {
Enabled bool `json:"enabled,omitempty"`
Script badoption.Listable[string] `json:"script,omitempty"`
SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"sg_url_rewrite,omitempty"`
SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"`
SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"sg_body_rewrite,omitempty"`
SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"sg_map_local,omitempty"`
}

View file

@ -0,0 +1,444 @@
package option
import (
"encoding/base64"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json"
)
type SurgeURLRewriteLine struct {
Pattern *regexp.Regexp
Destination *url.URL
Redirect bool
Reject bool
}
func (l SurgeURLRewriteLine) String() string {
var fields []string
fields = append(fields, l.Pattern.String())
if l.Reject {
fields = append(fields, "_")
} else {
fields = append(fields, l.Destination.String())
}
switch {
case l.Redirect:
fields = append(fields, "302")
case l.Reject:
fields = append(fields, "reject")
default:
fields = append(fields, "header")
}
return encodeSurgeKeys(fields)
}
func (l SurgeURLRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeURLRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: ", stringValue)
} else if len(fields) < 2 || len(fields) > 3 {
return E.New("invalid surge_url_rewrite line: ", stringValue)
}
pattern, err := regexp.Compile(fields[0].Key)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: invalid pattern: ", stringValue)
}
l.Pattern = pattern
l.Destination, err = url.Parse(fields[1].Key)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: invalid destination: ", stringValue)
}
if len(fields) == 3 {
switch fields[2].Key {
case "header":
case "302":
l.Redirect = true
case "reject":
l.Reject = true
default:
return E.New("invalid surge_url_rewrite line: invalid action: ", stringValue)
}
}
return nil
}
type SurgeHeaderRewriteLine struct {
Response bool
Pattern *regexp.Regexp
Add bool
Delete bool
Replace bool
ReplaceRegex bool
Key string
Match *regexp.Regexp
Value string
}
func (l SurgeHeaderRewriteLine) String() string {
var fields []string
if !l.Response {
fields = append(fields, "http-request")
} else {
fields = append(fields, "http-response")
}
fields = append(fields, l.Pattern.String())
if l.Add {
fields = append(fields, "header-add")
} else if l.Delete {
fields = append(fields, "header-del")
} else if l.Replace {
fields = append(fields, "header-replace")
} else if l.ReplaceRegex {
fields = append(fields, "header-replace-regex")
}
fields = append(fields, l.Key)
if l.Add || l.Replace {
fields = append(fields, l.Value)
} else if l.ReplaceRegex {
fields = append(fields, l.Match.String(), l.Value)
}
return encodeSurgeKeys(fields)
}
func (l SurgeHeaderRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeHeaderRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: ", stringValue)
} else if len(fields) < 4 {
return E.New("invalid surge_header_rewrite line: ", stringValue)
}
switch fields[0].Key {
case "http-request":
case "http-response":
l.Response = true
default:
return E.New("invalid surge_header_rewrite line: invalid type: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[1].Key)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: invalid pattern: ", stringValue)
}
switch fields[2].Key {
case "header-add":
l.Add = true
if len(fields) != 5 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Value = fields[4].Key
case "header-del":
l.Delete = true
l.Key = fields[3].Key
case "header-replace":
l.Replace = true
if len(fields) != 5 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Value = fields[4].Key
case "header-replace-regex":
l.ReplaceRegex = true
if len(fields) != 6 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Match, err = regexp.Compile(fields[4].Key)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: invalid match: ", stringValue)
}
l.Value = fields[5].Key
default:
return E.New("invalid surge_header_rewrite line: invalid action: ", stringValue)
}
return nil
}
type SurgeBodyRewriteLine struct {
Response bool
Pattern *regexp.Regexp
Match []*regexp.Regexp
Replace []string
}
func (l SurgeBodyRewriteLine) String() string {
var fields []string
if !l.Response {
fields = append(fields, "http-request")
} else {
fields = append(fields, "http-response")
}
for i := 0; i < len(l.Match); i += 2 {
fields = append(fields, l.Match[i].String(), l.Replace[i])
}
return strings.Join(fields, " ")
}
func (l SurgeBodyRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeBodyRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_body_rewrite line: ", stringValue)
} else if len(fields) < 4 {
return E.New("invalid surge_body_rewrite line: ", stringValue)
} else if len(fields)%2 != 0 {
return E.New("invalid surge_body_rewrite line: ", stringValue)
}
switch fields[0].Key {
case "http-request":
case "http-response":
l.Response = true
default:
return E.New("invalid surge_body_rewrite line: invalid type: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[1].Key)
for i := 2; i < len(fields); i += 2 {
var match *regexp.Regexp
match, err = regexp.Compile(fields[i].Key)
if err != nil {
return E.Cause(err, "invalid surge_body_rewrite line: invalid match: ", stringValue)
}
l.Match = append(l.Match, match)
l.Replace = append(l.Replace, fields[i+1].Key)
}
return nil
}
type SurgeMapLocalLine struct {
Pattern *regexp.Regexp
StatusCode int
File bool
Text bool
TinyGif bool
Base64 bool
Data string
Base64Data []byte
Headers http.Header
}
func (l SurgeMapLocalLine) String() string {
var fields []surgeField
fields = append(fields, surgeField{Key: l.Pattern.String()})
if l.File {
fields = append(fields, surgeField{Key: "data-type", Value: "file"})
fields = append(fields, surgeField{Key: "data", Value: l.Data})
} else if l.Text {
fields = append(fields, surgeField{Key: "data-type", Value: "text"})
fields = append(fields, surgeField{Key: "data", Value: l.Data})
} else if l.TinyGif {
fields = append(fields, surgeField{Key: "data-type", Value: "tiny-gif"})
} else if l.Base64 {
fields = append(fields, surgeField{Key: "data-type", Value: "base64"})
fields = append(fields, surgeField{Key: "data-type", Value: base64.StdEncoding.EncodeToString(l.Base64Data)})
}
fields = append(fields, surgeField{Key: "status-code", Value: F.ToString(l.StatusCode), ValueSet: true})
if len(l.Headers) > 0 {
var headers []string
for key, values := range l.Headers {
for _, value := range values {
headers = append(headers, key+":"+value)
}
}
fields = append(fields, surgeField{Key: "headers", Value: strings.Join(headers, "|")})
}
return encodeSurgeFields(fields)
}
func (l SurgeMapLocalLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeMapLocalLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_map_local line: ", stringValue)
} else if len(fields) < 1 {
return E.New("invalid surge_map_local line: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[0].Key)
if err != nil {
return E.Cause(err, "invalid surge_map_local line: invalid pattern: ", stringValue)
}
dataTypeField := common.Find(fields, func(it surgeField) bool {
return it.Key == "data-type"
})
if !dataTypeField.ValueSet {
return E.New("invalid surge_map_local line: missing data-type: ", stringValue)
}
switch dataTypeField.Value {
case "file":
l.File = true
case "text":
l.Text = true
case "tiny-gif":
l.TinyGif = true
case "base64":
l.Base64 = true
}
for i := 1; i < len(fields); i++ {
switch fields[i].Key {
case "data-type":
continue
case "data":
if l.File {
l.Data = fields[i].Value
} else if l.Text {
l.Data = fields[i].Value
} else if l.Base64 {
l.Base64Data, err = base64.StdEncoding.DecodeString(fields[i].Value)
if err != nil {
return E.New("invalid surge_map_local line: invalid base64 data: ", stringValue)
}
}
case "status-code":
statusCode, err := strconv.ParseInt(fields[i].Value, 10, 16)
if err != nil {
return E.New("invalid surge_map_local line: invalid status code: ", stringValue)
}
l.StatusCode = int(statusCode)
case "headers":
headers := make(http.Header)
for _, headerLine := range strings.Split(fields[i].Value, "|") {
if !strings.Contains(headerLine, ":") {
return E.New("invalid surge_map_local line: headers: missing `:` in item: ", stringValue, ": ", headerLine)
}
headers.Add(common.SubstringBefore(headerLine, ":"), common.SubstringAfter(headerLine, ":"))
}
l.Headers = headers
default:
return E.New("invalid surge_map_local line: unknown options: ", stringValue)
}
}
return nil
}
type surgeField struct {
Key string
Value string
ValueSet bool
}
func encodeSurgeKeys(keys []string) string {
keys = common.Map(keys, func(it string) string {
if strings.ContainsFunc(it, unicode.IsSpace) {
return "\"" + it + "\""
} else {
return it
}
})
return strings.Join(keys, " ")
}
func encodeSurgeFields(fields []surgeField) string {
return strings.Join(common.Map(fields, func(it surgeField) string {
if !it.ValueSet {
if strings.ContainsFunc(it.Key, unicode.IsSpace) {
return "\"" + it.Key + "\""
} else {
return it.Key
}
} else {
if strings.ContainsFunc(it.Value, unicode.IsSpace) {
return it.Key + "=\"" + it.Value + "\""
} else {
return it.Key + "=" + it.Value
}
}
}), " ")
}
func surgeFields(s string) ([]surgeField, error) {
var (
fields []surgeField
currentField *surgeField
)
for _, field := range strings.Fields(s) {
if currentField != nil {
field = " " + field
if strings.HasSuffix(field, "\"") {
field = field[:len(field)-1]
if !currentField.ValueSet {
currentField.Key += field
} else {
currentField.Value += field
}
fields = append(fields, *currentField)
currentField = nil
} else {
if !currentField.ValueSet {
currentField.Key += field
} else {
currentField.Value += " " + field
}
}
}
if !strings.Contains(field, "=") {
if strings.HasPrefix(field, "\"") {
field = field[1:]
if strings.HasSuffix(field, "\"") {
field = field[:len(field)-1]
} else {
currentField = &surgeField{Key: field}
continue
}
}
fields = append(fields, surgeField{Key: field})
} else {
key := common.SubstringBefore(field, "=")
value := common.SubstringAfter(field, "=")
if strings.HasPrefix(value, "\"") {
value = value[1:]
if strings.HasSuffix(field, "\"") {
value = value[:len(value)-1]
} else {
currentField = &surgeField{Key: key, Value: field, ValueSet: true}
continue
}
}
fields = append(fields, surgeField{Key: key, Value: value, ValueSet: true})
}
}
if currentField != nil {
return nil, E.New("invalid surge fields line: ", s)
}
return fields, nil
}

View file

@ -12,13 +12,15 @@ type _Options struct {
Schema string `json:"$schema,omitempty"`
Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"`
Certificate *CertificateOptions `json:"certificate,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"`
Certificate *CertificateOptions `json:"certificate,omitempty"`
MITM *MITMOptions `json:"mitm,omitempty"`
Scripts []Script `json:"scripts,omitempty"`
}
type Options _Options

View file

@ -153,6 +153,8 @@ type RawRouteOptionsActionOptions struct {
TLSFragment bool `json:"tls_fragment,omitempty"`
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
MITM *MITMRouteOptions `json:"mitm,omitempty"`
}
type RouteOptionsActionOptions RawRouteOptionsActionOptions

138
option/script.go Normal file
View file

@ -0,0 +1,138 @@
package option
import (
C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/common/json/badoption"
)
type _ScriptSourceOptions struct {
Source string `json:"source"`
LocalOptions LocalScriptSource `json:"-"`
RemoteOptions RemoteScriptSource `json:"-"`
}
type LocalScriptSource struct {
Path string `json:"path"`
}
type RemoteScriptSource struct {
URL string `json:"url"`
DownloadDetour string `json:"download_detour,omitempty"`
UpdateInterval badoption.Duration `json:"update_interval,omitempty"`
}
type ScriptSourceOptions _ScriptSourceOptions
func (o ScriptSourceOptions) MarshalJSON() ([]byte, error) {
var source any
switch o.Source {
case C.ScriptSourceLocal:
source = o.LocalOptions
case C.ScriptSourceRemote:
source = o.RemoteOptions
default:
return nil, E.New("unknown script source: ", o.Source)
}
return badjson.MarshallObjects((_ScriptSourceOptions)(o), source)
}
func (o *ScriptSourceOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_ScriptSourceOptions)(o))
if err != nil {
return err
}
var source any
switch o.Source {
case C.ScriptSourceLocal:
source = &o.LocalOptions
case C.ScriptSourceRemote:
source = &o.RemoteOptions
default:
return E.New("unknown script source: ", o.Source)
}
return json.Unmarshal(bytes, source)
}
// TODO: make struct in order
type Script struct {
ScriptSourceOptions
ScriptOptions
}
func (s Script) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(s.ScriptSourceOptions, s.ScriptOptions)
}
func (s *Script) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, &s.ScriptSourceOptions)
if err != nil {
return err
}
return badjson.UnmarshallExcluded(bytes, &s.ScriptSourceOptions, &s.ScriptOptions)
}
type _ScriptOptions struct {
Type string `json:"type"`
Tag string `json:"tag"`
Timeout badoption.Duration `json:"timeout,omitempty"`
Arguments []any `json:"arguments,omitempty"`
HTTPOptions HTTPScriptOptions `json:"-"`
CronOptions CronScriptOptions `json:"-"`
}
type ScriptOptions _ScriptOptions
func (o ScriptOptions) MarshalJSON() ([]byte, error) {
var v any
switch o.Type {
case C.ScriptTypeSurgeGeneric:
v = nil
case C.ScriptTypeSurgeHTTPRequest, C.ScriptTypeSurgeHTTPResponse:
v = o.HTTPOptions
case C.ScriptTypeSurgeCron:
v = o.CronOptions
default:
return nil, E.New("unknown script type: ", o.Type)
}
if v == nil {
return badjson.MarshallObjects((_ScriptOptions)(o))
}
return badjson.MarshallObjects((_ScriptOptions)(o), v)
}
func (o *ScriptOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_ScriptOptions)(o))
if err != nil {
return err
}
var v any
switch o.Type {
case C.ScriptTypeSurgeGeneric:
v = nil
case C.ScriptTypeSurgeHTTPRequest, C.ScriptTypeSurgeHTTPResponse:
v = &o.HTTPOptions
case C.ScriptTypeSurgeCron:
v = &o.CronOptions
default:
return E.New("unknown script type: ", o.Type)
}
if v == nil {
// check unknown fields
return json.UnmarshalDisallowUnknownFields(bytes, &_ScriptOptions{})
}
return badjson.UnmarshallExcluded(bytes, (*_ScriptOptions)(o), v)
}
type HTTPScriptOptions struct {
Pattern string `json:"pattern"`
RequiresBody bool `json:"requires_body,omitempty"`
MaxSize int64 `json:"max_size,omitempty"`
BinaryBodyMode bool `json:"binary_body_mode,omitempty"`
}
type CronScriptOptions struct {
Expression string `json:"expression"`
}

View file

@ -21,23 +21,31 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
)
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct {
ctx context.Context
logger logger.ContextLogger
mitm adapter.MITMEngine
access sync.Mutex
connections list.List[io.Closer]
}
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
func NewConnectionManager(ctx context.Context, logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{
ctx: ctx,
logger: logger,
}
}
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateInitialize:
m.mitm = service.FromContext[adapter.MITMEngine](m.ctx)
}
return nil
}
@ -52,6 +60,14 @@ func (m *ConnectionManager) Close() error {
}
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.MITM != nil && metadata.MITM.Enabled {
if m.mitm == nil {
m.logger.WarnContext(ctx, "MITM disabled")
} else {
m.mitm.NewConnection(ctx, this, conn, metadata, onClose)
return
}
}
ctx = adapter.WithContext(ctx, &metadata)
var (
remoteConn net.Conn

View file

@ -458,6 +458,9 @@ match:
metadata.TLSFragment = true
metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay
}
if routeOptions.MITM != nil && routeOptions.MITM.Enabled {
metadata.MITM = routeOptions.MITM
}
}
switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff:

View file

@ -38,6 +38,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPConnect: action.RouteOptions.UDPConnect,
TLSFragment: action.RouteOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
MITM: action.RouteOptions.MITM,
},
}, nil
case C.RuleActionTypeRouteOptions:
@ -51,6 +52,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
TLSFragment: action.RouteOptionsOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
MITM: action.RouteOptionsOptions.MITM,
}, nil
case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@ -140,15 +142,7 @@ func (r *RuleActionRoute) Type() string {
func (r *RuleActionRoute) String() string {
var descriptions []string
descriptions = append(descriptions, r.Outbound)
if r.UDPDisableDomainUnmapping {
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
if r.UDPConnect {
descriptions = append(descriptions, "udp-connect")
}
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
descriptions = append(descriptions, r.Descriptions()...)
return F.ToString("route(", strings.Join(descriptions, ","), ")")
}
@ -164,14 +158,33 @@ type RuleActionRouteOptions struct {
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
MITM *option.MITMRouteOptions
}
func (r *RuleActionRouteOptions) Type() string {
return C.RuleActionTypeRouteOptions
}
func (r *RuleActionRouteOptions) String() string {
func (r *RuleActionRouteOptions) Descriptions() []string {
var descriptions []string
if r.OverrideAddress.IsValid() {
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.String()))
}
if r.OverridePort > 0 {
descriptions = append(descriptions, F.ToString("override-port=", r.OverridePort))
}
if r.NetworkStrategy != nil {
descriptions = append(descriptions, F.ToString("network-strategy=", r.NetworkStrategy))
}
if r.NetworkType != nil {
descriptions = append(descriptions, F.ToString("network-type=", strings.Join(common.Map(r.NetworkType, C.InterfaceType.String), ",")))
}
if r.FallbackNetworkType != nil {
descriptions = append(descriptions, F.ToString("fallback-network-type="+strings.Join(common.Map(r.NetworkType, C.InterfaceType.String), ",")))
}
if r.FallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("fallback-delay=", r.FallbackDelay.String()))
}
if r.UDPDisableDomainUnmapping {
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
@ -179,9 +192,22 @@ func (r *RuleActionRouteOptions) String() string {
descriptions = append(descriptions, "udp-connect")
}
if r.UDPTimeout > 0 {
descriptions = append(descriptions, "udp-timeout")
descriptions = append(descriptions, F.ToString("udp-timeout=", r.UDPTimeout))
}
return F.ToString("route-options(", strings.Join(descriptions, ","), ")")
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
if r.TLSFragmentFallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("tls-fragment-fallbac-delay=", r.TLSFragmentFallbackDelay.String()))
}
}
if r.MITM != nil && r.MITM.Enabled {
descriptions = append(descriptions, "mitm")
}
return descriptions
}
func (r *RuleActionRouteOptions) String() string {
return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
}
type RuleActionDNSRoute struct {

23
script/jsc/array.go Normal file
View file

@ -0,0 +1,23 @@
package jsc
import (
_ "unsafe"
"github.com/dop251/goja"
)
func NewUint8Array(runtime *goja.Runtime, data []byte) goja.Value {
buffer := runtime.NewArrayBuffer(data)
ctor, loaded := goja.AssertConstructor(runtimeGetUint8Array(runtime))
if !loaded {
panic(runtime.NewTypeError("missing UInt8Array constructor"))
}
array, err := ctor(nil, runtime.ToValue(buffer))
if err != nil {
panic(runtime.NewGoError(err))
}
return array
}
//go:linkname runtimeGetUint8Array github.com/dop251/goja.(*Runtime).getUint8Array
func runtimeGetUint8Array(r *goja.Runtime) *goja.Object

18
script/jsc/array_test.go Normal file
View file

@ -0,0 +1,18 @@
package jsc_test
import (
"testing"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestNewUInt8Array(t *testing.T) {
runtime := goja.New()
runtime.Set("hello", jsc.NewUint8Array(runtime, []byte("world")))
result, err := runtime.RunString("hello instanceof Uint8Array")
require.NoError(t, err)
require.True(t, result.ToBoolean())
}

121
script/jsc/assert.go Normal file
View file

@ -0,0 +1,121 @@
package jsc
import (
"net/http"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
)
func IsNil(value goja.Value) bool {
return value == nil || goja.IsUndefined(value) || goja.IsNull(value)
}
func AssertObject(vm *goja.Runtime, value goja.Value, name string, nilable bool) *goja.Object {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
objectValue, isObject := value.(*goja.Object)
if !isObject {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected object, but got ", value)))
}
return objectValue
}
func AssertString(vm *goja.Runtime, value goja.Value, name string, nilable bool) string {
if IsNil(value) {
if nilable {
return ""
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
stringValue, isString := value.Export().(string)
if !isString {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, but got ", value)))
}
return stringValue
}
func AssertInt(vm *goja.Runtime, value goja.Value, name string, nilable bool) int64 {
if IsNil(value) {
if nilable {
return 0
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
integerValue, isNumber := value.Export().(int64)
if !isNumber {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected integer, but got ", value)))
}
return integerValue
}
func AssertBool(vm *goja.Runtime, value goja.Value, name string, nilable bool) bool {
if IsNil(value) {
if nilable {
return false
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
boolValue, isBool := value.Export().(bool)
if !isBool {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected boolean, but got ", value)))
}
return boolValue
}
func AssertBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
switch exportedValue := value.Export().(type) {
case []byte:
return exportedValue
case goja.ArrayBuffer:
return exportedValue.Bytes()
default:
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected Uint8Array or ArrayBuffer, but got ", value)))
}
}
func AssertStringBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
switch exportedValue := value.Export().(type) {
case string:
return []byte(exportedValue)
case []byte:
return exportedValue
case goja.ArrayBuffer:
return exportedValue.Bytes()
default:
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, Uint8Array or ArrayBuffer, but got ", value)))
}
}
func AssertFunction(vm *goja.Runtime, value goja.Value, name string) goja.Callable {
functionValue, isFunction := goja.AssertFunction(value)
if !isFunction {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected function, but got ", value)))
}
return functionValue
}
func AssertHTTPHeader(vm *goja.Runtime, value goja.Value, name string) http.Header {
headersObject := AssertObject(vm, value, name, true)
if headersObject == nil {
return nil
}
return ObjectToHeaders(vm, headersObject, name)
}

56
script/jsc/headers.go Normal file
View file

@ -0,0 +1,56 @@
package jsc
import (
"net/http"
"reflect"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
)
func HeadersToValue(runtime *goja.Runtime, headers http.Header) goja.Value {
object := runtime.NewObject()
for key, value := range headers {
if len(value) == 1 {
object.Set(key, value[0])
} else {
object.Set(key, ArrayToValue(runtime, value))
}
}
return object
}
func ArrayToValue[T any](runtime *goja.Runtime, values []T) goja.Value {
return runtime.NewArray(common.Map(values, func(it T) any { return it })...)
}
func ObjectToHeaders(vm *goja.Runtime, object *goja.Object, name string) http.Header {
headers := make(http.Header)
for _, key := range object.Keys() {
valueObject := object.Get(key)
switch headerValue := valueObject.(type) {
case goja.String:
headers.Set(key, headerValue.String())
case *goja.Object:
values := headerValue.Export()
valueArray, isArray := values.([]any)
if !isArray {
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, "expected string or string array, got ", valueObject.String())))
}
newValues := make([]string, 0, len(valueArray))
for _, value := range valueArray {
stringValue, isString := value.(string)
if !isString {
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got array item type: ", reflect.TypeOf(value))))
}
newValues = append(newValues, stringValue)
}
headers[key] = newValues
default:
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got ", valueObject.String())))
}
}
return headers
}

View file

@ -0,0 +1,31 @@
package jsc_test
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestHeaders(t *testing.T) {
runtime := goja.New()
runtime.Set("headers", jsc.HeadersToValue(runtime, http.Header{
"My-Header": []string{"My-Value1", "My-Value2"},
}))
headers := runtime.Get("headers").(*goja.Object).Get("My-Header").(*goja.Object)
fmt.Println(reflect.ValueOf(headers.Export()).Type().String())
}
func TestBody(t *testing.T) {
runtime := goja.New()
_, err := runtime.RunString(`
var responseBody = new Uint8Array([1, 2, 3, 4, 5])
`)
require.NoError(t, err)
fmt.Println(reflect.TypeOf(runtime.Get("responseBody").Export()))
}

18
script/jsc/time.go Normal file
View file

@ -0,0 +1,18 @@
package jsc
import (
"time"
_ "unsafe"
"github.com/dop251/goja"
)
func TimeToValue(runtime *goja.Runtime, time time.Time) goja.Value {
return runtimeNewDateObject(runtime, time, true, runtimeGetDatePrototype(runtime))
}
//go:linkname runtimeNewDateObject github.com/dop251/goja.(*Runtime).newDateObject
func runtimeNewDateObject(r *goja.Runtime, t time.Time, isSet bool, proto *goja.Object) *goja.Object
//go:linkname runtimeGetDatePrototype github.com/dop251/goja.(*Runtime).getDatePrototype
func runtimeGetDatePrototype(r *goja.Runtime) *goja.Object

20
script/jsc/time_test.go Normal file
View file

@ -0,0 +1,20 @@
package jsc_test
import (
"testing"
"time"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestTimeToValue(t *testing.T) {
t.Parallel()
runtime := goja.New()
now := time.Now()
err := runtime.Set("now", jsc.TimeToValue(runtime, now))
require.NoError(t, err)
println(runtime.Get("now").String())
}

107
script/manager.go Normal file
View file

@ -0,0 +1,107 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/task"
)
var _ adapter.ScriptManager = (*Manager)(nil)
type Manager struct {
ctx context.Context
logger logger.ContextLogger
scripts []adapter.Script
// scriptByName map[string]adapter.Script
}
func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Script) (*Manager, error) {
manager := &Manager{
ctx: ctx,
logger: logFactory.NewLogger("script"),
// scriptByName: make(map[string]adapter.Script),
}
for _, scriptOptions := range scripts {
script, err := NewScript(ctx, logFactory.NewLogger(F.ToString("script/", scriptOptions.Type, "[", scriptOptions.Tag, "]")), scriptOptions)
if err != nil {
return nil, E.Cause(err, "initialize script: ", scriptOptions.Tag)
}
manager.scripts = append(manager.scripts, script)
// manager.scriptByName[scriptOptions.Tag] = script
}
return manager, nil
}
func (m *Manager) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
switch stage {
case adapter.StartStateStart:
var cacheContext *adapter.HTTPStartContext
if len(m.scripts) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext(m.ctx)
var scriptStartGroup task.Group
for _, script := range m.scripts {
scriptInPlace := script
scriptStartGroup.Append0(func(ctx context.Context) error {
err := scriptInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize script/", scriptInPlace.Type(), "[", scriptInPlace.Tag(), "]")
}
return nil
})
}
scriptStartGroup.Concurrency(5)
scriptStartGroup.FastFail()
err := scriptStartGroup.Run(m.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
case adapter.StartStatePostStart:
for _, script := range m.scripts {
monitor.Start(F.ToString("post start script/", script.Type(), "[", script.Tag(), "]"))
err := script.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start script/", script.Type(), "[", script.Tag(), "]")
}
}
}
return nil
}
func (m *Manager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, script := range m.scripts {
monitor.Start(F.ToString("close start script/", script.Type(), "[", script.Tag(), "]"))
err = E.Append(err, script.Close(), func(err error) error {
return E.Cause(err, "close script/", script.Type(), "[", script.Tag(), "]")
})
monitor.Finish()
}
return err
}
func (m *Manager) Scripts() []adapter.Script {
return m.scripts
}
/*
func (m *Manager) Script(name string) (adapter.Script, bool) {
script, loaded := m.scriptByName[name]
return script, loaded
}*/

View file

@ -0,0 +1,108 @@
package console
import (
"bytes"
"context"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
const ModuleName = "console"
type Console struct {
vm *goja.Runtime
}
func (c *Console) log(ctx context.Context, p func(ctx context.Context, values ...any)) func(goja.FunctionCall) goja.Value {
return func(call goja.FunctionCall) goja.Value {
var buffer bytes.Buffer
var format string
if arg := call.Argument(0); !goja.IsUndefined(arg) {
format = arg.String()
}
var args []goja.Value
if len(call.Arguments) > 0 {
args = call.Arguments[1:]
}
c.Format(&buffer, format, args...)
p(ctx, buffer.String())
return nil
}
}
func (c *Console) Format(b *bytes.Buffer, f string, args ...goja.Value) {
pct := false
argNum := 0
for _, chr := range f {
if pct {
if argNum < len(args) {
if c.format(chr, args[argNum], b) {
argNum++
}
} else {
b.WriteByte('%')
b.WriteRune(chr)
}
pct = false
} else {
if chr == '%' {
pct = true
} else {
b.WriteRune(chr)
}
}
}
for _, arg := range args[argNum:] {
b.WriteByte(' ')
b.WriteString(arg.String())
}
}
func (c *Console) format(f rune, val goja.Value, w *bytes.Buffer) bool {
switch f {
case 's':
w.WriteString(val.String())
case 'd':
w.WriteString(val.ToNumber().String())
case 'j':
if json, ok := c.vm.Get("JSON").(*goja.Object); ok {
if stringify, ok := goja.AssertFunction(json.Get("stringify")); ok {
res, err := stringify(json, val)
if err != nil {
panic(err)
}
w.WriteString(res.String())
}
}
case '%':
w.WriteByte('%')
return false
default:
w.WriteByte('%')
w.WriteRune(f)
return false
}
return true
}
func Require(ctx context.Context, logger logger.ContextLogger) require.ModuleLoader {
return func(runtime *goja.Runtime, module *goja.Object) {
c := &Console{
vm: runtime,
}
o := module.Get("exports").(*goja.Object)
o.Set("log", c.log(ctx, logger.DebugContext))
o.Set("error", c.log(ctx, logger.ErrorContext))
o.Set("warn", c.log(ctx, logger.WarnContext))
o.Set("info", c.log(ctx, logger.InfoContext))
o.Set("debug", c.log(ctx, logger.DebugContext))
}
}
func Enable(runtime *goja.Runtime) {
runtime.Set("console", require.Require(runtime, ModuleName))
}

View file

@ -0,0 +1,489 @@
package eventloop
import (
"sync"
"sync/atomic"
"time"
"github.com/dop251/goja"
)
type job struct {
cancel func() bool
fn func()
idx int
cancelled bool
}
type Timer struct {
job
timer *time.Timer
}
type Interval struct {
job
ticker *time.Ticker
stopChan chan struct{}
}
type Immediate struct {
job
}
type EventLoop struct {
vm *goja.Runtime
jobChan chan func()
jobs []*job
jobCount int32
canRun int32
auxJobsLock sync.Mutex
wakeupChan chan struct{}
auxJobsSpare, auxJobs []func()
stopLock sync.Mutex
stopCond *sync.Cond
running bool
terminated bool
errorHandler func(error)
}
func Enable(runtime *goja.Runtime, errorHandler func(error)) *EventLoop {
loop := &EventLoop{
vm: runtime,
jobChan: make(chan func()),
wakeupChan: make(chan struct{}, 1),
errorHandler: errorHandler,
}
loop.stopCond = sync.NewCond(&loop.stopLock)
runtime.Set("setTimeout", loop.setTimeout)
runtime.Set("setInterval", loop.setInterval)
runtime.Set("setImmediate", loop.setImmediate)
runtime.Set("clearTimeout", loop.clearTimeout)
runtime.Set("clearInterval", loop.clearInterval)
runtime.Set("clearImmediate", loop.clearImmediate)
return loop
}
func (loop *EventLoop) schedule(call goja.FunctionCall, repeating bool) goja.Value {
if fn, ok := goja.AssertFunction(call.Argument(0)); ok {
delay := call.Argument(1).ToInteger()
var args []goja.Value
if len(call.Arguments) > 2 {
args = append(args, call.Arguments[2:]...)
}
f := func() {
_, err := fn(nil, args...)
if err != nil {
loop.errorHandler(err)
}
}
loop.jobCount++
var job *job
var ret goja.Value
if repeating {
interval := loop.newInterval(f)
interval.start(loop, time.Duration(delay)*time.Millisecond)
job = &interval.job
ret = loop.vm.ToValue(interval)
} else {
timeout := loop.newTimeout(f)
timeout.start(loop, time.Duration(delay)*time.Millisecond)
job = &timeout.job
ret = loop.vm.ToValue(timeout)
}
job.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, job)
return ret
}
return nil
}
func (loop *EventLoop) setTimeout(call goja.FunctionCall) goja.Value {
return loop.schedule(call, false)
}
func (loop *EventLoop) setInterval(call goja.FunctionCall) goja.Value {
return loop.schedule(call, true)
}
func (loop *EventLoop) setImmediate(call goja.FunctionCall) goja.Value {
if fn, ok := goja.AssertFunction(call.Argument(0)); ok {
var args []goja.Value
if len(call.Arguments) > 1 {
args = append(args, call.Arguments[1:]...)
}
f := func() {
_, err := fn(nil, args...)
if err != nil {
loop.errorHandler(err)
}
}
loop.jobCount++
return loop.vm.ToValue(loop.addImmediate(f))
}
return nil
}
// SetTimeout schedules to run the specified function in the context
// of the loop as soon as possible after the specified timeout period.
// SetTimeout returns a Timer which can be passed to ClearTimeout.
// The instance of goja.Runtime that is passed to the function and any Values derived
// from it must not be used outside the function. SetTimeout is
// safe to call inside or outside the loop.
// If the loop is terminated (see Terminate()) returns nil.
func (loop *EventLoop) SetTimeout(fn func(*goja.Runtime), timeout time.Duration) *Timer {
t := loop.newTimeout(func() { fn(loop.vm) })
if loop.addAuxJob(func() {
t.start(loop, timeout)
loop.jobCount++
t.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, &t.job)
}) {
return t
}
return nil
}
// ClearTimeout cancels a Timer returned by SetTimeout if it has not run yet.
// ClearTimeout is safe to call inside or outside the loop.
func (loop *EventLoop) ClearTimeout(t *Timer) {
loop.addAuxJob(func() {
loop.clearTimeout(t)
})
}
// SetInterval schedules to repeatedly run the specified function in
// the context of the loop as soon as possible after every specified
// timeout period. SetInterval returns an Interval which can be
// passed to ClearInterval. The instance of goja.Runtime that is passed to the
// function and any Values derived from it must not be used outside
// the function. SetInterval is safe to call inside or outside the
// loop.
// If the loop is terminated (see Terminate()) returns nil.
func (loop *EventLoop) SetInterval(fn func(*goja.Runtime), timeout time.Duration) *Interval {
i := loop.newInterval(func() { fn(loop.vm) })
if loop.addAuxJob(func() {
i.start(loop, timeout)
loop.jobCount++
i.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, &i.job)
}) {
return i
}
return nil
}
// ClearInterval cancels an Interval returned by SetInterval.
// ClearInterval is safe to call inside or outside the loop.
func (loop *EventLoop) ClearInterval(i *Interval) {
loop.addAuxJob(func() {
loop.clearInterval(i)
})
}
func (loop *EventLoop) setRunning() {
loop.stopLock.Lock()
defer loop.stopLock.Unlock()
if loop.running {
panic("Loop is already started")
}
loop.running = true
atomic.StoreInt32(&loop.canRun, 1)
loop.auxJobsLock.Lock()
loop.terminated = false
loop.auxJobsLock.Unlock()
}
// Run calls the specified function, starts the event loop and waits until there are no more delayed jobs to run
// after which it stops the loop and returns.
// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used
// outside the function.
// Do NOT use this function while the loop is already running. Use RunOnLoop() instead.
// If the loop is already started it will panic.
func (loop *EventLoop) Run(fn func(*goja.Runtime)) {
loop.setRunning()
fn(loop.vm)
loop.run(false)
}
// Start the event loop in the background. The loop continues to run until Stop() is called.
// If the loop is already started it will panic.
func (loop *EventLoop) Start() {
loop.setRunning()
go loop.run(true)
}
// StartInForeground starts the event loop in the current goroutine. The loop continues to run until Stop() is called.
// If the loop is already started it will panic.
// Use this instead of Start if you want to recover from panics that may occur while calling native Go functions from
// within setInterval and setTimeout callbacks.
func (loop *EventLoop) StartInForeground() {
loop.setRunning()
loop.run(true)
}
// Stop the loop that was started with Start(). After this function returns there will be no more jobs executed
// by the loop. It is possible to call Start() or Run() again after this to resume the execution.
// Note, it does not cancel active timeouts (use Terminate() instead if you want this).
// It is not allowed to run Start() (or Run()) and Stop() or Terminate() concurrently.
// Calling Stop() on a non-running loop has no effect.
// It is not allowed to call Stop() from the loop, because it is synchronous and cannot complete until the loop
// is not running any jobs. Use StopNoWait() instead.
// return number of jobs remaining
func (loop *EventLoop) Stop() int {
loop.stopLock.Lock()
for loop.running {
atomic.StoreInt32(&loop.canRun, 0)
loop.wakeup()
loop.stopCond.Wait()
}
loop.stopLock.Unlock()
return int(loop.jobCount)
}
// StopNoWait tells the loop to stop and returns immediately. Can be used inside the loop. Calling it on a
// non-running loop has no effect.
func (loop *EventLoop) StopNoWait() {
loop.stopLock.Lock()
if loop.running {
atomic.StoreInt32(&loop.canRun, 0)
loop.wakeup()
}
loop.stopLock.Unlock()
}
// Terminate stops the loop and clears all active timeouts and intervals. After it returns there are no
// active timers or goroutines associated with the loop. Any attempt to submit a task (by using RunOnLoop(),
// SetTimeout() or SetInterval()) will not succeed.
// After being terminated the loop can be restarted again by using Start() or Run().
// This method must not be called concurrently with Stop*(), Start(), or Run().
func (loop *EventLoop) Terminate() {
loop.Stop()
loop.auxJobsLock.Lock()
loop.terminated = true
loop.auxJobsLock.Unlock()
loop.runAux()
for i := 0; i < len(loop.jobs); i++ {
job := loop.jobs[i]
if !job.cancelled {
job.cancelled = true
if job.cancel() {
loop.removeJob(job)
i--
}
}
}
for len(loop.jobs) > 0 {
(<-loop.jobChan)()
}
}
// RunOnLoop schedules to run the specified function in the context of the loop as soon as possible.
// The order of the runs is preserved (i.e. the functions will be called in the same order as calls to RunOnLoop())
// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used
// outside the function. It is safe to call inside or outside the loop.
// Returns true on success or false if the loop is terminated (see Terminate()).
func (loop *EventLoop) RunOnLoop(fn func(*goja.Runtime)) bool {
return loop.addAuxJob(func() { fn(loop.vm) })
}
func (loop *EventLoop) runAux() {
loop.auxJobsLock.Lock()
jobs := loop.auxJobs
loop.auxJobs = loop.auxJobsSpare
loop.auxJobsLock.Unlock()
for i, job := range jobs {
job()
jobs[i] = nil
}
loop.auxJobsSpare = jobs[:0]
}
func (loop *EventLoop) run(inBackground bool) {
loop.runAux()
if inBackground {
loop.jobCount++
}
LOOP:
for loop.jobCount > 0 {
select {
case job := <-loop.jobChan:
job()
case <-loop.wakeupChan:
loop.runAux()
if atomic.LoadInt32(&loop.canRun) == 0 {
break LOOP
}
}
}
if inBackground {
loop.jobCount--
}
loop.stopLock.Lock()
loop.running = false
loop.stopLock.Unlock()
loop.stopCond.Broadcast()
}
func (loop *EventLoop) wakeup() {
select {
case loop.wakeupChan <- struct{}{}:
default:
}
}
func (loop *EventLoop) addAuxJob(fn func()) bool {
loop.auxJobsLock.Lock()
if loop.terminated {
loop.auxJobsLock.Unlock()
return false
}
loop.auxJobs = append(loop.auxJobs, fn)
loop.auxJobsLock.Unlock()
loop.wakeup()
return true
}
func (loop *EventLoop) newTimeout(f func()) *Timer {
t := &Timer{
job: job{fn: f},
}
t.cancel = t.doCancel
return t
}
func (t *Timer) start(loop *EventLoop, timeout time.Duration) {
t.timer = time.AfterFunc(timeout, func() {
loop.jobChan <- func() {
loop.doTimeout(t)
}
})
}
func (loop *EventLoop) newInterval(f func()) *Interval {
i := &Interval{
job: job{fn: f},
stopChan: make(chan struct{}),
}
i.cancel = i.doCancel
return i
}
func (i *Interval) start(loop *EventLoop, timeout time.Duration) {
// https://nodejs.org/api/timers.html#timers_setinterval_callback_delay_args
if timeout <= 0 {
timeout = time.Millisecond
}
i.ticker = time.NewTicker(timeout)
go i.run(loop)
}
func (loop *EventLoop) addImmediate(f func()) *Immediate {
i := &Immediate{
job: job{fn: f},
}
loop.addAuxJob(func() {
loop.doImmediate(i)
})
return i
}
func (loop *EventLoop) doTimeout(t *Timer) {
loop.removeJob(&t.job)
if !t.cancelled {
t.cancelled = true
loop.jobCount--
t.fn()
}
}
func (loop *EventLoop) doInterval(i *Interval) {
if !i.cancelled {
i.fn()
}
}
func (loop *EventLoop) doImmediate(i *Immediate) {
if !i.cancelled {
i.cancelled = true
loop.jobCount--
i.fn()
}
}
func (loop *EventLoop) clearTimeout(t *Timer) {
if t != nil && !t.cancelled {
t.cancelled = true
loop.jobCount--
if t.doCancel() {
loop.removeJob(&t.job)
}
}
}
func (loop *EventLoop) clearInterval(i *Interval) {
if i != nil && !i.cancelled {
i.cancelled = true
loop.jobCount--
i.doCancel()
}
}
func (loop *EventLoop) removeJob(job *job) {
idx := job.idx
if idx < 0 {
return
}
if idx < len(loop.jobs)-1 {
loop.jobs[idx] = loop.jobs[len(loop.jobs)-1]
loop.jobs[idx].idx = idx
}
loop.jobs[len(loop.jobs)-1] = nil
loop.jobs = loop.jobs[:len(loop.jobs)-1]
job.idx = -1
}
func (loop *EventLoop) clearImmediate(i *Immediate) {
if i != nil && !i.cancelled {
i.cancelled = true
loop.jobCount--
}
}
func (i *Interval) doCancel() bool {
close(i.stopChan)
return false
}
func (t *Timer) doCancel() bool {
return t.timer.Stop()
}
func (i *Interval) run(loop *EventLoop) {
L:
for {
select {
case <-i.stopChan:
i.ticker.Stop()
break L
case <-i.ticker.C:
loop.jobChan <- func() {
loop.doInterval(i)
}
}
}
loop.jobChan <- func() {
loop.removeJob(&i.job)
}
}

View file

@ -0,0 +1,231 @@
package require
import (
"errors"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"runtime"
"sync"
"syscall"
"text/template"
js "github.com/dop251/goja"
"github.com/dop251/goja/parser"
)
type ModuleLoader func(*js.Runtime, *js.Object)
// SourceLoader represents a function that returns a file data at a given path.
// The function should return ModuleFileDoesNotExistError if the file either doesn't exist or is a directory.
// This error will be ignored by the resolver and the search will continue. Any other errors will be propagated.
type SourceLoader func(path string) ([]byte, error)
var (
InvalidModuleError = errors.New("Invalid module")
IllegalModuleNameError = errors.New("Illegal module name")
NoSuchBuiltInModuleError = errors.New("No such built-in module")
ModuleFileDoesNotExistError = errors.New("module file does not exist")
)
// Registry contains a cache of compiled modules which can be used by multiple Runtimes
type Registry struct {
sync.Mutex
native map[string]ModuleLoader
builtin map[string]ModuleLoader
compiled map[string]*js.Program
srcLoader SourceLoader
globalFolders []string
fsEnabled bool
}
type RequireModule struct {
r *Registry
runtime *js.Runtime
modules map[string]*js.Object
nodeModules map[string]*js.Object
}
func NewRegistry(opts ...Option) *Registry {
r := &Registry{}
for _, opt := range opts {
opt(r)
}
return r
}
type Option func(*Registry)
// WithLoader sets a function which will be called by the require() function in order to get a source code for a
// module at the given path. The same function will be used to get external source maps.
// Note, this only affects the modules loaded by the require() function. If you need to use it as a source map
// loader for code parsed in a different way (such as runtime.RunString() or eval()), use (*Runtime).SetParserOptions()
func WithLoader(srcLoader SourceLoader) Option {
return func(r *Registry) {
r.srcLoader = srcLoader
}
}
// WithGlobalFolders appends the given paths to the registry's list of
// global folders to search if the requested module is not found
// elsewhere. By default, a registry's global folders list is empty.
// In the reference Node.js implementation, the default global folders
// list is $NODE_PATH, $HOME/.node_modules, $HOME/.node_libraries and
// $PREFIX/lib/node, see
// https://nodejs.org/api/modules.html#modules_loading_from_the_global_folders.
func WithGlobalFolders(globalFolders ...string) Option {
return func(r *Registry) {
r.globalFolders = globalFolders
}
}
func WithFsEnable(enabled bool) Option {
return func(r *Registry) {
r.fsEnabled = enabled
}
}
// Enable adds the require() function to the specified runtime.
func (r *Registry) Enable(runtime *js.Runtime) *RequireModule {
rrt := &RequireModule{
r: r,
runtime: runtime,
modules: make(map[string]*js.Object),
nodeModules: make(map[string]*js.Object),
}
runtime.Set("require", rrt.require)
return rrt
}
func (r *Registry) RegisterNodeModule(name string, loader ModuleLoader) {
r.Lock()
defer r.Unlock()
if r.builtin == nil {
r.builtin = make(map[string]ModuleLoader)
}
name = filepathClean(name)
r.builtin[name] = loader
}
func (r *Registry) RegisterNativeModule(name string, loader ModuleLoader) {
r.Lock()
defer r.Unlock()
if r.native == nil {
r.native = make(map[string]ModuleLoader)
}
name = filepathClean(name)
r.native[name] = loader
}
// DefaultSourceLoader is used if none was set (see WithLoader()). It simply loads files from the host's filesystem.
func DefaultSourceLoader(filename string) ([]byte, error) {
fp := filepath.FromSlash(filename)
f, err := os.Open(fp)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err = ModuleFileDoesNotExistError
} else if runtime.GOOS == "windows" {
if errors.Is(err, syscall.Errno(0x7b)) { // ERROR_INVALID_NAME, The filename, directory name, or volume label syntax is incorrect.
err = ModuleFileDoesNotExistError
}
}
return nil, err
}
defer f.Close()
// On some systems (e.g. plan9 and FreeBSD) it is possible to use the standard read() call on directories
// which means we cannot rely on read() returning an error, we have to do stat() instead.
if fi, err := f.Stat(); err == nil {
if fi.IsDir() {
return nil, ModuleFileDoesNotExistError
}
} else {
return nil, err
}
return io.ReadAll(f)
}
func (r *Registry) getSource(p string) ([]byte, error) {
srcLoader := r.srcLoader
if srcLoader == nil {
srcLoader = DefaultSourceLoader
}
return srcLoader(p)
}
func (r *Registry) getCompiledSource(p string) (*js.Program, error) {
r.Lock()
defer r.Unlock()
prg := r.compiled[p]
if prg == nil {
buf, err := r.getSource(p)
if err != nil {
return nil, err
}
s := string(buf)
if path.Ext(p) == ".json" {
s = "module.exports = JSON.parse('" + template.JSEscapeString(s) + "')"
}
source := "(function(exports, require, module) {" + s + "\n})"
parsed, err := js.Parse(p, source, parser.WithSourceMapLoader(r.srcLoader))
if err != nil {
return nil, err
}
prg, err = js.CompileAST(parsed, false)
if err == nil {
if r.compiled == nil {
r.compiled = make(map[string]*js.Program)
}
r.compiled[p] = prg
}
return prg, err
}
return prg, nil
}
func (r *RequireModule) require(call js.FunctionCall) js.Value {
ret, err := r.Require(call.Argument(0).String())
if err != nil {
if _, ok := err.(*js.Exception); !ok {
panic(r.runtime.NewGoError(err))
}
panic(err)
}
return ret
}
func filepathClean(p string) string {
return path.Clean(p)
}
// Require can be used to import modules from Go source (similar to JS require() function).
func (r *RequireModule) Require(p string) (ret js.Value, err error) {
module, err := r.resolve(p)
if err != nil {
return
}
ret = module.Get("exports")
return
}
func Require(runtime *js.Runtime, name string) js.Value {
if r, ok := js.AssertFunction(runtime.Get("require")); ok {
mod, err := r(js.Undefined(), runtime.ToValue(name))
if err != nil {
panic(err)
}
return mod
}
panic(runtime.NewTypeError("Please enable require for this runtime using new(require.Registry).Enable(runtime)"))
}

View file

@ -0,0 +1,277 @@
package require
import (
"encoding/json"
"errors"
"path"
"path/filepath"
"runtime"
"strings"
js "github.com/dop251/goja"
)
const NodePrefix = "node:"
// NodeJS module search algorithm described by
// https://nodejs.org/api/modules.html#modules_all_together
func (r *RequireModule) resolve(modpath string) (module *js.Object, err error) {
origPath, modpath := modpath, filepathClean(modpath)
if modpath == "" {
return nil, IllegalModuleNameError
}
var start string
err = nil
if path.IsAbs(origPath) {
start = "/"
} else {
start = r.getCurrentModulePath()
}
p := path.Join(start, modpath)
if isFileOrDirectoryPath(origPath) && r.r.fsEnabled {
if module = r.modules[p]; module != nil {
return
}
module, err = r.loadAsFileOrDirectory(p)
if err == nil && module != nil {
r.modules[p] = module
}
} else {
module, err = r.loadNative(origPath)
if err == nil {
return
} else {
if err == InvalidModuleError {
err = nil
} else {
return
}
}
if module = r.nodeModules[p]; module != nil {
return
}
if r.r.fsEnabled {
module, err = r.loadNodeModules(modpath, start)
if err == nil && module != nil {
r.nodeModules[p] = module
}
}
}
if module == nil && err == nil {
err = InvalidModuleError
}
return
}
func (r *RequireModule) loadNative(path string) (*js.Object, error) {
module := r.modules[path]
if module != nil {
return module, nil
}
var ldr ModuleLoader
if r.r.native != nil {
ldr = r.r.native[path]
}
var isBuiltIn, withPrefix bool
if ldr == nil {
if r.r.builtin != nil {
ldr = r.r.builtin[path]
}
if ldr == nil && strings.HasPrefix(path, NodePrefix) {
ldr = r.r.builtin[path[len(NodePrefix):]]
if ldr == nil {
return nil, NoSuchBuiltInModuleError
}
withPrefix = true
}
isBuiltIn = true
}
if ldr != nil {
module = r.createModuleObject()
r.modules[path] = module
if isBuiltIn {
if withPrefix {
r.modules[path[len(NodePrefix):]] = module
} else {
if !strings.HasPrefix(path, NodePrefix) {
r.modules[NodePrefix+path] = module
}
}
}
ldr(r.runtime, module)
return module, nil
}
return nil, InvalidModuleError
}
func (r *RequireModule) loadAsFileOrDirectory(path string) (module *js.Object, err error) {
if module, err = r.loadAsFile(path); module != nil || err != nil {
return
}
return r.loadAsDirectory(path)
}
func (r *RequireModule) loadAsFile(path string) (module *js.Object, err error) {
if module, err = r.loadModule(path); module != nil || err != nil {
return
}
p := path + ".js"
if module, err = r.loadModule(p); module != nil || err != nil {
return
}
p = path + ".json"
return r.loadModule(p)
}
func (r *RequireModule) loadIndex(modpath string) (module *js.Object, err error) {
p := path.Join(modpath, "index.js")
if module, err = r.loadModule(p); module != nil || err != nil {
return
}
p = path.Join(modpath, "index.json")
return r.loadModule(p)
}
func (r *RequireModule) loadAsDirectory(modpath string) (module *js.Object, err error) {
p := path.Join(modpath, "package.json")
buf, err := r.r.getSource(p)
if err != nil {
return r.loadIndex(modpath)
}
var pkg struct {
Main string
}
err = json.Unmarshal(buf, &pkg)
if err != nil || len(pkg.Main) == 0 {
return r.loadIndex(modpath)
}
m := path.Join(modpath, pkg.Main)
if module, err = r.loadAsFile(m); module != nil || err != nil {
return
}
return r.loadIndex(m)
}
func (r *RequireModule) loadNodeModule(modpath, start string) (*js.Object, error) {
return r.loadAsFileOrDirectory(path.Join(start, modpath))
}
func (r *RequireModule) loadNodeModules(modpath, start string) (module *js.Object, err error) {
for _, dir := range r.r.globalFolders {
if module, err = r.loadNodeModule(modpath, dir); module != nil || err != nil {
return
}
}
for {
var p string
if path.Base(start) != "node_modules" {
p = path.Join(start, "node_modules")
} else {
p = start
}
if module, err = r.loadNodeModule(modpath, p); module != nil || err != nil {
return
}
if start == ".." { // Dir('..') is '.'
break
}
parent := path.Dir(start)
if parent == start {
break
}
start = parent
}
return
}
func (r *RequireModule) getCurrentModulePath() string {
var buf [2]js.StackFrame
frames := r.runtime.CaptureCallStack(2, buf[:0])
if len(frames) < 2 {
return "."
}
return path.Dir(frames[1].SrcName())
}
func (r *RequireModule) createModuleObject() *js.Object {
module := r.runtime.NewObject()
module.Set("exports", r.runtime.NewObject())
return module
}
func (r *RequireModule) loadModule(path string) (*js.Object, error) {
module := r.modules[path]
if module == nil {
module = r.createModuleObject()
r.modules[path] = module
err := r.loadModuleFile(path, module)
if err != nil {
module = nil
delete(r.modules, path)
if errors.Is(err, ModuleFileDoesNotExistError) {
err = nil
}
}
return module, err
}
return module, nil
}
func (r *RequireModule) loadModuleFile(path string, jsModule *js.Object) error {
prg, err := r.r.getCompiledSource(path)
if err != nil {
return err
}
f, err := r.runtime.RunProgram(prg)
if err != nil {
return err
}
if call, ok := js.AssertFunction(f); ok {
jsExports := jsModule.Get("exports")
jsRequire := r.runtime.Get("require")
// Run the module source, with "jsExports" as "this",
// "jsExports" as the "exports" variable, "jsRequire"
// as the "require" variable and "jsModule" as the
// "module" variable (Nodejs capable).
_, err = call(jsExports, jsExports, jsRequire, jsModule)
if err != nil {
return err
}
} else {
return InvalidModuleError
}
return nil
}
func isFileOrDirectoryPath(path string) bool {
result := path == "." || path == ".." ||
strings.HasPrefix(path, "/") ||
strings.HasPrefix(path, "./") ||
strings.HasPrefix(path, "../")
if runtime.GOOS == "windows" {
result = result ||
strings.HasPrefix(path, `.\`) ||
strings.HasPrefix(path, `..\`) ||
filepath.IsAbs(path)
}
return result
}

View file

@ -0,0 +1,147 @@
package sghttp
import (
"bytes"
"context"
"crypto/tls"
"io"
"net/http"
"net/http/cookiejar"
"sync"
"time"
"github.com/sagernet/sing-box/script/jsc"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
"golang.org/x/net/publicsuffix"
)
type SurgeHTTP struct {
vm *goja.Runtime
ctx context.Context
cookieAccess sync.RWMutex
cookieJar *cookiejar.Jar
errorHandler func(error)
}
func Enable(vm *goja.Runtime, ctx context.Context, errorHandler func(error)) {
sgHTTP := &SurgeHTTP{
vm: vm,
ctx: ctx,
errorHandler: errorHandler,
}
httpObject := vm.NewObject()
httpObject.Set("get", sgHTTP.request(http.MethodGet))
httpObject.Set("post", sgHTTP.request(http.MethodPost))
httpObject.Set("put", sgHTTP.request(http.MethodPut))
httpObject.Set("delete", sgHTTP.request(http.MethodDelete))
httpObject.Set("head", sgHTTP.request(http.MethodHead))
httpObject.Set("options", sgHTTP.request(http.MethodOptions))
httpObject.Set("patch", sgHTTP.request(http.MethodPatch))
httpObject.Set("trace", sgHTTP.request(http.MethodTrace))
vm.Set("$http", httpObject)
}
func (s *SurgeHTTP) request(method string) func(call goja.FunctionCall) goja.Value {
return func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) != 2 {
panic(s.vm.NewTypeError("invalid arguments"))
}
var (
url string
headers http.Header
body []byte
timeout = 5 * time.Second
insecure bool
autoCookie bool
autoRedirect bool
// policy string
binaryMode bool
)
switch optionsValue := call.Argument(0).(type) {
case goja.String:
url = optionsValue.String()
case *goja.Object:
url = jsc.AssertString(s.vm, optionsValue.Get("url"), "options.url", false)
headers = jsc.AssertHTTPHeader(s.vm, optionsValue.Get("headers"), "option.headers")
body = jsc.AssertStringBinary(s.vm, optionsValue.Get("body"), "options.body", true)
timeoutInt := jsc.AssertInt(s.vm, optionsValue.Get("timeout"), "options.timeout", true)
if timeoutInt > 0 {
timeout = time.Duration(timeoutInt) * time.Second
}
insecure = jsc.AssertBool(s.vm, optionsValue.Get("insecure"), "options.insecure", true)
autoCookie = jsc.AssertBool(s.vm, optionsValue.Get("auto-cookie"), "options.auto-cookie", true)
autoRedirect = jsc.AssertBool(s.vm, optionsValue.Get("auto-redirect"), "options.auto-redirect", true)
// policy = jsc.AssertString(s.vm, optionsValue.Get("policy"), "options.policy", true)
binaryMode = jsc.AssertBool(s.vm, optionsValue.Get("binary-mode"), "options.binary-mode", true)
default:
panic(s.vm.NewTypeError(F.ToString("invalid argument: options: expected string or object, but got ", optionsValue)))
}
callback := jsc.AssertFunction(s.vm, call.Argument(1), "callback")
httpClient := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecure,
},
ForceAttemptHTTP2: true,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if autoRedirect {
return nil
}
return http.ErrUseLastResponse
},
}
if autoCookie {
s.cookieAccess.Lock()
if s.cookieJar == nil {
s.cookieJar, _ = cookiejar.New(&cookiejar.Options{
PublicSuffixList: publicsuffix.List,
})
}
httpClient.Jar = s.cookieJar
s.cookieAccess.Lock()
}
request, err := http.NewRequestWithContext(s.ctx, method, url, bytes.NewReader(body))
if host := headers.Get("Host"); host != "" {
request.Host = host
headers.Del("Host")
}
request.Header = headers
if err != nil {
panic(s.vm.NewGoError(err))
}
go func() {
response, executeErr := httpClient.Do(request)
if err != nil {
_, err = callback(nil, s.vm.NewGoError(executeErr), nil, nil)
if err != nil {
s.errorHandler(err)
}
return
}
defer response.Body.Close()
var content []byte
content, err = io.ReadAll(response.Body)
if err != nil {
_, err = callback(nil, s.vm.NewGoError(err), nil, nil)
if err != nil {
s.errorHandler(err)
}
}
responseObject := s.vm.NewObject()
responseObject.Set("status", response.StatusCode)
responseObject.Set("headers", jsc.HeadersToValue(s.vm, response.Header))
var bodyValue goja.Value
if binaryMode {
bodyValue = jsc.NewUint8Array(s.vm, content)
} else {
bodyValue = s.vm.ToValue(string(content))
}
_, err = callback(nil, nil, responseObject, bodyValue)
}()
return goja.Undefined()
}
}

View file

@ -0,0 +1,111 @@
package sgnotification
import (
"context"
"encoding/base64"
"strings"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/script/jsc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"github.com/dop251/goja"
)
type SurgeNotification struct {
vm *goja.Runtime
logger logger.Logger
platformInterface platform.Interface
scriptTag string
}
func Enable(vm *goja.Runtime, ctx context.Context, logger logger.Logger) {
platformInterface := service.FromContext[platform.Interface](ctx)
notification := &SurgeNotification{
vm: vm,
logger: logger,
platformInterface: platformInterface,
}
notificationObject := vm.NewObject()
notificationObject.Set("post", notification.js_post)
vm.Set("$notification", notificationObject)
}
func (s *SurgeNotification) js_post(call goja.FunctionCall) goja.Value {
var (
title string
subtitle string
body string
openURL string
clipboard string
mediaURL string
mediaData []byte
mediaType string
autoDismiss int
)
title = jsc.AssertString(s.vm, call.Argument(0), "title", true)
subtitle = jsc.AssertString(s.vm, call.Argument(1), "subtitle", true)
body = jsc.AssertString(s.vm, call.Argument(2), "body", true)
options := jsc.AssertObject(s.vm, call.Argument(3), "options", true)
if options != nil {
action := jsc.AssertString(s.vm, options.Get("action"), "options.action", true)
switch action {
case "open-url":
openURL = jsc.AssertString(s.vm, options.Get("url"), "options.url", false)
case "clipboard":
clipboard = jsc.AssertString(s.vm, options.Get("clipboard"), "options.clipboard", false)
}
mediaURL = jsc.AssertString(s.vm, options.Get("media-url"), "options.media-url", true)
mediaBase64 := jsc.AssertString(s.vm, options.Get("media-base64"), "options.media-base64", true)
if mediaBase64 != "" {
mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64)
if err != nil {
panic(s.vm.NewGoError(E.Cause(err, "decode media-base64")))
}
mediaData = mediaBinary
mediaType = jsc.AssertString(s.vm, options.Get("media-base64-mime"), "options.media-base64-mime", false)
}
autoDismiss = int(jsc.AssertInt(s.vm, options.Get("auto-dismiss"), "options.auto-dismiss", true))
}
if title != "" && subtitle == "" && body == "" {
body = title
title = ""
} else if title != "" && subtitle != "" && body == "" {
body = subtitle
subtitle = ""
}
var builder strings.Builder
if title != "" {
builder.WriteString("[")
builder.WriteString(title)
if subtitle != "" {
builder.WriteString(" - ")
builder.WriteString(subtitle)
}
builder.WriteString("]: ")
}
builder.WriteString(body)
s.logger.Info("notification: " + builder.String())
if s.platformInterface != nil {
err := s.platformInterface.SendNotification(&platform.Notification{
Identifier: "surge-script-notification-" + s.scriptTag,
TypeName: "Surge Script Notification (" + s.scriptTag + ")",
TypeID: 11,
Title: title,
Subtitle: subtitle,
Body: body,
OpenURL: openURL,
Clipboard: clipboard,
MediaURL: mediaURL,
MediaData: mediaData,
MediaType: mediaType,
Timeout: autoDismiss,
})
if err != nil {
s.logger.Error(E.Cause(err, "send notification"))
}
}
return goja.Undefined()
}

View file

@ -0,0 +1,76 @@
package sgstore
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing/service"
"github.com/dop251/goja"
)
type SurgePersistentStore struct {
vm *goja.Runtime
cacheFile adapter.CacheFile
data map[string]string
tag string
}
func Enable(vm *goja.Runtime, ctx context.Context) {
object := vm.NewObject()
cacheFile := service.FromContext[adapter.CacheFile](ctx)
tag := vm.Get("$script").(*goja.Object).Get("name").String()
store := &SurgePersistentStore{
vm: vm,
cacheFile: cacheFile,
tag: tag,
}
if cacheFile == nil {
store.data = make(map[string]string)
}
object.Set("read", store.js_read)
object.Set("write", store.js_write)
vm.Set("$persistentStore", object)
}
func (s *SurgePersistentStore) js_read(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 1 {
panic(s.vm.NewTypeError("invalid arguments"))
}
key := jsc.AssertString(s.vm, call.Argument(0), "key", true)
if key == "" {
key = s.tag
}
var value string
if s.cacheFile != nil {
value = s.cacheFile.SurgePersistentStoreRead(key)
} else {
value = s.data[key]
}
if value == "" {
return goja.Null()
} else {
return s.vm.ToValue(value)
}
}
func (s *SurgePersistentStore) js_write(call goja.FunctionCall) goja.Value {
if len(call.Arguments) == 0 || len(call.Arguments) > 2 {
panic(s.vm.NewTypeError("invalid arguments"))
}
data := jsc.AssertString(s.vm, call.Argument(0), "data", true)
key := jsc.AssertString(s.vm, call.Argument(1), "key", true)
if key == "" {
key = s.tag
}
if s.cacheFile != nil {
err := s.cacheFile.SurgePersistentStoreWrite(key, data)
if err != nil {
panic(s.vm.NewGoError(err))
}
} else {
s.data[key] = data
}
return goja.Undefined()
}

View file

@ -0,0 +1,45 @@
package sgutils
import (
"bytes"
"compress/gzip"
"io"
"github.com/sagernet/sing-box/script/jsc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/dop251/goja"
)
type SurgeUtils struct {
vm *goja.Runtime
}
func Enable(runtime *goja.Runtime) {
utils := &SurgeUtils{runtime}
object := runtime.NewObject()
object.Set("geoip", utils.js_stub)
object.Set("ipasn", utils.js_stub)
object.Set("ipaso", utils.js_stub)
object.Set("ungzip", utils.js_ungzip)
}
func (u *SurgeUtils) js_stub(call goja.FunctionCall) goja.Value {
panic(u.vm.NewGoError(E.New("not implemented")))
}
func (u *SurgeUtils) js_ungzip(call goja.FunctionCall) goja.Value {
if len(call.Arguments) != 1 {
panic(u.vm.NewGoError(E.New("invalid argument")))
}
binary := jsc.AssertBinary(u.vm, call.Argument(0), "binary", false)
reader, err := gzip.NewReader(bytes.NewReader(binary))
if err != nil {
panic(u.vm.NewGoError(err))
}
binary, err = io.ReadAll(reader)
if err != nil {
panic(u.vm.NewGoError(err))
}
return jsc.NewUint8Array(u.vm, binary)
}

26
script/script.go Normal file
View file

@ -0,0 +1,26 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
func NewScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) {
switch options.Type {
case C.ScriptTypeSurgeGeneric:
return NewSurgeGenericScript(ctx, logger, options)
case C.ScriptTypeSurgeHTTPRequest:
return NewSurgeHTTPRequestScript(ctx, logger, options)
case C.ScriptTypeSurgeHTTPResponse:
return NewSurgeHTTPResponseScript(ctx, logger, options)
case C.ScriptTypeSurgeCron:
return NewSurgeCronScript(ctx, logger, options)
default:
return nil, E.New("unknown script type: ", options.Type)
}
}

119
script/script_surge_cron.go Normal file
View file

@ -0,0 +1,119 @@
package script
import (
"context"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/adhocore/gronx"
)
var _ adapter.GenericScript = (*SurgeCronScript)(nil)
type SurgeCronScript struct {
GenericScript
ctx context.Context
expression string
timer *time.Timer
}
func NewSurgeCronScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeCronScript, error) {
source, err := NewSource(ctx, logger, options)
if err != nil {
return nil, err
}
if !gronx.IsValid(options.CronOptions.Expression) {
return nil, E.New("invalid cron expression: ", options.CronOptions.Expression)
}
return &SurgeCronScript{
GenericScript: GenericScript{
logger: logger,
tag: options.Tag,
timeout: time.Duration(options.Timeout),
arguments: options.Arguments,
source: source,
},
ctx: ctx,
expression: options.CronOptions.Expression,
}, nil
}
func (s *SurgeCronScript) Type() string {
return C.ScriptTypeSurgeCron
}
func (s *SurgeCronScript) Tag() string {
return s.tag
}
func (s *SurgeCronScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
return s.source.StartContext(ctx, startContext)
}
func (s *SurgeCronScript) PostStart() error {
err := s.source.PostStart()
if err != nil {
return err
}
go s.loop()
return nil
}
func (s *SurgeCronScript) loop() {
s.logger.Debug("starting event")
err := s.Run(s.ctx)
if err != nil {
s.logger.Error(E.Cause(err, "running event"))
}
nextTick, err := gronx.NextTick(s.expression, false)
if err != nil {
s.logger.Error(E.Cause(err, "determine next tick"))
return
}
s.timer = time.NewTimer(nextTick.Sub(time.Now()))
s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat))
for {
select {
case <-s.ctx.Done():
return
case <-s.timer.C:
s.logger.Debug("starting event")
err = s.Run(s.ctx)
if err != nil {
s.logger.Error(E.Cause(err, "running event"))
}
nextTick, err = gronx.NextTick(s.expression, false)
if err != nil {
s.logger.Error(E.Cause(err, "determine next tick"))
return
}
s.timer.Reset(nextTick.Sub(time.Now()))
s.logger.Debug("next event at: ", nextTick)
}
}
}
func (s *SurgeCronScript) Close() error {
return s.source.Close()
}
func (s *SurgeCronScript) Run(ctx context.Context) error {
program := s.source.Program()
if program == nil {
return E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
vm := NewRuntime(ctx, s.logger, cancel)
err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments)
if err != nil {
return err
}
return ExecuteSurgeGeneral(vm, program, ctx, s.timeout)
}

View file

@ -0,0 +1,183 @@
package script
import (
"context"
"runtime"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/locale"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/console"
"github.com/sagernet/sing-box/script/modules/eventloop"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/sagernet/sing-box/script/modules/sghttp"
"github.com/sagernet/sing-box/script/modules/sgnotification"
"github.com/sagernet/sing-box/script/modules/sgstore"
"github.com/sagernet/sing-box/script/modules/sgutils"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/ntp"
"github.com/dop251/goja"
"github.com/dop251/goja/parser"
)
const defaultScriptTimeout = 10 * time.Second
var _ adapter.GenericScript = (*GenericScript)(nil)
type GenericScript struct {
logger logger.ContextLogger
tag string
timeout time.Duration
arguments []any
source Source
}
func NewSurgeGenericScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*GenericScript, error) {
source, err := NewSource(ctx, logger, options)
if err != nil {
return nil, err
}
return &GenericScript{
logger: logger,
tag: options.Tag,
timeout: time.Duration(options.Timeout),
arguments: options.Arguments,
source: source,
}, nil
}
func (s *GenericScript) Type() string {
return C.ScriptTypeSurgeGeneric
}
func (s *GenericScript) Tag() string {
return s.tag
}
func (s *GenericScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
return s.source.StartContext(ctx, startContext)
}
func (s *GenericScript) PostStart() error {
return s.source.PostStart()
}
func (s *GenericScript) Close() error {
return s.source.Close()
}
func (s *GenericScript) Run(ctx context.Context) error {
program := s.source.Program()
if program == nil {
return E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
vm := NewRuntime(ctx, s.logger, cancel)
err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments)
if err != nil {
return err
}
return ExecuteSurgeGeneral(vm, program, ctx, s.timeout)
}
func NewRuntime(ctx context.Context, logger logger.ContextLogger, cancel context.CancelCauseFunc) *goja.Runtime {
vm := goja.New()
if timeFunc := ntp.TimeFuncFromContext(ctx); timeFunc != nil {
vm.SetTimeSource(timeFunc)
}
vm.SetParserOptions(parser.WithDisableSourceMaps)
registry := require.NewRegistry(require.WithLoader(func(path string) ([]byte, error) {
return nil, E.New("unsupported usage")
}))
registry.Enable(vm)
registry.RegisterNodeModule(console.ModuleName, console.Require(ctx, logger))
console.Enable(vm)
eventloop.Enable(vm, cancel)
return vm
}
func SetSurgeModules(vm *goja.Runtime, ctx context.Context, logger logger.Logger, errorHandler func(error), tag string, scriptType string, arguments []any) error {
script := vm.NewObject()
script.Set("name", F.ToString("script:", tag))
script.Set("startTime", jsc.TimeToValue(vm, time.Now()))
script.Set("type", scriptType)
vm.Set("$script", script)
environment := vm.NewObject()
var system string
switch runtime.GOOS {
case "ios":
system = "iOS"
case "darwin":
system = "macOS"
case "tvos":
system = "tvOS"
case "linux":
system = "Linux"
case "android":
system = "Android"
case "windows":
system = "Windows"
default:
system = runtime.GOOS
}
environment.Set("system", system)
environment.Set("surge-build", "N/A")
environment.Set("surge-version", "sing-box "+C.Version)
environment.Set("language", locale.Current().Locale)
environment.Set("device-model", "N/A")
vm.Set("$environment", environment)
sgstore.Enable(vm, ctx)
sgutils.Enable(vm)
sghttp.Enable(vm, ctx, errorHandler)
sgnotification.Enable(vm, ctx, logger)
vm.Set("$argument", arguments)
return nil
}
func ExecuteSurgeGeneral(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration) error {
if timeout == 0 {
timeout = defaultScriptTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
vm.ClearInterrupt()
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
vm.Set("done", func(call goja.FunctionCall) goja.Value {
doneFunc()
return goja.Undefined()
})
var err error
go func() {
_, err = vm.RunProgram(program)
if err != nil {
doneFunc()
}
}()
select {
case <-ctx.Done():
vm.Interrupt(ctx.Err())
return ctx.Err()
case <-done:
if err != nil {
vm.Interrupt(err)
} else {
vm.Interrupt("script done")
}
}
return err
}

View file

@ -0,0 +1,165 @@
package script
import (
"context"
"net/http"
"regexp"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
var _ adapter.HTTPRequestScript = (*SurgeHTTPRequestScript)(nil)
type SurgeHTTPRequestScript struct {
GenericScript
pattern *regexp.Regexp
requiresBody bool
maxSize int64
binaryBodyMode bool
}
func NewSurgeHTTPRequestScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeHTTPRequestScript, error) {
source, err := NewSource(ctx, logger, options)
if err != nil {
return nil, err
}
pattern, err := regexp.Compile(options.HTTPOptions.Pattern)
if err != nil {
return nil, E.Cause(err, "parse pattern")
}
return &SurgeHTTPRequestScript{
GenericScript: GenericScript{
logger: logger,
tag: options.Tag,
timeout: time.Duration(options.Timeout),
arguments: options.Arguments,
source: source,
},
pattern: pattern,
requiresBody: options.HTTPOptions.RequiresBody,
maxSize: options.HTTPOptions.MaxSize,
binaryBodyMode: options.HTTPOptions.BinaryBodyMode,
}, nil
}
func (s *SurgeHTTPRequestScript) Type() string {
return C.ScriptTypeSurgeHTTPRequest
}
func (s *SurgeHTTPRequestScript) Tag() string {
return s.tag
}
func (s *SurgeHTTPRequestScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
return s.source.StartContext(ctx, startContext)
}
func (s *SurgeHTTPRequestScript) PostStart() error {
return s.source.PostStart()
}
func (s *SurgeHTTPRequestScript) Close() error {
return s.source.Close()
}
func (s *SurgeHTTPRequestScript) Match(requestURL string) bool {
return s.pattern.MatchString(requestURL)
}
func (s *SurgeHTTPRequestScript) RequiresBody() bool {
return s.requiresBody
}
func (s *SurgeHTTPRequestScript) MaxSize() int64 {
return s.maxSize
}
func (s *SurgeHTTPRequestScript) Run(ctx context.Context, request *http.Request, body []byte) (*adapter.HTTPRequestScriptResult, error) {
program := s.source.Program()
if program == nil {
return nil, E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
vm := NewRuntime(ctx, s.logger, cancel)
err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments)
if err != nil {
return nil, err
}
return ExecuteSurgeHTTPRequest(vm, program, ctx, s.timeout, request, body, s.binaryBodyMode)
}
func ExecuteSurgeHTTPRequest(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool) (*adapter.HTTPRequestScriptResult, error) {
if timeout == 0 {
timeout = defaultScriptTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
vm.ClearInterrupt()
requestObject := vm.NewObject()
requestObject.Set("url", request.URL.String())
requestObject.Set("method", request.Method)
requestObject.Set("headers", jsc.HeadersToValue(vm, request.Header))
if !binaryBody {
requestObject.Set("body", string(body))
} else {
requestObject.Set("body", jsc.NewUint8Array(vm, body))
}
requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
vm.Set("request", requestObject)
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
var result adapter.HTTPRequestScriptResult
vm.Set("done", func(call goja.FunctionCall) goja.Value {
defer doneFunc()
resultObject := jsc.AssertObject(vm, call.Argument(0), "done() argument", true)
if resultObject == nil {
panic(vm.NewGoError(E.New("request rejected by script")))
}
result.URL = jsc.AssertString(vm, resultObject.Get("url"), "url", true)
result.Headers = jsc.AssertHTTPHeader(vm, resultObject.Get("headers"), "headers")
result.Body = jsc.AssertStringBinary(vm, resultObject.Get("body"), "body", true)
responseObject := jsc.AssertObject(vm, resultObject.Get("response"), "response", true)
if responseObject != nil {
result.Response = &adapter.HTTPRequestScriptResponse{
Status: int(jsc.AssertInt(vm, responseObject.Get("status"), "status", true)),
Headers: jsc.AssertHTTPHeader(vm, responseObject.Get("headers"), "headers"),
Body: jsc.AssertStringBinary(vm, responseObject.Get("body"), "body", true),
}
}
return goja.Undefined()
})
var err error
go func() {
_, err = vm.RunProgram(program)
if err != nil {
doneFunc()
}
}()
select {
case <-ctx.Done():
vm.Interrupt(ctx.Err())
return nil, ctx.Err()
case <-done:
if err != nil {
vm.Interrupt(err)
} else {
vm.Interrupt("script done")
}
}
return &result, err
}

View file

@ -0,0 +1,175 @@
package script
import (
"context"
"net/http"
"regexp"
"sync"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
var _ adapter.HTTPResponseScript = (*SurgeHTTPResponseScript)(nil)
type SurgeHTTPResponseScript struct {
GenericScript
pattern *regexp.Regexp
requiresBody bool
maxSize int64
binaryBodyMode bool
}
func NewSurgeHTTPResponseScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (*SurgeHTTPResponseScript, error) {
source, err := NewSource(ctx, logger, options)
if err != nil {
return nil, err
}
pattern, err := regexp.Compile(options.HTTPOptions.Pattern)
if err != nil {
return nil, E.Cause(err, "parse pattern")
}
return &SurgeHTTPResponseScript{
GenericScript: GenericScript{
logger: logger,
tag: options.Tag,
timeout: time.Duration(options.Timeout),
arguments: options.Arguments,
source: source,
},
pattern: pattern,
requiresBody: options.HTTPOptions.RequiresBody,
maxSize: options.HTTPOptions.MaxSize,
binaryBodyMode: options.HTTPOptions.BinaryBodyMode,
}, nil
}
func (s *SurgeHTTPResponseScript) Type() string {
return C.ScriptTypeSurgeHTTPResponse
}
func (s *SurgeHTTPResponseScript) Tag() string {
return s.tag
}
func (s *SurgeHTTPResponseScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
return s.source.StartContext(ctx, startContext)
}
func (s *SurgeHTTPResponseScript) PostStart() error {
return s.source.PostStart()
}
func (s *SurgeHTTPResponseScript) Close() error {
return s.source.Close()
}
func (s *SurgeHTTPResponseScript) Match(requestURL string) bool {
return s.pattern.MatchString(requestURL)
}
func (s *SurgeHTTPResponseScript) RequiresBody() bool {
return s.requiresBody
}
func (s *SurgeHTTPResponseScript) MaxSize() int64 {
return s.maxSize
}
func (s *SurgeHTTPResponseScript) Run(ctx context.Context, request *http.Request, response *http.Response, body []byte) (*adapter.HTTPResponseScriptResult, error) {
program := s.source.Program()
if program == nil {
return nil, E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
vm := NewRuntime(ctx, s.logger, cancel)
err := SetSurgeModules(vm, ctx, s.logger, cancel, s.Tag(), s.Type(), s.arguments)
if err != nil {
return nil, err
}
return ExecuteSurgeHTTPResponse(vm, program, ctx, s.timeout, request, response, body, s.binaryBodyMode)
}
func ExecuteSurgeHTTPResponse(vm *goja.Runtime, program *goja.Program, ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool) (*adapter.HTTPResponseScriptResult, error) {
if timeout == 0 {
timeout = defaultScriptTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
vm.ClearInterrupt()
requestObject := vm.NewObject()
requestObject.Set("url", request.URL.String())
requestObject.Set("method", request.Method)
requestObject.Set("headers", jsc.HeadersToValue(vm, request.Header))
requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
vm.Set("request", requestObject)
responseObject := vm.NewObject()
responseObject.Set("status", response.StatusCode)
responseObject.Set("headers", jsc.HeadersToValue(vm, response.Header))
if !binaryBody {
responseObject.Set("body", string(body))
} else {
responseObject.Set("body", jsc.NewUint8Array(vm, body))
}
vm.Set("response", responseObject)
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
var (
access sync.Mutex
result adapter.HTTPResponseScriptResult
)
vm.Set("done", func(call goja.FunctionCall) goja.Value {
resultObject := jsc.AssertObject(vm, call.Argument(0), "done() argument", true)
if resultObject == nil {
panic(vm.NewGoError(E.New("response rejected by script")))
}
access.Lock()
defer access.Unlock()
result.Status = int(jsc.AssertInt(vm, resultObject.Get("status"), "status", true))
result.Headers = jsc.AssertHTTPHeader(vm, resultObject.Get("headers"), "headers")
result.Body = jsc.AssertStringBinary(vm, resultObject.Get("body"), "body", true)
doneFunc()
return goja.Undefined()
})
var scriptErr error
go func() {
_, err := vm.RunProgram(program)
if err != nil {
access.Lock()
scriptErr = err
access.Unlock()
doneFunc()
}
}()
select {
case <-ctx.Done():
println(1)
vm.Interrupt(ctx.Err())
return nil, ctx.Err()
case <-done:
access.Lock()
defer access.Unlock()
if scriptErr != nil {
vm.Interrupt(scriptErr)
} else {
vm.Interrupt("script done")
}
return &result, scriptErr
}
}

31
script/source.go Normal file
View file

@ -0,0 +1,31 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
type Source interface {
StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error
PostStart() error
Program() *goja.Program
Close() error
}
func NewSource(ctx context.Context, logger logger.Logger, options option.Script) (Source, error) {
switch options.Source {
case C.ScriptSourceLocal:
return NewLocalSource(ctx, logger, options)
case C.ScriptSourceRemote:
return NewRemoteSource(ctx, logger, options)
default:
return nil, E.New("unknown source type: ", options.Source)
}
}

92
script/source_local.go Normal file
View file

@ -0,0 +1,92 @@
package script
import (
"context"
"os"
"path/filepath"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service/filemanager"
"github.com/dop251/goja"
)
var _ Source = (*LocalSource)(nil)
type LocalSource struct {
ctx context.Context
logger logger.Logger
tag string
program *goja.Program
watcher *fswatch.Watcher
}
func NewLocalSource(ctx context.Context, logger logger.Logger, options option.Script) (*LocalSource, error) {
script := &LocalSource{
ctx: ctx,
logger: logger,
tag: options.Tag,
}
filePath := filemanager.BasePath(ctx, options.LocalOptions.Path)
filePath, _ = filepath.Abs(options.LocalOptions.Path)
err := script.reloadFile(filePath)
if err != nil {
return nil, err
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{filePath},
Callback: func(path string) {
uErr := script.reloadFile(path)
if uErr != nil {
logger.Error(E.Cause(uErr, "reload script ", path))
}
},
})
if err != nil {
return nil, err
}
script.watcher = watcher
return script, nil
}
func (s *LocalSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
if s.watcher != nil {
err := s.watcher.Start()
if err != nil {
s.logger.Error(E.Cause(err, "watch script file"))
}
}
return nil
}
func (s *LocalSource) reloadFile(path string) error {
content, err := os.ReadFile(path)
if err != nil {
return err
}
program, err := goja.Compile("script:"+s.tag, string(content), false)
if err != nil {
return E.Cause(err, "compile ", path)
}
if s.program != nil {
s.logger.Info("reloaded from ", path)
}
s.program = program
return nil
}
func (s *LocalSource) PostStart() error {
return nil
}
func (s *LocalSource) Program() *goja.Program {
return s.program
}
func (s *LocalSource) Close() error {
return s.watcher.Close()
}

224
script/source_remote.go Normal file
View file

@ -0,0 +1,224 @@
package script
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"runtime"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/dop251/goja"
)
var _ Source = (*RemoteSource)(nil)
type RemoteSource struct {
ctx context.Context
cancel context.CancelFunc
logger logger.Logger
outbound adapter.OutboundManager
options option.Script
updateInterval time.Duration
dialer N.Dialer
program *goja.Program
lastUpdated time.Time
lastEtag string
updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager
}
func NewRemoteSource(ctx context.Context, logger logger.Logger, options option.Script) (*RemoteSource, error) {
ctx, cancel := context.WithCancel(ctx)
var updateInterval time.Duration
if options.RemoteOptions.UpdateInterval > 0 {
updateInterval = time.Duration(options.RemoteOptions.UpdateInterval)
} else {
updateInterval = 24 * time.Hour
}
return &RemoteSource{
ctx: ctx,
cancel: cancel,
logger: logger,
outbound: service.FromContext[adapter.OutboundManager](ctx),
options: options,
updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
}, nil
}
func (s *RemoteSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.outbound.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded {
return E.New("download detour not found: ", s.options.RemoteOptions.DownloadDetour)
}
dialer = outbound
} else {
dialer = s.outbound.Default()
}
s.dialer = dialer
if s.cacheFile != nil {
if savedSet := s.cacheFile.LoadScript(s.options.Tag); savedSet != nil {
err := s.loadBytes(savedSet.Content)
if err != nil {
return E.Cause(err, "restore cached rule-set")
}
s.lastUpdated = savedSet.LastUpdated
s.lastEtag = savedSet.LastEtag
}
}
if s.lastUpdated.IsZero() {
err := s.fetchOnce(ctx, startContext)
if err != nil {
return E.Cause(err, "initial rule-set: ", s.options.Tag)
}
}
s.updateTicker = time.NewTicker(s.updateInterval)
return nil
}
func (s *RemoteSource) PostStart() error {
go s.loopUpdate()
return nil
}
func (s *RemoteSource) Program() *goja.Program {
return s.program
}
func (s *RemoteSource) loadBytes(content []byte) error {
program, err := goja.Compile(F.ToString("script:", s.options.Tag), string(content), false)
if err != nil {
return err
}
s.program = program
return nil
}
func (s *RemoteSource) loopUpdate() {
if time.Since(s.lastUpdated) > s.updateInterval {
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
}
}
for {
runtime.GC()
select {
case <-s.ctx.Done():
return
case <-s.updateTicker.C:
s.pauseManager.WaitActive()
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
}
}
}
}
func (s *RemoteSource) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.logger.Debug("updating script ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
var httpClient *http.Client
if startContext != nil {
httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer)
} else {
httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(s.ctx),
RootCAs: adapter.RootPoolFromContext(s.ctx),
},
},
}
}
request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil)
if err != nil {
return err
}
if s.lastEtag != "" {
request.Header.Set("If-None-Match", s.lastEtag)
}
response, err := httpClient.Do(request.WithContext(ctx))
if err != nil {
return err
}
switch response.StatusCode {
case http.StatusOK:
case http.StatusNotModified:
s.lastUpdated = time.Now()
if s.cacheFile != nil {
savedRuleSet := s.cacheFile.LoadScript(s.options.Tag)
if savedRuleSet != nil {
savedRuleSet.LastUpdated = s.lastUpdated
err = s.cacheFile.SaveScript(s.options.Tag, savedRuleSet)
if err != nil {
s.logger.Error("save script updated time: ", err)
return nil
}
}
}
s.logger.Info("update script ", s.options.Tag, ": not modified")
return nil
default:
return E.New("unexpected status: ", response.Status)
}
content, err := io.ReadAll(response.Body)
if err != nil {
response.Body.Close()
return err
}
err = s.loadBytes(content)
if err != nil {
response.Body.Close()
return err
}
response.Body.Close()
eTagHeader := response.Header.Get("Etag")
if eTagHeader != "" {
s.lastEtag = eTagHeader
}
s.lastUpdated = time.Now()
if s.cacheFile != nil {
err = s.cacheFile.SaveScript(s.options.Tag, &adapter.SavedBinary{
LastUpdated: s.lastUpdated,
Content: content,
LastEtag: s.lastEtag,
})
if err != nil {
s.logger.Error("save script cache: ", err)
}
}
s.logger.Info("updated script ", s.options.Tag)
return nil
}
func (s *RemoteSource) Close() error {
if s.updateTicker != nil {
s.updateTicker.Stop()
}
s.cancel()
return nil
}