diff --git a/stack.go b/stack.go index 2e96e9d..7eb7854 100644 --- a/stack.go +++ b/stack.go @@ -2,6 +2,8 @@ package tun import ( "context" + "encoding/binary" + "net" "net/netip" "github.com/sagernet/sing/common/control" @@ -52,3 +54,13 @@ func NewStack( return nil, E.New("unknown stack: ", stack) } } + +func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr { + if len(inet4Address) == 0 { + return netip.Addr{} + } + prefix := inet4Address[0] + var broadcastAddr [4]byte + binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32))) + return netip.AddrFrom4(broadcastAddr) +} diff --git a/stack_gvisor.go b/stack_gvisor.go index 616598c..f32f7fc 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -34,6 +34,7 @@ type GVisor struct { tunMtu uint32 endpointIndependentNat bool udpTimeout int64 + broadcastAddr netip.Addr handler Handler logger logger.Logger stack *stack.Stack @@ -59,6 +60,7 @@ func NewGVisor( tunMtu: options.MTU, endpointIndependentNat: options.EndpointIndependentNat, udpTimeout: options.UDPTimeout, + broadcastAddr: BroadcastAddr(options.Inet4Address), handler: options.Handler, logger: options.Logger, } @@ -70,7 +72,7 @@ func (t *GVisor) Start() error { if err != nil { return err } - linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.tun.CreateVectorisedWriter()} + linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()} ipStack, err := newGVisorStack(linkEndpoint) if err != nil { return err diff --git a/stack_gvisor_filter.go b/stack_gvisor_filter.go index 7f943ae..4b6ba98 100644 --- a/stack_gvisor_filter.go +++ b/stack_gvisor_filter.go @@ -3,6 +3,8 @@ package tun import ( + "net/netip" + "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" @@ -14,18 +16,20 @@ var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil) type LinkEndpointFilter struct { stack.LinkEndpoint - Writer N.VectorisedWriter + BroadcastAddress netip.Addr + Writer N.VectorisedWriter } func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) { - w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.Writer}) + w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer}) } var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil) type networkDispatcherFilter struct { stack.NetworkDispatcher - writer N.VectorisedWriter + broadcastAddress netip.Addr + writer N.VectorisedWriter } func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) { @@ -44,7 +48,7 @@ func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkPro return } destination := AddrFromAddress(network.DestinationAddress()) - if destination.IsMulticast() || !destination.IsGlobalUnicast() { + if destination == w.broadcastAddress || !destination.IsGlobalUnicast() { _, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices()) return } diff --git a/stack_system.go b/stack_system.go index e2b87cc..6a9598d 100644 --- a/stack_system.go +++ b/stack_system.go @@ -31,6 +31,7 @@ type System struct { inet4Address netip.Addr inet6ServerAddress netip.Addr inet6Address netip.Addr + broadcastAddr netip.Addr udpTimeout int64 tcpListener net.Listener tcpListener6 net.Listener @@ -60,6 +61,7 @@ func NewSystem(options StackOptions) (Stack, error) { logger: options.Logger, inet4Prefixes: options.Inet4Address, inet6Prefixes: options.Inet6Address, + broadcastAddr: BroadcastAddr(options.Inet4Address), bindInterface: options.ForwarderBindInterface, interfaceFinder: options.InterfaceFinder, } @@ -234,7 +236,7 @@ func (s *System) acceptLoop(listener net.Listener) { func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error { destination := packet.DestinationIP() - if destination.IsMulticast() || !destination.IsGlobalUnicast() { + if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { return common.Error(s.tun.Write(packet)) } switch packet.Protocol() {