Export interface for WireGuard

This commit is contained in:
世界 2024-11-21 18:12:21 +08:00
parent 8a18f0c99e
commit 4ebeb2fa86
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 269 additions and 49 deletions

2
go.mod
View file

@ -9,7 +9,7 @@ require (
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-alpha.11
github.com/sagernet/sing v0.6.0-alpha.18
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0

4
go.sum
View file

@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.6.0-alpha.11 h1:ZcZlA0/vdDeiipAbjK73x9VabGJ/RRcAJgWhOo/OoBk=
github.com/sagernet/sing v0.6.0-alpha.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.18 h1:ih4CurU8KvbhfagYjSqVrE2LR0oBSXSZTNH2sAGPGiM=
github.com/sagernet/sing v0.6.0-alpha.18/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=

View file

@ -0,0 +1,136 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"net/netip"
tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)
// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16
// Destination returns the value of the "destination port" field.
DestinationPort() uint16
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)
// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// Payload returns the data carried in the transport buffer.
Payload() []byte
}
// ChecksummableTransport is a Transport that supports checksumming.
type ChecksummableTransport interface {
Transport
// SetSourcePortWithChecksumUpdate sets the source port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetSourcePortWithChecksumUpdate(port uint16)
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetDestinationPortWithChecksumUpdate(port uint16)
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
//
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
// fully calculated checksum. Otherwise, it is assumed to hold a partially
// calculated checksum which only reflects the pseudo header.
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
}
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
DestinationAddr() netip.Addr
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)
SetDestinationAddr(addr netip.Addr)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber
// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte
// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
// ChecksummableNetwork is a Network that supports checksumming.
type ChecksummableNetwork interface {
Network
// SetSourceAddressAndChecksum sets the source address and updates the
// checksum to reflect the new address.
SetSourceAddressWithChecksumUpdate(tcpip.Address)
// SetDestinationAddressAndChecksum sets the destination address and
// updates the checksum to reflect the new address.
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
}

View file

@ -19,13 +19,11 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const WithGVisor = true
const defaultNIC tcpip.NICID = 1
const DefaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
@ -68,28 +66,11 @@ func (t *GVisor) Start() error {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
ipStack, err := NewGVisorStack(linkEndpoint)
if err != nil {
return err
}
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: t.ctx,
stack: t.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
}
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
icmp.NewProtocol6,
},
})
err := ipStack.CreateNIC(defaultNIC, ep)
err := ipStack.CreateNIC(DefaultNIC, ep)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
})
err = ipStack.SetSpoofing(defaultNIC, true)
err = ipStack.SetSpoofing(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
err = ipStack.SetPromiscuousMode(defaultNIC, true)
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}

51
stack_gvisor_tcp.go Normal file
View file

@ -0,0 +1,51 @@
//go:build with_gvisor
package tun
import (
"context"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
forwarder *tcp.Forwarder
}
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
}
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
return f.forwarder.HandlePacket(id, pkt)
}
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: f.ctx,
stack: f.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
}

View file

@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
DefaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,

View file

@ -38,7 +38,7 @@ func (m *Mixed) Start() error {
return err
}
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
ipStack, err := NewGVisorStack(endpoint)
if err != nil {
return err
}
@ -137,7 +137,7 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
_, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
@ -151,10 +151,10 @@ func (m *Mixed) processPacket(packet []byte) bool {
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = m.processIPv4(packet)
case 6:
case header.IPv6Version:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)

View file

@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
}))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetChecksum(0)
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}

34
stack_system_packet.go Normal file
View file

@ -0,0 +1,34 @@
package tun
import (
"net/netip"
"syscall"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
func PacketIPVersion(packet []byte) int {
return header.IPVersion(packet)
}
func PacketFillHeader(packet []byte, ipVersion int) {
if PacketOffset > 0 {
switch ipVersion {
case header.IPv4Version:
packet[3] = syscall.AF_INET
case header.IPv6Version:
packet[3] = syscall.AF_INET6
}
}
}
func PacketDestination(packet []byte) netip.Addr {
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
return header.IPv4(packet).DestinationAddr()
case header.IPv6Version:
return header.IPv6(packet).DestinationAddr()
default:
return netip.Addr{}
}
}

3
tun.go
View file

@ -1,6 +1,7 @@
package tun
import (
"github.com/sagernet/sing/common/control"
"io"
"net"
"net/netip"
@ -54,6 +55,7 @@ type Options struct {
MTU uint32
GSO bool
AutoRoute bool
InterfaceScope bool
Inet4Gateway netip.Addr
Inet6Gateway netip.Addr
DNSServers []netip.Addr
@ -74,6 +76,7 @@ type Options struct {
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceFinder control.InterfaceFinder
InterfaceMonitor DefaultInterfaceMonitor
FileDescriptor int
Logger logger.Logger

View file

@ -3,6 +3,7 @@ package tun
import (
"errors"
"fmt"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"net"
"net/netip"
"os"
@ -96,9 +97,10 @@ var (
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if buffers[0].Byte(0)>>4 == 4 {
switch header.IPVersion(buffers[0].Bytes()) {
case header.IPv4Version:
packetHeader = packetHeader4[:]
} else {
case header.IPv6Version:
packetHeader = packetHeader6[:]
}
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
@ -250,6 +252,7 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
func (t *NativeTun) setRoutes() error {
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
@ -262,14 +265,22 @@ func (t *NativeTun) setRoutes() error {
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_ADD, destination, gateway)
var interfaceIndex int
if t.options.InterfaceScope {
iff, err := t.options.InterfaceFinder.ByName(t.options.Name)
if err != nil {
return err
}
interfaceIndex = iff.Index
}
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, destination, gateway)
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, destination, gateway)
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
@ -300,7 +311,7 @@ func (t *NativeTun) unsetRoutes() error {
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_DELETE, destination, gateway)
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
if err != nil {
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
}
@ -317,7 +328,7 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
return block(socketFd)
}
func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
func execRoute(rtmType int, interfaceScope bool, interfaceIndex int, destination netip.Prefix, gateway netip.Addr) error {
routeMessage := route.RouteMessage{
Type: rtmType,
Version: unix.RTM_VERSION,
@ -326,6 +337,10 @@ func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error
}
if rtmType == unix.RTM_ADD {
routeMessage.Flags |= unix.RTF_UP
if interfaceScope {
routeMessage.Flags |= unix.RTF_IFSCOPE
routeMessage.Index = interfaceIndex
}
}
if gateway.Is4() {
routeMessage.Addrs = []route.Addr{