Fix processing multiple sniffs

This commit is contained in:
世界 2025-03-16 09:21:54 +08:00
parent 4f3ee61104
commit d55d5009c2
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 29 additions and 13 deletions

View file

@ -57,6 +57,7 @@ type InboundContext struct {
Domain string
Client string
SniffContext any
PacketSniffError error
// cache

View file

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
)
@ -34,7 +35,7 @@ func Skip(metadata *adapter.InboundContext) bool {
return false
}
func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error {
func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffers []*buf.Buffer, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error {
if timeout == 0 {
timeout = C.ReadPayloadTimeout
}
@ -55,7 +56,10 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
}
errors = nil
for _, sniffer := range sniffers {
err = sniffer(ctx, metadata, bytes.NewReader(buffer.Bytes()))
reader := io.MultiReader(common.Map(append(buffers, buffer), func(it *buf.Buffer) io.Reader {
return bytes.NewReader(it.Bytes())
})...)
err = sniffer(ctx, metadata, reader)
if err == nil {
return nil
}

View file

@ -358,7 +358,7 @@ func (r *Router) matchRule(
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
}, inputConn, inputPacketConn)
}, inputConn, inputPacketConn, nil)
if newErr != nil {
fatalErr = newErr
return
@ -458,7 +458,7 @@ match:
switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff:
if !preMatch {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers)
if newErr != nil {
fatalErr = newErr
return
@ -490,7 +490,7 @@ match:
}
}
if !preMatch && inputPacketConn != nil && (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn)
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn, buffers)
if newErr != nil {
fatalErr = newErr
return
@ -506,11 +506,16 @@ match:
func (r *Router) actionSniff(
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
inputConn net.Conn, inputPacketConn N.PacketConn,
inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer,
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
if sniff.Skip(metadata) {
r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
return
} else if inputConn != nil {
} else if metadata.Protocol != "" {
r.logger.DebugContext(ctx, "duplicate sniff skipped")
return
}
if inputConn != nil {
sniffBuffer := buf.NewPacket()
var streamSniffers []sniff.StreamSniffer
if len(action.StreamSniffers) > 0 {
@ -529,6 +534,7 @@ func (r *Router) actionSniff(
ctx,
metadata,
inputConn,
inputBuffers,
sniffBuffer,
action.Timeout,
streamSniffers...,
@ -555,6 +561,10 @@ func (r *Router) actionSniff(
sniffBuffer.Release()
}
} else if inputPacketConn != nil {
if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrClientHelloFragmented) {
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.PacketSniffError)
return
}
for {
var (
sniffBuffer = buf.NewPacket()
@ -589,7 +599,7 @@ func (r *Router) actionSniff(
if (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() {
metadata.Destination = destination
}
if len(packetBuffers) > 0 {
if len(packetBuffers) > 0 || metadata.PacketSniffError != nil {
err = sniff.PeekPacket(
ctx,
metadata,
@ -622,7 +632,8 @@ func (r *Router) actionSniff(
Destination: destination,
}
packetBuffers = append(packetBuffers, packetBuffer)
if E.IsMulti(err, sniff.ErrClientHelloFragmented) {
metadata.PacketSniffError = err
if errors.Is(err, sniff.ErrClientHelloFragmented) {
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue
}