From 74920b44ac92b00d43abafc0e860609ec366cf3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 2 Feb 2025 21:36:09 +0800 Subject: [PATCH] mitm: Add HTTP2 support --- box.go | 2 +- mitm/engine.go | 424 ++++++++++++++++++++++++++++++++++++++++++++++++- option/mitm.go | 4 +- 3 files changed, 420 insertions(+), 10 deletions(-) diff --git a/box.go b/box.go index 64f93de3..26995fdb 100644 --- a/box.go +++ b/box.go @@ -348,7 +348,7 @@ func New(options Options) (*Box, error) { services = append(services, adapter.NewLifecycleService(ntpService, "ntp service")) } mitmOptions := common.PtrValueOrDefault(options.MITM) - var mitmEngine *mitm.Engine + var mitmEngine adapter.MITMEngine if mitmOptions.Enabled { engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions) if err != nil { diff --git a/mitm/engine.go b/mitm/engine.go index 6efa3f90..32a6baec 100644 --- a/mitm/engine.go +++ b/mitm/engine.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/base64" "io" + "math" "mime" "net" "net/http" @@ -32,6 +33,7 @@ import ( "github.com/sagernet/sing/service" "golang.org/x/crypto/pkcs12" + "golang.org/x/net/http2" ) var _ adapter.MITMEngine = (*Engine)(nil) @@ -51,9 +53,9 @@ type Engine struct { func NewEngine(ctx context.Context, logger logger.ContextLogger, options option.MITMOptions) (*Engine, error) { engine := &Engine{ - ctx: ctx, - logger: logger, - // http2Enabled: options.HTTP2Enabled, + ctx: ctx, + logger: logger, + http2Enabled: options.HTTP2Enabled, } if options.TLSDecryptionOptions != nil && options.TLSDecryptionOptions.Enabled { pfxBytes, err := base64.StdEncoding.DecodeString(options.TLSDecryptionOptions.KeyPair) @@ -265,7 +267,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls } if result.Body != nil { request.Body = io.NopCloser(bytes.NewReader(result.Body)) - request.ContentLength = int64(len(body)) + request.ContentLength = int64(len(result.Body)) } } } @@ -421,7 +423,6 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls var innerErr atomic.TypedValue[error] httpClient := &http.Client{ Transport: &http.Transport{ - DisableCompression: true, DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { if tlsConfig != nil { return tls.Client(remoteConn, tlsConfig), nil @@ -558,8 +559,417 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls return nil } -func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn *tls.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { - // TODO: implement http2 support +func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + handler := &engineHandler{ + Engine: e, + conn: conn, + dialer: this, + metadata: metadata, + httpClient: &http.Client{ + Transport: &http2.Transport{ + AllowHTTP: true, + MaxReadFrameSize: math.MaxUint32, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + ctx = adapter.WithContext(ctx, &metadata) + var ( + remoteConn net.Conn + err error + ) + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + if err != nil { + return nil, err + } + return tls.Client(remoteConn, cfg), nil + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + onClose: onClose, + } + http2Server := &http2.Server{ + MaxReadFrameSize: math.MaxUint32, + } + http2Server.ServeConn(conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: handler, + }) + return nil +} + +type engineHandler struct { + *Engine + conn net.Conn + dialer N.Dialer + metadata adapter.InboundContext + onClose N.CloseHandlerFunc + + httpClient *http.Client +} + +func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + err := e.serveHTTP(request.Context(), writer, request) + if err != nil { + if E.IsClosedOrCanceled(err) { + e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed")) + } else { + e.logger.ErrorContext(request.Context(), err) + } + } +} + +func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + options := e.metadata.MITM + e.metadata.MITM = nil + rawRequestURL := request.URL + rawRequestURL.Scheme = "https" + if rawRequestURL.Host == "" { + rawRequestURL.Host = request.Host + } + requestURL := rawRequestURL.String() + request.RequestURI = "" + var ( + requestMatch bool + requestScript adapter.HTTPRequestScript + ) + for _, script := range e.script.Scripts() { + if !common.Contains(options.Script, script.Tag()) { + continue + } + httpScript, isHTTP := script.(adapter.HTTPRequestScript) + if !isHTTP { + _, isHTTP = script.(adapter.HTTPScript) + if !isHTTP { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + } + continue + } + if !httpScript.Match(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") + requestScript = httpScript + requestMatch = true + break + } + var err error + if requestScript != nil { + var body []byte + if requestScript.RequiresBody() && request.ContentLength > 0 && (requestScript.MaxSize() == 0 && request.ContentLength <= 131072 || request.ContentLength <= requestScript.MaxSize()) { + body, err = io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + request.Body = io.NopCloser(bytes.NewReader(body)) + } + result, err := requestScript.Run(ctx, request, body) + if err != nil { + return E.Cause(err, "execute script/", requestScript.Type(), "[", requestScript.Tag(), "]") + } + if result.Response != nil { + if result.Response.Status == 0 { + result.Response.Status = http.StatusOK + } + for key, values := range result.Response.Headers { + writer.Header()[key] = values + } + writer.WriteHeader(result.Response.Status) + if result.Response.Body != nil { + _, err = writer.Write(result.Response.Body) + if err != nil { + return E.Cause(err, "write fake response body") + } + } + return nil + } else { + if result.URL != "" { + var newURL *url.URL + newURL, err = url.Parse(result.URL) + if err != nil { + return E.Cause(err, "parse updated request URL") + } + request.URL = newURL + newDestination := M.ParseSocksaddrHostPortStr(newURL.Hostname(), newURL.Port()) + if newDestination.Port == 0 { + newDestination.Port = e.metadata.Destination.Port + } + e.metadata.Destination = newDestination + } + for key, values := range result.Headers { + request.Header[key] = values + } + if newHost := result.Headers.Get("Host"); newHost != "" { + request.Host = newHost + request.Header.Del("Host") + } + if result.Body != nil { + io.Copy(io.Discard, request.Body) + request.Body = io.NopCloser(bytes.NewReader(result.Body)) + request.ContentLength = int64(len(result.Body)) + } + } + } + if !requestMatch { + for i, rule := range options.SurgeURLRewrite { + if !rule.Pattern.MatchString(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match url_rewrite[", i, "] => ", rule.String()) + if rule.Reject { + return E.New("request rejected by url_rewrite") + } else if rule.Redirect { + http.Redirect(writer, request, rule.Destination.String(), http.StatusFound) + return nil + } + requestMatch = true + request.URL = rule.Destination + newDestination := M.ParseSocksaddrHostPortStr(rule.Destination.Hostname(), rule.Destination.Port()) + if newDestination.Port == 0 { + newDestination.Port = e.metadata.Destination.Port + } + e.metadata.Destination = newDestination + break + } + for i, rule := range options.SurgeHeaderRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + if strings.ToLower(rule.Key) == "host" { + request.Host = rule.Value + continue + } + request.Header.Add(rule.Key, rule.Value) + case rule.Delete: + request.Header.Del(rule.Key) + case rule.Replace: + if request.Header.Get(rule.Key) != "" { + request.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := request.Header.Get(rule.Key); value != "" { + request.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + var body []byte + if request.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if request.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err := io.ReadAll(request.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + request.Body.Close() + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + request.Body = io.NopCloser(bytes.NewReader(body)) + request.ContentLength = int64(len(body)) + } + } + if !requestMatch { + for i, rule := range options.SurgeMapLocal { + if !rule.Pattern.MatchString(requestURL) { + continue + } + requestMatch = true + e.logger.DebugContext(ctx, "match map_local[", i, "] => ", rule.String()) + go func() { + io.Copy(io.Discard, request.Body) + request.Body.Close() + }() + var ( + statusCode = http.StatusOK + headers = make(http.Header) + body []byte + ) + if rule.StatusCode > 0 { + statusCode = rule.StatusCode + } + switch { + case rule.File: + resource, err := os.ReadFile(rule.Data) + if err != nil { + return E.Cause(err, "open map local source") + } + mimeType := mime.TypeByExtension(filepath.Ext(rule.Data)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + headers.Set("Content-Type", mimeType) + body = resource + case rule.Text: + headers.Set("Content-Type", "text/plain") + body = []byte(rule.Data) + case rule.TinyGif: + headers.Set("Content-Type", "image/gif") + body = surgeTinyGif() + case rule.Base64: + headers.Set("Content-Type", "application/octet-stream") + body = rule.Base64Data + } + for key, values := range headers { + writer.Header()[key] = values + } + writer.WriteHeader(statusCode) + _, err = writer.Write(body) + if err != nil { + return E.Cause(err, "write map local response") + } + return nil + } + } + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + response, err := e.httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Cause(err, "exchange request") + } + var ( + responseScript adapter.HTTPResponseScript + responseMatch bool + ) + for _, script := range e.script.Scripts() { + if !common.Contains(options.Script, script.Tag()) { + continue + } + httpScript, isHTTP := script.(adapter.HTTPResponseScript) + if !isHTTP { + _, isHTTP = script.(adapter.HTTPScript) + if !isHTTP { + e.logger.WarnContext(ctx, "specified script/", script.Type(), "[", script.Tag(), "] is not a HTTP request/response script") + } + continue + } + if !httpScript.Match(requestURL) { + continue + } + e.logger.DebugContext(ctx, "match script/", httpScript.Type(), "[", httpScript.Tag(), "]") + responseScript = httpScript + responseMatch = true + break + } + if responseScript != nil { + var body []byte + if responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) { + body, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP response body") + } + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(body)) + } + var result *adapter.HTTPResponseScriptResult + result, err = responseScript.Run(ctx, request, response, body) + if err != nil { + return E.Cause(err, "execute script/", responseScript.Type(), "[", responseScript.Tag(), "]") + } + if result.Status > 0 { + response.Status = http.StatusText(result.Status) + response.StatusCode = result.Status + } + for key, values := range result.Headers { + response.Header[key] = values + } + if result.Body != nil { + response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(result.Body)) + response.ContentLength = int64(len(result.Body)) + } + } + if !responseMatch { + for i, rule := range options.SurgeHeaderRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match header_rewrite[", i, "] => ", rule.String()) + switch { + case rule.Add: + response.Header.Add(rule.Key, rule.Value) + case rule.Delete: + response.Header.Del(rule.Key) + case rule.Replace: + if response.Header.Get(rule.Key) != "" { + response.Header.Set(rule.Key, rule.Value) + } + case rule.ReplaceRegex: + if value := response.Header.Get(rule.Key); value != "" { + response.Header.Set(rule.Key, rule.Match.ReplaceAllString(value, rule.Value)) + } + } + } + for i, rule := range options.SurgeBodyRewrite { + if !rule.Response { + continue + } + if !rule.Pattern.MatchString(requestURL) { + continue + } + responseMatch = true + e.logger.DebugContext(ctx, "match body_rewrite[", i, "] => ", rule.String()) + var body []byte + if response.ContentLength <= 0 { + e.logger.WarnContext(ctx, "body replace skipped due to non-fixed content length") + break + } else if response.ContentLength > 131072 { + e.logger.WarnContext(ctx, "body replace skipped due to large content length: ", request.ContentLength) + break + } + body, err = io.ReadAll(response.Body) + if err != nil { + return E.Cause(err, "read HTTP request body") + } + response.Body.Close() + for mi := 0; i < len(rule.Match); i++ { + body = rule.Match[mi].ReplaceAll(body, []byte(rule.Replace[i])) + } + response.Body = io.NopCloser(bytes.NewReader(body)) + response.ContentLength = int64(len(body)) + } + } + if !requestMatch && !responseMatch { + e.logger.WarnContext(ctx, "request not modified") + } + for key, values := range request.Header { + writer.Header()[key] = values + } + writer.WriteHeader(response.StatusCode) + _, err = io.Copy(writer, response.Body) + response.Body.Close() + if err != nil { + return E.Cause(err, "write HTTP response") + } return nil } diff --git a/option/mitm.go b/option/mitm.go index 32aa53bc..be9f0180 100644 --- a/option/mitm.go +++ b/option/mitm.go @@ -5,8 +5,8 @@ import ( ) type MITMOptions struct { - Enabled bool `json:"enabled,omitempty"` - // HTTP2Enabled bool `json:"http2_enabled,omitempty"` + Enabled bool `json:"enabled,omitempty"` + HTTP2Enabled bool `json:"http2_enabled,omitempty"` TLSDecryptionOptions *TLSDecryptionOptions `json:"tls_decryption,omitempty"` }