mitm: Minor fixes

This commit is contained in:
世界 2025-02-03 10:59:07 +08:00
parent 5361d2acec
commit fb3007fa80
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 54 additions and 38 deletions

View file

@ -40,7 +40,7 @@ func getMobileConfig(ctx context.Context) http.HandlerFunc {
mobileConfig := map[string]interface{}{
"PayloadContent": []interface{}{
map[string]interface{}{
"PayloadCertificateFileName": "Certificate.cer",
"PayloadCertificateFileName": "Certificates.cer",
"PayloadContent": certificate.Raw,
"PayloadDescription": "Adds a root certificate",
"PayloadDisplayName": certificate.Subject.CommonName,

View file

@ -26,6 +26,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -165,7 +166,7 @@ func (e *Engine) newTLS(ctx context.Context, this N.Dialer, conn net.Conn, metad
tlsConn := tls.Server(conn, tlsConfig)
err := tlsConn.HandshakeContext(ctx)
if err != nil {
return E.Cause(err, "TLS handshake")
return E.Cause(err, "TLS handshake failed for ", metadata.ClientHello.ServerName, ", ", strings.Join(metadata.ClientHello.SupportedProtos, ", "))
}
if tlsConn.ConnectionState().NegotiatedProtocol == "h2" {
return e.newHTTP2(ctx, this, tlsConn, tlsConfig, metadata, onClose)
@ -183,7 +184,11 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
return E.Cause(err, "read HTTP request")
}
rawRequestURL := request.URL
rawRequestURL.Scheme = "https"
if tlsConfig != nil {
rawRequestURL.Scheme = "https"
} else {
rawRequestURL.Scheme = "http"
}
if rawRequestURL.Host == "" {
rawRequestURL.Host = request.Host
}
@ -482,7 +487,7 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
response.Body = io.NopCloser(bytes.NewReader(responseBody))
}
if options.Print {
e.printResponse(ctx, response, responseBody)
e.printResponse(ctx, request, response, responseBody)
}
if responseScript != nil {
if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
@ -578,6 +583,22 @@ func (e *Engine) newHTTP1(ctx context.Context, this N.Dialer, conn net.Conn, tls
}
func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tlsConfig *tls.Config, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
httpTransport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
ctx = adapter.WithContext(ctx, &metadata)
if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
return dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
return this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
},
TLSClientConfig: tlsConfig,
}
err := http2.ConfigureTransport(httpTransport)
if err != nil {
return E.Cause(err, "configure HTTP/2 transport")
}
handler := &engineHandler{
Engine: e,
conn: conn,
@ -585,27 +606,7 @@ func (e *Engine) newHTTP2(ctx context.Context, this N.Dialer, conn net.Conn, tls
dialer: this,
metadata: metadata,
httpClient: &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
MaxReadFrameSize: math.MaxUint32,
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
ctx = adapter.WithContext(ctx, &metadata)
var (
remoteConn net.Conn
err error
)
if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
return nil, err
}
return tls.Client(remoteConn, cfg), nil
},
TLSClientConfig: tlsConfig,
},
Transport: httpTransport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
@ -635,7 +636,6 @@ type engineHandler struct {
func (e *engineHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
err := e.serveHTTP(request.Context(), writer, request)
if err != nil {
e.conn.Close()
if E.IsClosedOrCanceled(err) {
e.logger.DebugContext(request.Context(), E.Cause(err, "connection closed"))
} else {
@ -921,7 +921,7 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
response.Body = io.NopCloser(bytes.NewReader(responseBody))
}
if options.Print {
e.printResponse(ctx, response, responseBody)
e.printResponse(ctx, request, response, responseBody)
}
if responseScript != nil {
if responseBody == nil && responseScript.RequiresBody() && response.ContentLength > 0 && (responseScript.MaxSize() == 0 && response.ContentLength <= 131072 || response.ContentLength <= responseScript.MaxSize()) {
@ -1021,42 +1021,58 @@ func (e *engineHandler) serveHTTP(ctx context.Context, writer http.ResponseWrite
}
func (e *Engine) printRequest(ctx context.Context, request *http.Request, body []byte) {
e.logger.TraceContext(ctx, "request: ", request.Proto, " ", request.Method, " ", request.URL.String())
var builder strings.Builder
builder.WriteString(F.ToString(request.Proto, " ", request.Method, " ", request.URL))
builder.WriteString("\n")
if request.URL.Hostname() != "" && request.URL.Hostname() != request.Host {
e.logger.TraceContext(ctx, "request: ", "Host: ", request.Host)
builder.WriteString("Host: ")
builder.WriteString(request.Host)
builder.WriteString("\n")
}
for key, values := range request.Header {
for _, value := range values {
e.logger.TraceContext(ctx, "request: ", key, ": ", value)
builder.WriteString(key)
builder.WriteString(": ")
builder.WriteString(value)
builder.WriteString("\n")
}
}
if len(body) > 0 {
builder.WriteString("\n")
if !bytes.ContainsFunc(body, func(r rune) bool {
return !unicode.IsPrint(r) && !unicode.IsSpace(r)
}) {
e.logger.TraceContext(ctx, "request: body: ", string(body))
builder.Write(body)
} else {
e.logger.TraceContext(ctx, "request: body unprintable")
builder.WriteString("(body not printable)")
}
}
e.logger.InfoContext(ctx, "request: ", builder.String())
}
func (e *Engine) printResponse(ctx context.Context, response *http.Response, body []byte) {
e.logger.TraceContext(ctx, "response: ", response.Proto, " ", response.Status)
func (e *Engine) printResponse(ctx context.Context, request *http.Request, response *http.Response, body []byte) {
var builder strings.Builder
builder.WriteString(F.ToString(response.Proto, " ", response.Status, " ", request.URL))
builder.WriteString("\n")
for key, values := range response.Header {
for _, value := range values {
e.logger.TraceContext(ctx, "response: ", key, ": ", value)
builder.WriteString(key)
builder.WriteString(": ")
builder.WriteString(value)
builder.WriteString("\n")
}
}
if len(body) > 0 {
builder.WriteString("\n")
if !bytes.ContainsFunc(body, func(r rune) bool {
return !unicode.IsPrint(r) && !unicode.IsSpace(r)
}) {
e.logger.TraceContext(ctx, "response: ", string(body))
builder.Write(body)
} else {
builder.WriteString("(body not printable)")
}
} else {
e.logger.TraceContext(ctx, "response: body unprintable")
}
e.logger.InfoContext(ctx, "response: ", builder.String())
}
type simpleResponseWriter struct {