Implement dns-hijack

This commit is contained in:
世界 2024-10-23 13:44:08 +08:00
parent 179e3cb2f5
commit 9f7683818f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 245 additions and 96 deletions

View file

@ -88,7 +88,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewConn(conn)
}
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
if err != nil {
return err
}
@ -109,6 +109,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error())
return nil
case *rule.RuleActionHijackDNS:
for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer)
}
r.hijackDNSStream(ctx, conn, metadata)
return nil
}
}
if selectedRule == nil || selectReturn {
@ -226,7 +232,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
}*/
selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
if err != nil {
return err
}
@ -238,32 +244,35 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
var loaded bool
selectedOutbound, loaded = r.Outbound(action.Outbound)
if !loaded {
buf.ReleaseMulti(buffers)
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound)
}
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
case *rule.RuleActionReturn:
selectReturn = true
case *rule.RuleActionReject:
buf.ReleaseMulti(buffers)
N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED)
return nil
case *rule.RuleActionHijackDNS:
r.hijackDNSPacket(ctx, conn, packetBuffers, metadata)
return nil
}
}
if selectedRule == nil || selectReturn {
if r.defaultOutboundForPacketConnection == nil {
buf.ReleaseMulti(buffers)
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("missing default outbound with UDP support")
}
selectedOutbound = r.defaultOutboundForPacketConnection
}
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
buf.ReleaseMulti(buffers)
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
}
for _, buffer := range buffers {
// TODO: check if metadata.Destination == packet destination
conn = bufio.NewCachedPacketConn(conn, buffer, metadata.Destination)
for _, buffer := range packetBuffers {
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
N.PutPacketBuffer(buffer)
}
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule)
@ -297,7 +306,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
}
func (r *Router) PreMatch(metadata adapter.InboundContext) error {
selectedRule, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
if err != nil {
return err
}
@ -314,7 +323,10 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
func (r *Router) matchRule(
ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
) (selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, fatalErr error) {
) (
selectedRule adapter.Rule, selectedRuleIndex int,
buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
) {
if r.processSearcher != nil && metadata.ProcessInfo == nil {
var originDestination netip.AddrPort
if metadata.OriginDestination.IsValid() {
@ -376,7 +388,7 @@ func (r *Router) matchRule(
//nolint:staticcheck
if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() {
if !preMatch && metadata.InboundOptions.SniffEnabled {
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
}, inputConn, inputPacketConn)
@ -384,7 +396,11 @@ func (r *Router) matchRule(
fatalErr = newErr
return
}
buffers = append(buffers, newBuffers...)
if newBuffer != nil {
buffers = []*buf.Buffer{newBuffer}
} else if len(newPackerBuffers) > 0 {
packetBuffers = newPackerBuffers
}
}
if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS {
fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{
@ -421,22 +437,36 @@ match:
break
}
if !preMatch {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", currentRuleIndex, "] => ", currentRule.Action())
}
} else {
switch currentRule.Action().Type() {
case C.RuleActionTypeReject, C.RuleActionTypeResolve:
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] => ", currentRule.Action())
}
}
}
switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff:
if !preMatch {
newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
if newErr != nil {
fatalErr = newErr
return
}
buffers = append(buffers, newBuffers...)
if newBuffer != nil {
buffers = append(buffers, newBuffer)
} else if len(newPacketBuffers) > 0 {
packetBuffers = append(packetBuffers, newPacketBuffers...)
}
} else {
selectedRule = currentRule
selectedRuleIndex = currentRuleIndex
@ -455,12 +485,16 @@ match:
ruleIndex = currentRuleIndex
}
if !preMatch && metadata.Destination.Addr.IsUnspecified() {
newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)
if newErr != nil {
fatalErr = newErr
return
}
buffers = append(buffers, newBuffers...)
if newBuffer != nil {
buffers = append(buffers, newBuffer)
} else if len(newPacketBuffers) > 0 {
packetBuffers = append(packetBuffers, newPacketBuffers...)
}
}
return
}
@ -468,18 +502,31 @@ match:
func (r *Router) actionSniff(
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
inputConn net.Conn, inputPacketConn N.PacketConn,
) (buffers []*buf.Buffer, fatalErr error) {
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
if sniff.Skip(metadata) {
return
} else if inputConn != nil && len(action.StreamSniffers) > 0 {
buffer := buf.NewPacket()
} else if inputConn != nil {
sniffBuffer := buf.NewPacket()
var streamSniffers []sniff.StreamSniffer
if len(action.StreamSniffers) > 0 {
streamSniffers = action.StreamSniffers
} else {
streamSniffers = []sniff.StreamSniffer{
sniff.TLSClientHello,
sniff.HTTPHost,
sniff.StreamDomainNameQuery,
sniff.BitTorrent,
sniff.SSH,
sniff.RDP,
}
}
err := sniff.PeekStream(
ctx,
metadata,
inputConn,
buffer,
sniffBuffer,
action.Timeout,
action.StreamSniffers...,
streamSniffers...,
)
if err == nil {
//goland:noinspection GoDeprecation
@ -497,15 +544,15 @@ func (r *Router) actionSniff(
r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol)
}
}
if !buffer.IsEmpty() {
buffers = append(buffers, buffer)
if !sniffBuffer.IsEmpty() {
buffer = sniffBuffer
} else {
buffer.Release()
sniffBuffer.Release()
}
} else if inputPacketConn != nil && len(action.PacketSniffers) > 0 {
} else if inputPacketConn != nil {
for {
var (
buffer = buf.NewPacket()
sniffBuffer = buf.NewPacket()
destination M.Socksaddr
done = make(chan struct{})
err error
@ -516,7 +563,7 @@ func (r *Router) actionSniff(
sniffTimeout = action.Timeout
}
inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout))
destination, err = inputPacketConn.ReadPacket(buffer)
destination, err = inputPacketConn.ReadPacket(sniffBuffer)
inputPacketConn.SetReadDeadline(time.Time{})
close(done)
}()
@ -528,7 +575,7 @@ func (r *Router) actionSniff(
return
}
if err != nil {
buffer.Release()
sniffBuffer.Release()
if !errors.Is(err, os.ErrDeadlineExceeded) {
fatalErr = err
return
@ -538,22 +585,40 @@ func (r *Router) actionSniff(
if metadata.Destination.Addr.IsUnspecified() {
metadata.Destination = destination
}
if len(buffers) > 0 {
if len(packetBuffers) > 0 {
err = sniff.PeekPacket(
ctx,
metadata,
buffer.Bytes(),
sniffBuffer.Bytes(),
sniff.QUICClientHello,
)
} else {
var packetSniffers []sniff.PacketSniffer
if len(action.PacketSniffers) > 0 {
packetSniffers = action.PacketSniffers
} else {
packetSniffers = []sniff.PacketSniffer{
sniff.DomainNameQuery,
sniff.QUICClientHello,
sniff.STUNMessage,
sniff.UTP,
sniff.UDPTracker,
sniff.DTLSRecord,
}
}
err = sniff.PeekPacket(
ctx, metadata,
buffer.Bytes(),
action.PacketSniffers...,
sniffBuffer.Bytes(),
packetSniffers...,
)
}
buffers = append(buffers, buffer)
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(buffers) == 0 {
packetBuffer := N.NewPacketBuffer()
*packetBuffer = N.PacketBuffer{
Buffer: sniffBuffer,
Destination: destination,
}
packetBuffers = append(packetBuffers, packetBuffer)
if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(packetBuffers) == 0 {
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue
}