diff --git a/adapter/outbound.go b/adapter/outbound.go index 2c2b1091..86a3bf5f 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -5,6 +5,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" ) @@ -18,6 +19,11 @@ type Outbound interface { N.Dialer } +type DirectRouteOutbound interface { + Outbound + NewDirectRouteConnection(metadata InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) +} + type OutboundRegistry interface { option.OutboundOptionsRegistry CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) diff --git a/adapter/router.go b/adapter/router.go index b82cb5d8..14a8f791 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -8,6 +8,7 @@ import ( "sync" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-tun" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" @@ -19,7 +20,7 @@ import ( type Router interface { Lifecycle ConnectionRouter - PreMatch(metadata InboundContext) error + PreMatch(metadata InboundContext, context tun.DirectRouteContext) (tun.DirectRouteDestination, error) ConnectionRouterEx RuleSet(tag string) (RuleSet, bool) NeedWIFIState() bool diff --git a/go.mod b/go.mod index bc2c1036..5792424b 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/sagernet/cors v1.2.1 github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gomobile v0.1.4 - github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff + github.com/sagernet/gvisor v0.0.0-20250217052116-ed66b6946f72 github.com/sagernet/quic-go v0.49.0-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/sing v0.6.2-0.20250210105917-3464ed3babc0 @@ -34,7 +34,7 @@ require ( github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.2.0 - github.com/sagernet/sing-tun v0.6.1 + github.com/sagernet/sing-tun v0.6.2-0.20250217135654-784bb584392f github.com/sagernet/sing-vmess v0.2.0 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/tailscale v1.79.0-mod.1 diff --git a/go.sum b/go.sum index 6d2dd2b2..8a6c3cc3 100644 --- a/go.sum +++ b/go.sum @@ -171,8 +171,8 @@ github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQ github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o= github.com/sagernet/gomobile v0.1.4 h1:WzX9ka+iHdupMgy2Vdich+OAt7TM8C2cZbIbzNjBrJY= github.com/sagernet/gomobile v0.1.4/go.mod h1:Pqq2+ZVvs10U7xK+UwJgwYWUykewi8H6vlslAO73n9E= -github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs= -github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= +github.com/sagernet/gvisor v0.0.0-20250217052116-ed66b6946f72 h1:Jgv6N59yiVMEwimTcFV1EVcu2Aa7R2Wh1ZAYNzWP2qA= +github.com/sagernet/gvisor v0.0.0-20250217052116-ed66b6946f72/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= @@ -194,8 +194,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.0 h1:cLKe4OAOFwuhmAIuPLj//CIL7Q9js+pIDardhJ+/osk= github.com/sagernet/sing-shadowtls v0.2.0/go.mod h1:agU+Fw5X+xnWVyRHyFthoZCX3MfWKCFPm4JUf+1oaxo= -github.com/sagernet/sing-tun v0.6.1 h1:4l0+gnEKcGjlWfUVTD+W0BRApqIny/lU2ZliurE+VMo= -github.com/sagernet/sing-tun v0.6.1/go.mod h1:fisFCbC4Vfb6HqQNcwPJi2CDK2bf0Xapyz3j3t4cnHE= +github.com/sagernet/sing-tun v0.6.2-0.20250217135654-784bb584392f h1:VEtmCNfk8RuZcYz/Xx63F2QjcgG3z/7Pa0uiJciQo+Y= +github.com/sagernet/sing-tun v0.6.2-0.20250217135654-784bb584392f/go.mod h1:UiOi1ombGaAzWkGSgH4qcP7Zpq8FjWc1uQmleK8oPCE= github.com/sagernet/sing-vmess v0.2.0 h1:pCMGUXN2k7RpikQV65/rtXtDHzb190foTfF9IGTMZrI= github.com/sagernet/sing-vmess v0.2.0/go.mod h1:jDAZ0A0St1zVRkyvhAPRySOFfhC+4SQtO5VYyeFotgA= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 6c2c111b..0694bc61 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/sing-box/adapter" @@ -205,8 +206,10 @@ func (t *Endpoint) Start(stage adapter.StartStage) error { ipStack := t.server.ExportNetstack().ExportIPStack() ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(t.ctx, ipStack, t).HandlePacket) - udpForwarder := tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(t.ctx, ipStack, t, t.udpTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(t.ctx, ipStack, t, t.udpTimeout) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) t.stack = ipStack localBackend := t.server.ExportLocalBackend() @@ -377,7 +380,7 @@ func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return udpConn, nil } -func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { +func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { tsFilter := t.filter.Load() if tsFilter != nil { var ipProto ipproto.Proto @@ -390,9 +393,9 @@ func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destina response := tsFilter.Check(source.Addr, destination.Addr, destination.Port, ipProto) switch response { case filter.Drop: - return syscall.ECONNRESET + return nil, syscall.ECONNREFUSED case filter.DropSilently: - return tun.ErrDrop + return nil, tun.ErrDrop } } return t.router.PreMatch(adapter.InboundContext{ @@ -401,7 +404,7 @@ func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destina Network: network, Source: source, Destination: destination, - }) + }, routeContext) } func (t *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index 00cc0561..fc655d0d 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -8,6 +8,7 @@ import ( "runtime" "strconv" "strings" + "syscall" "time" "github.com/sagernet/sing-box/adapter" @@ -438,15 +439,21 @@ func (t *Inbound) Close() error { ) } -func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { - return t.router.PreMatch(adapter.InboundContext{ +func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + routeDestination, err := t.router.PreMatch(adapter.InboundContext{ Inbound: t.tag, InboundType: C.TypeTun, Network: network, Source: source, Destination: destination, InboundOptions: t.inboundOptions, - }) + }, routeContext) + if err != nil { + if !E.IsMulti(err, tun.ErrDrop, syscall.ECONNREFUSED) { + t.logger.Warn(err) + } + } + return routeDestination, err } func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index e167bec1..88e5458e 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -43,7 +44,7 @@ type Endpoint struct { func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { ep := &Endpoint{ - Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMPv4, N.NetworkICMPv6}, options.DialerOptions), ctx: ctx, router: router, dnsRouter: service.FromContext[adapter.DNSRouter](ctx), @@ -132,14 +133,14 @@ func (w *Endpoint) InterfaceUpdated() { return } -func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { +func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, context tun.DirectRouteContext) (tun.DirectRouteDestination, error) { return w.router.PreMatch(adapter.InboundContext{ Inbound: w.Tag(), InboundType: w.Type(), Network: network, Source: source, Destination: destination, - }) + }, context) } func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { @@ -220,3 +221,12 @@ func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } return w.endpoint.ListenPacket(ctx, destination) } + +func (w *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + destination, err := w.endpoint.NewDirectRouteConnection(metadata, routeContext) + if err != nil { + return nil, err + } + w.logger.Info("linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) + return destination, nil +} diff --git a/route/route.go b/route/route.go index 55a83d15..9ccf216a 100644 --- a/route/route.go +++ b/route/route.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-mux" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing-vmess" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -271,19 +272,36 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m return nil } -func (r *Router) PreMatch(metadata adapter.InboundContext) error { +func (r *Router) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil) if err != nil { - return err + return nil, err } if selectedRule == nil { - return nil + defaultOutbound := r.outbound.Default() + if !common.Contains(defaultOutbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by default outbound: ", defaultOutbound.Tag()) + } + return defaultOutbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext) } - rejectAction, isReject := selectedRule.Action().(*rule.RuleActionReject) - if !isReject { - return nil + switch action := selectedRule.Action().(type) { + case *rule.RuleActionReject: + return nil, action.Error(context.Background()) + case *rule.RuleActionRoute: + if routeContext == nil { + return nil, nil + } + outbound, loaded := r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, E.New("outbound not found: ", action.Outbound) + } + if !common.Contains(outbound.Network(), metadata.Network) { + return nil, E.New(metadata.Network, " is not supported by outbound: ", action.Outbound) + } + return outbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext) + default: + return nil, nil } - return rejectAction.Error(context.Background()) } func (r *Router) matchRule( diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 7a17b8f3..4e2d6a24 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -5,6 +5,7 @@ import ( "net/netip" "time" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" @@ -17,6 +18,8 @@ type Device interface { N.Dialer Start() error SetDevice(device *device.Device) + Inet4Address() netip.Addr + Inet6Address() netip.Addr } type DeviceOptions struct { @@ -41,3 +44,8 @@ func NewDevice(options DeviceOptions) (Device, error) { return newSystemStackDevice(options) } } + +type NatDevice interface { + Device + CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) +} diff --git a/transport/wireguard/device_nat.go b/transport/wireguard/device_nat.go new file mode 100644 index 00000000..2c482d30 --- /dev/null +++ b/transport/wireguard/device_nat.go @@ -0,0 +1,85 @@ +package wireguard + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" +) + +var _ Device = (*natDeviceWrapper)(nil) + +type natDeviceWrapper struct { + Device + gVisorOutbound + packetOutbound chan *buf.Buffer + mapping *tun.NatMapping + writer *tun.NatWriter + buffer [][]byte +} + +func NewNATDevice(upstream Device, ipRewrite bool) NatDevice { + wrapper := &natDeviceWrapper{ + Device: upstream, + gVisorOutbound: newGVisorOutbound(), + packetOutbound: make(chan *buf.Buffer, 256), + mapping: tun.NewNatMapping(ipRewrite), + } + if ipRewrite { + wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address()) + } + return wrapper +} + +func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { + for _, buffer := range bufs { + handled, err := d.mapping.WritePacket(buffer[offset:]) + if handled { + if err != nil { + return 0, err + } + } else { + d.buffer = append(d.buffer, buffer) + } + } + if len(d.buffer) > 0 { + _, err := d.Device.Write(d.buffer, offset) + if err != nil { + return 0, err + } + d.buffer = d.buffer[:0] + } + return 0, nil +} + +func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + session := tun.DirectRouteSession{ + Source: metadata.Source.Addr, + Destination: metadata.Destination.Addr, + } + d.mapping.CreateSession(session, routeContext) + return &natDestinationWrapper{d, session}, nil +} + +var _ tun.DirectRouteDestination = (*natDestinationWrapper)(nil) + +type natDestinationWrapper struct { + device *natDeviceWrapper + session tun.DirectRouteSession +} + +func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.packetOutbound <- buffer + return nil +} + +func (d *natDestinationWrapper) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *natDestinationWrapper) Timeout() bool { + return false +} diff --git a/transport/wireguard/device_nat_gvisor.go b/transport/wireguard/device_nat_gvisor.go new file mode 100644 index 00000000..edecba34 --- /dev/null +++ b/transport/wireguard/device_nat_gvisor.go @@ -0,0 +1,48 @@ +//go:build with_gvisor + +package wireguard + +import ( + "github.com/sagernet/gvisor/pkg/tcpip/stack" +) + +type gVisorOutbound struct { + outbound chan *stack.PacketBuffer +} + +func newGVisorOutbound() gVisorOutbound { + return gVisorOutbound{ + outbound: make(chan *stack.PacketBuffer, 256), + } +} + +func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + select { + case packet := <-d.outbound: + defer packet.DecRef() + var copyN int + /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) { + copyN += copy(bufs[0][offset+copyN:], view.AsSlice()) + })*/ + for _, view := range packet.AsSlices() { + copyN += copy(bufs[0][offset+copyN:], view) + } + sizes[0] = copyN + return 1, nil + case packet := <-d.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil + default: + } + return d.Device.Read(bufs, sizes, offset) +} + +func (d *natDestinationWrapper) WritePacketBuffer(packetBuffer *stack.PacketBuffer) error { + println("read from wg") + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(packetBuffer) + } + d.device.outbound <- packetBuffer + return nil +} diff --git a/transport/wireguard/device_nat_non_gvisor.go b/transport/wireguard/device_nat_non_gvisor.go new file mode 100644 index 00000000..e81e1e31 --- /dev/null +++ b/transport/wireguard/device_nat_non_gvisor.go @@ -0,0 +1,20 @@ +//go:build !with_gvisor + +package wireguard + +type gVisorOutbound struct{} + +func newGVisorOutbound() gVisorOutbound { + return gVisorOutbound{} +} + +func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + select { + case packet := <-d.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil + default: + } + return d.Device.Read(bufs, sizes, offset) +} diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index f9440f02..eb759a2b 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -5,6 +5,7 @@ package wireguard import ( "context" "net" + "net/netip" "os" "github.com/sagernet/gvisor/pkg/buffer" @@ -14,9 +15,12 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -24,25 +28,30 @@ import ( wgTun "github.com/sagernet/wireguard-go/tun" ) -var _ Device = (*stackDevice)(nil) +var _ NatDevice = (*stackDevice)(nil) type stackDevice struct { - stack *stack.Stack - mtu uint32 - events chan wgTun.Event - outbound chan *stack.PacketBuffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + addr4 tcpip.Address + addr6 tcpip.Address + mapping *tun.NatMapping + writer *tun.NatWriter } func newStackDevice(options DeviceOptions) (*stackDevice, error) { tunDevice := &stackDevice{ - mtu: options.MTU, - events: make(chan wgTun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - done: make(chan struct{}), + mtu: options.MTU, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), + mapping: tun.NewNatMapping(true), } ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice)) if err != nil { @@ -68,10 +77,14 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) { return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) } } + tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address()) tunDevice.stack = ipStack if options.Handler != nil { ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) } return tunDevice, nil } @@ -130,6 +143,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } +func (w *stackDevice) Inet4Address() netip.Addr { + return netip.AddrFrom4(w.addr4.As4()) +} + +func (w *stackDevice) Inet6Address() netip.Addr { + return netip.AddrFrom16(w.addr6.As16()) +} + func (w *stackDevice) SetDevice(device *device.Device) { } @@ -144,20 +165,24 @@ func (w *stackDevice) File() *os.File { func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { select { - case packetBuffer, ok := <-w.outbound: + case packet, ok := <-w.outbound: if !ok { return 0, os.ErrClosed } - defer packetBuffer.DecRef() - p := bufs[0] - p = p[offset:] - n := 0 - for _, slice := range packetBuffer.AsSlices() { - n += copy(p[n:], slice) + defer packet.DecRef() + var copyN int + /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) { + copyN += copy(bufs[0][offset+copyN:], view.AsSlice()) + })*/ + for _, view := range packet.AsSlices() { + copyN += copy(bufs[0][offset+copyN:], view) } - sizes[0] = n - count = 1 - return + sizes[0] = copyN + return 1, nil + case packet := <-w.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil case <-w.done: return 0, os.ErrClosed } @@ -169,6 +194,14 @@ func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) { if len(b) == 0 { continue } + handled, err := w.mapping.WritePacket(b) + if handled { + if err != nil { + return count, err + } + count++ + continue + } var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(b) { case header.IPv4Version: @@ -282,3 +315,157 @@ func (ep *wireEndpoint) Close() { func (ep *wireEndpoint) SetOnCloseAction(f func()) { } + +func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + /* var wq waiter.Queue + ep, err := raw.NewEndpoint(w.stack, ipv4.ProtocolNumber, icmp.ProtocolNumber4, &wq) + if err != nil { + return nil, E.Cause(gonet.TranslateNetstackError(err), "create endpoint") + } + err = ep.Connect(tcpip.FullAddress{ + NIC: tun.DefaultNIC, + Port: metadata.Destination.Port, + Addr: tun.AddressFromAddr(metadata.Destination.Addr), + }) + if err != nil { + ep.Close() + return nil, E.Cause(gonet.TranslateNetstackError(err), "ICMP connect ", metadata.Destination) + } + fmt.Println("linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) + destination := &endpointNatDestination{ + ep: ep, + wq: &wq, + context: routeContext, + } + go destination.loopRead() + return destination, nil*/ + session := tun.DirectRouteSession{ + Source: metadata.Source.Addr, + Destination: metadata.Destination.Addr, + } + w.mapping.CreateSession(session, routeContext) + return &stackNatDestination{ + device: w, + session: session, + }, nil +} + +type stackNatDestination struct { + device *stackDevice + session tun.DirectRouteSession +} + +func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.packetOutbound <- buffer + return nil +} + +func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(buffer) + } + d.device.outbound <- buffer + return nil +} + +func (d *stackNatDestination) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *stackNatDestination) Timeout() bool { + return false +} + +/*type endpointNatDestination struct { + ep tcpip.Endpoint + wq *waiter.Queue + networkProto tcpip.NetworkProtocolNumber + context tun.DirectRouteContext + done chan struct{} +} + +func (d *endpointNatDestination) loopRead() { + for { + println("start read") + buffer, err := commonRead(d.ep, d.wq, d.done) + if err != nil { + log.Error(err) + return + } + println("done read") + ipHdr := header.IPv4(buffer.Bytes()) + if ipHdr.TransportProtocol() != header.ICMPv4ProtocolNumber { + buffer.Release() + continue + } + icmpHdr := header.ICMPv4(ipHdr.Payload()) + if icmpHdr.Type() != header.ICMPv4EchoReply { + buffer.Release() + continue + } + fmt.Println("read echo reply") + _ = d.context.WritePacket(ipHdr) + buffer.Release() + } +} + +func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, done chan struct{}) (*buf.Buffer, error) { + buffer := buf.NewPacket() + result, err := ep.Read(buffer, tcpip.ReadOptions{}) + if err != nil { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) + for { + result, err = ep.Read(buffer, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + break + } + select { + case <-notifyCh: + case <-done: + buffer.Release() + return nil, context.DeadlineExceeded + } + } + } + return nil, gonet.TranslateNetstackError(err) + } + buffer.Truncate(result.Count) + return buffer, nil +} + +func (d *endpointNatDestination) WritePacket(buffer *buf.Buffer) error { + _, err := d.ep.Write(buffer, tcpip.WriteOptions{}) + if err != nil { + return gonet.TranslateNetstackError(err) + } + return nil +} + +func (d *endpointNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { + data := buffer.ToView().AsSlice() + println("write echo request buffer :" + fmt.Sprint(data)) + _, err := d.ep.Write(bytes.NewReader(data), tcpip.WriteOptions{}) + if err != nil { + log.Error(err) + return gonet.TranslateNetstackError(err) + } + return nil +} + +func (d *endpointNatDestination) Close() error { + d.ep.Abort() + close(d.done) + return nil +} + +func (d *endpointNatDestination) Timeout() bool { + return false +} +*/ diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index fa54f332..90abee4b 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -28,16 +28,36 @@ type systemDevice struct { batchDevice tun.LinuxTUN events chan wgTun.Event closeOnce sync.Once + addr4 netip.Addr + addr6 netip.Addr } func newSystemDevice(options DeviceOptions) (*systemDevice, error) { if options.Name == "" { options.Name = tun.CalculateInterfaceName("wg") } + var inet4Address netip.Addr + var inet6Address netip.Addr + if len(options.Address) > 0 { + if prefix := common.Find(options.Address, func(it netip.Prefix) bool { + return it.Addr().Is4() + }); prefix.IsValid() { + inet4Address = prefix.Addr() + } + } + if len(options.Address) > 0 { + if prefix := common.Find(options.Address, func(it netip.Prefix) bool { + return it.Addr().Is6() + }); prefix.IsValid() { + inet6Address = prefix.Addr() + } + } return &systemDevice{ options: options, dialer: options.CreateDialer(options.Name), events: make(chan wgTun.Event, 1), + addr4: inet4Address, + addr6: inet6Address, }, nil } @@ -49,6 +69,14 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr return w.dialer.ListenPacket(ctx, destination) } +func (w *systemDevice) Inet4Address() netip.Addr { + return w.addr4 +} + +func (w *systemDevice) Inet6Address() netip.Addr { + return w.addr6 +} + func (w *systemDevice) SetDevice(device *device.Device) { } diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index 69ce9170..af982948 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -10,6 +10,8 @@ import ( "os" "strings" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -29,6 +31,7 @@ type Endpoint struct { ipcConf string allowedAddress []netip.Prefix tunDevice Device + natDevice NatDevice device *device.Device pauseManager pause.Manager pauseCallback *list.Element[pause.Callback] @@ -111,12 +114,17 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) { if err != nil { return nil, E.Cause(err, "create WireGuard device") } + natDevice, isNatDevice := tunDevice.(NatDevice) + if !isNatDevice { + natDevice = NewNATDevice(tunDevice, true) + } return &Endpoint{ options: options, peers: peers, ipcConf: ipcConf, allowedAddress: allowedAddresses, tunDevice: tunDevice, + natDevice: natDevice, }, nil } @@ -176,7 +184,13 @@ func (e *Endpoint) Start(resolve bool) error { e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, } - wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers) + var deviceInput Device + if e.natDevice != nil { + deviceInput = e.natDevice + } else { + deviceInput = e.tunDevice + } + wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers) e.tunDevice.SetDevice(wgDevice) ipcConf := e.ipcConf for _, peer := range e.peers { @@ -194,6 +208,20 @@ func (e *Endpoint) Start(resolve bool) error { return nil } +func (e *Endpoint) Close() error { + if e.device != nil { + e.device.Close() + } + if e.pauseCallback != nil { + e.pauseManager.UnregisterCallback(e.pauseCallback) + } + return nil +} + +func (e *Endpoint) BindUpdate() error { + return e.device.BindUpdate() +} + func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if !destination.Addr.IsValid() { return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") @@ -208,18 +236,11 @@ func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return e.tunDevice.ListenPacket(ctx, destination) } -func (e *Endpoint) BindUpdate() error { - return e.device.BindUpdate() -} - -func (e *Endpoint) Close() error { - if e.device != nil { - e.device.Close() +func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext) (tun.DirectRouteDestination, error) { + if e.natDevice == nil { + return nil, os.ErrInvalid } - if e.pauseCallback != nil { - e.pauseManager.UnregisterCallback(e.pauseCallback) - } - return nil + return e.natDevice.CreateDestination(metadata, routeContext) } func (e *Endpoint) onPauseUpdated(event int) {