Add multicast filter

This commit is contained in:
世界 2023-11-01 20:57:42 +08:00
parent b93db9639d
commit 150b116231
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 60 additions and 0 deletions

View file

@ -70,6 +70,7 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.tun.CreateVectorisedWriter()}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err

52
stack_gvisor_filter.go Normal file
View file

@ -0,0 +1,52 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
type LinkEndpointFilter struct {
stack.LinkEndpoint
Writer N.VectorisedWriter
}
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.Writer})
}
var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
type networkDispatcherFilter struct {
stack.NetworkDispatcher
writer N.VectorisedWriter
}
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
var network header.Network
if protocol == header.IPv4ProtocolNumber {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
network = header.IPv4(headerPackets)
}
} else {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
network = header.IPv6(headerPackets)
}
}
if network == nil {
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
}

View file

@ -233,6 +233,10 @@ func (s *System) acceptLoop(listener net.Listener) {
}
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
@ -246,6 +250,9 @@ func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
}
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())