From 209ec123ca7bafb1aa41e3263d1e9b86a91a795b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 22 Apr 2023 20:14:22 +0800 Subject: [PATCH] Update gVisor to 20230417.0 --- Makefile | 9 + go.mod | 4 +- go.sum | 8 +- gvisor.go | 4 +- gvisor_udp.go | 7 +- internal/fdbased/README.md | 3 + internal/fdbased/endpoint.go | 815 ++++++++++++++++++ internal/fdbased/endpoint_unsafe.go | 24 + internal/fdbased/mmap.go | 207 +++++ internal/fdbased/mmap_stub.go | 24 + internal/fdbased/mmap_unsafe.go | 91 ++ internal/fdbased/packet_dispatchers.go | 344 ++++++++ internal/fdbased/stopfd/stopfd.go | 52 ++ .../fdbased/stopfd/stopfd_state_autogen.go | 6 + route_gvisor.go | 2 +- route_nat_gvisor.go | 2 +- tun_darwin_gvisor.go | 2 +- tun_linux_gvisor.go | 3 +- tun_windows_gvisor.go | 2 +- 19 files changed, 1593 insertions(+), 16 deletions(-) create mode 100644 internal/fdbased/README.md create mode 100644 internal/fdbased/endpoint.go create mode 100644 internal/fdbased/endpoint_unsafe.go create mode 100644 internal/fdbased/mmap.go create mode 100644 internal/fdbased/mmap_stub.go create mode 100644 internal/fdbased/mmap_unsafe.go create mode 100644 internal/fdbased/packet_dispatchers.go create mode 100644 internal/fdbased/stopfd/stopfd.go create mode 100644 internal/fdbased/stopfd/stopfd_state_autogen.go diff --git a/Makefile b/Makefile index 01cc64e..f3b0209 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,12 @@ +build: + GOOS=darwin GOARCH=arm64 go build -v -tags with_gvisor . + GOOS=ios GOARCH=arm64 go build -v -tags with_gvisor . + GOOS=linux GOARCH=amd64 go build -v -tags with_gvisor . + GOOS=linux GOARCH=arm64 go build -v -tags with_gvisor . + GOOS=linux GOARCH=386 go build -v -tags with_gvisor . + GOOS=linux GOARCH=arm go build -v -tags with_gvisor . + GOOS=windows GOARCH=amd64 go build -v -tags with_gvisor . + fmt: @gofumpt -l -w . @gofmt -s -w . diff --git a/go.mod b/go.mod index c60f705..1952530 100644 --- a/go.mod +++ b/go.mod @@ -9,11 +9,11 @@ require ( github.com/sagernet/sing v0.2.4 golang.org/x/net v0.9.0 golang.org/x/sys v0.7.0 - gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c + gvisor.dev/gvisor v0.0.0-20230415003630-3981d5d5e523 ) require ( github.com/google/btree v1.0.1 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index 7be41fd..c8e4e87 100644 --- a/go.sum +++ b/go.sum @@ -18,7 +18,7 @@ golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h1:m5lcgWnL3OElQNVyp3qcncItJ2c0sQlSGjYK2+nJTA4= -gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +gvisor.dev/gvisor v0.0.0-20230415003630-3981d5d5e523 h1:zUQYeyyPLnSR6yMvLSOmLH37xDWCZ7BqlpE69fE5K3Q= +gvisor.dev/gvisor v0.0.0-20230415003630-3981d5d5e523/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q= diff --git a/gvisor.go b/gvisor.go index b525509..4a75a02 100644 --- a/gvisor.go +++ b/gvisor.go @@ -155,7 +155,7 @@ func (t *GVisor) Start() error { } }() }) - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool { + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer stack.PacketBufferPtr) bool { if t.router != nil { var routeSession RouteSession routeSession.Network = syscall.IPPROTO_TCP @@ -218,7 +218,7 @@ func (t *GVisor) Start() error { } }() }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, buffer stack.PacketBufferPtr) bool { if t.router != nil { var routeSession RouteSession routeSession.Network = syscall.IPPROTO_UDP diff --git a/gvisor_udp.go b/gvisor_udp.go index e62157d..f23bd3a 100644 --- a/gvisor_udp.go +++ b/gvisor_udp.go @@ -15,6 +15,7 @@ import ( "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -33,7 +34,7 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, u } } -func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { var upstreamMetadata M.Metadata upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort) upstreamMetadata.Destination = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.LocalAddress)), id.LocalPort) @@ -93,9 +94,9 @@ func (w *UDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) }) if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber { - xsum := udpHdr.CalculateChecksum(header.ChecksumCombine( + xsum := udpHdr.CalculateChecksum(checksum.Combine( route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), - packet.Data().AsRange().Checksum(), + packet.Data().Checksum(), )) if xsum != math.MaxUint16 { xsum = ^xsum diff --git a/internal/fdbased/README.md b/internal/fdbased/README.md new file mode 100644 index 0000000..6bb26b5 --- /dev/null +++ b/internal/fdbased/README.md @@ -0,0 +1,3 @@ +# fdbased + +Version: release-20230417.0 \ No newline at end of file diff --git a/internal/fdbased/endpoint.go b/internal/fdbased/endpoint.go new file mode 100644 index 0000000..d2a5bf2 --- /dev/null +++ b/internal/fdbased/endpoint.go @@ -0,0 +1,815 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +// Package fdbased provides the implemention of data-link layer endpoints +// backed by boundary-preserving file descriptors (e.g., TUN devices, +// seqpacket/datagram sockets). +// +// FD based endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +// +// FD based endpoints can use more than one file descriptor to read incoming +// packets. If there are more than one FDs specified and the underlying FD is an +// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the +// host kernel will consistently hash the packets to the sockets. This ensures +// that packets for the same TCP streams are not reordered. +// +// Similarly if more than one FD's are specified where the underlying FD is not +// AF_PACKET then it's the caller's responsibility to ensure that all inbound +// packets on the descriptors are consistently 5 tuple hashed to one of the +// descriptors to prevent TCP reordering. +// +// Since netstack today does not compute 5 tuple hashes for outgoing packets we +// only use the first FD to write outbound packets. Once 5 tuple hashes for +// all outbound packets are available we will make use of all underlying FD's to +// write outbound packets. +package fdbased + +import ( + "fmt" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// linkDispatcher reads packets from the link FD and dispatches them to the +// NetworkDispatcher. +type linkDispatcher interface { + Stop() + dispatch() (bool, tcpip.Error) + release() +} + +// PacketDispatchMode are the various supported methods of receiving and +// dispatching packets from the underlying FD. +type PacketDispatchMode int + +// BatchSize is the number of packets to write in each syscall. It is 47 +// because when GvisorGSO is in use then a single 65KB TCP segment can get +// split into 46 segments of 1420 bytes and a single 216 byte segment. +const BatchSize = 47 + +const ( + // Readv is the default dispatch mode and is the least performant of the + // dispatch options but the one that is supported by all underlying FD + // types. + Readv PacketDispatchMode = iota + // RecvMMsg enables use of recvmmsg() syscall instead of readv() to + // read inbound packets. This reduces # of syscalls needed to process + // packets. + // + // NOTE: recvmmsg() is only supported for sockets, so if the underlying + // FD is not a socket then the code will still fall back to the readv() + // path. + RecvMMsg + // PacketMMap enables use of PACKET_RX_RING to receive packets from the + // NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The + // primary use-case for this is runsc which uses an AF_PACKET FD to + // receive packets from the veth device. + PacketMMap +) + +func (p PacketDispatchMode) String() string { + switch p { + case Readv: + return "Readv" + case RecvMMsg: + return "RecvMMsg" + case PacketMMap: + return "PacketMMap" + default: + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) + } +} + +var ( + _ stack.LinkEndpoint = (*endpoint)(nil) + _ stack.GSOEndpoint = (*endpoint)(nil) +) + +type fdInfo struct { + fd int + isSocket bool +} + +type endpoint struct { + // fds is the set of file descriptors each identifying one inbound/outbound + // channel. The endpoint will dispatch from all inbound channels as well as + // hash outbound packets to specific channels based on the packet hash. + fds []fdInfo + + // mtu (maximum transmission unit) is the maximum size of a packet. + mtu uint32 + + // hdrSize specifies the link-layer header size. If set to 0, no header + // is added/removed; otherwise an ethernet header is used. + hdrSize int + + // addr is the address of the endpoint. + addr tcpip.LinkAddress + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // closed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + closed func(tcpip.Error) + + inboundDispatchers []linkDispatcher + + mu sync.RWMutex + // +checklocks:mu + dispatcher stack.NetworkDispatcher + + // packetDispatchMode controls the packet dispatcher used by this + // endpoint. + packetDispatchMode PacketDispatchMode + + // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is + // disabled. + gsoMaxSize uint32 + + // wg keeps track of running goroutines. + wg sync.WaitGroup + + // gsoKind is the supported kind of GSO. + gsoKind stack.SupportedGSO + + // maxSyscallHeaderBytes has the same meaning as + // Options.MaxSyscallHeaderBytes. + maxSyscallHeaderBytes uintptr + + // writevMaxIovs is the maximum number of iovecs that may be passed to + // rawfile.NonBlockingWriteIovec, as possibly limited by + // maxSyscallHeaderBytes. (No analogous limit is defined for + // rawfile.NonBlockingSendMMsg, since in that case the maximum number of + // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch + // encounters a packet whose iovec count is limited by + // maxSyscallHeaderBytes, it falls back to writing the packet using writev + // via WritePacket.) + writevMaxIovs int +} + +// Options specify the details about the fd-based endpoint to be created. +type Options struct { + // FDs is a set of FDs used to read/write packets. + FDs []int + + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // EthernetHeader if true, indicates that the endpoint should read/write + // ethernet frames instead of IP packets. + EthernetHeader bool + + // ClosedFunc is a function to be called when an endpoint's peer (if + // any) closes its end of the communication pipe. + ClosedFunc func(tcpip.Error) + + // Address is the link address for this endpoint. Only used if + // EthernetHeader is true. + Address tcpip.LinkAddress + + // SaveRestore if true, indicates that this NIC capability set should + // include CapabilitySaveRestore + SaveRestore bool + + // DisconnectOk if true, indicates that this NIC capability set should + // include CapabilityDisconnectOk. + DisconnectOk bool + + // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is + // disabled. + GSOMaxSize uint32 + + // GvisorGSOEnabled indicates whether Gvisor GSO is enabled or not. + GvisorGSOEnabled bool + + // PacketDispatchMode specifies the type of inbound dispatcher to be + // used for this endpoint. + PacketDispatchMode PacketDispatchMode + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool + + // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes + // of struct iovec, msghdr, and mmsghdr that may be passed by each host + // system call. + MaxSyscallHeaderBytes int + + // AFXDPFD is used with the experimental AF_XDP mode. + // TODO(b/240191988): Use multiple sockets. + // TODO(b/240191988): How do we handle the MTU issue? + AFXDPFD *int + + // InterfaceIndex is the interface index of the underlying device. + InterfaceIndex int +} + +// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT +// support in the host kernel. This allows us to use multiple FD's to receive +// from the same underlying NIC. The fanoutID needs to be the same for a given +// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT +// option for an FD with a fanoutID already in use by another FD for a different +// NIC will return an EINVAL. +// +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +var fanoutID atomicbitops.Int32 = atomicbitops.FromInt32(int32(unix.Getpid())) + +// New creates a new fd-based endpoint. +// +// Makes fd non-blocking, but does not take ownership of fd, which must remain +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { + caps := stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + caps |= stack.CapabilityTXChecksumOffload + } + + hdrSize := 0 + if opts.EthernetHeader { + hdrSize = header.EthernetMinimumSize + caps |= stack.CapabilityResolutionRequired + } + + if opts.SaveRestore { + caps |= stack.CapabilitySaveRestore + } + + if opts.DisconnectOk { + caps |= stack.CapabilityDisconnectOk + } + + if len(opts.FDs) == 0 { + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + } + + if opts.MaxSyscallHeaderBytes < 0 { + return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") + } + + e := &endpoint{ + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + packetDispatchMode: opts.PacketDispatchMode, + maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), + writevMaxIovs: rawfile.MaxIovs, + } + if e.maxSyscallHeaderBytes != 0 { + if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { + e.writevMaxIovs = max + } + } + + // Increment fanoutID to ensure that we don't re-use the same fanoutID + // for the next endpoint. + fid := fanoutID.Add(1) + + // Create per channel dispatchers. + for _, fd := range opts.FDs { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("unix.SetNonblock(%v) failed: %v", fd, err) + } + + isSocket, err := isSocketFD(fd) + if err != nil { + return nil, err + } + e.fds = append(e.fds, fdInfo{fd: fd, isSocket: isSocket}) + if isSocket { + if opts.GSOMaxSize != 0 { + if opts.GvisorGSOEnabled { + e.gsoKind = stack.GvisorGSOSupported + } else { + e.gsoKind = stack.HostGSOSupported + } + e.gsoMaxSize = opts.GSOMaxSize + } + } + + inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket, fid) + if err != nil { + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) + } + e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) + } + + return e, nil +} + +func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (linkDispatcher, error) { + // By default use the readv() dispatcher as it works with all kinds of + // FDs (tap/tun/unix domain sockets and af_packet). + inboundDispatcher, err := newReadVDispatcher(fd, e) + if err != nil { + return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err) + } + + if isSocket { + sa, err := unix.Getsockname(fd) + if err != nil { + return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err) + } + switch sa.(type) { + case *unix.SockaddrLinklayer: + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + // + // See: include/uapi/linux/if_packet.h (struct fanout_args). + // + // NOTE: We are using SetSockOptInt here even though the underlying + // option is actually a struct. The code follows the example in the + // kernel documentation as described at the link below: + // + // See: https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt + // + // This works out because the actual implementation for the option zero + // initializes the structure and will initialize the max_members field + // to a proper value if zero. + // + // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 + const fanoutType = unix.PACKET_FANOUT_HASH + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 + if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { + return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) + } + } + + switch e.packetDispatchMode { + case PacketMMap: + inboundDispatcher, err = newPacketMMapDispatcher(fd, e) + if err != nil { + return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err) + } + case RecvMMsg: + // If the provided FD is a socket then we optimize + // packet reads by using recvmmsg() instead of read() to + // read packets in a batch. + inboundDispatcher, err = newRecvMMsgDispatcher(fd, e) + if err != nil { + return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err) + } + case Readv: + default: + return nil, fmt.Errorf("unknown dispatch mode %d", e.packetDispatchMode) + } + } + return inboundDispatcher, nil +} + +func isSocketFD(fd int) (bool, error) { + var stat unix.Stat_t + if err := unix.Fstat(fd, &stat); err != nil { + return false, fmt.Errorf("unix.Fstat(%v,...) failed: %v", fd, err) + } + return (stat.Mode & unix.S_IFSOCK) == unix.S_IFSOCK, nil +} + +// Attach launches the goroutine that reads packets from the file descriptor and +// dispatches them via the provided dispatcher. If one is already attached, +// then nothing happens. +// +// Attach implements stack.LinkEndpoint.Attach. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + defer e.mu.Unlock() + // nil means the NIC is being removed. + if dispatcher == nil && e.dispatcher != nil { + for _, dispatcher := range e.inboundDispatchers { + dispatcher.Stop() + } + e.Wait() + e.dispatcher = nil + return + } + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + for i := range e.inboundDispatchers { + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) + } + } +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength returns the maximum size of the link-layer header. +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress returns the link address of this endpoint. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + +// virtioNetHdr is declared in linux/virtio_net.h. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +// marshal serializes h to a newly-allocated byte slice, in little-endian byte +// order. +// +// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used +// for general serialization. This makes it difficult to use go-marshal for +// virtio types, as go-marshal implicitly uses the native byte ordering. +func (h *virtioNetHdr) marshal() []byte { + buf := [virtioNetHdrSize]byte{ + 0: byte(h.flags), + 1: byte(h.gsoType), + + // Manually lay out the fields in little-endian byte order. Little endian => + // least significant bit goes to the lower address. + + 2: byte(h.hdrLen), + 3: byte(h.hdrLen >> 8), + + 4: byte(h.gsoSize), + 5: byte(h.gsoSize >> 8), + + 6: byte(h.csumStart), + 7: byte(h.csumStart >> 8), + + 8: byte(h.csumOffset), + 9: byte(h.csumOffset >> 8), + } + return buf[:] +} + +// These constants are declared in linux/virtio_net.h. +const ( + _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 + + _VIRTIO_NET_HDR_GSO_TCPV4 = 1 + _VIRTIO_NET_HDR_GSO_TCPV6 = 4 +) + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(pkt stack.PacketBufferPtr) { + if e.hdrSize > 0 { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + eth.Encode(&header.EthernetFields{ + SrcAddr: pkt.EgressRoute.LocalLinkAddress, + DstAddr: pkt.EgressRoute.RemoteLinkAddress, + Type: pkt.NetworkProtocolNumber, + }) + } +} + +// writePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) writePacket(pkt stack.PacketBufferPtr) tcpip.Error { + fdInfo := e.fds[pkt.Hash%uint32(len(e.fds))] + fd := fdInfo.fd + var vnetHdrBuf []byte + if e.gsoKind == stack.HostGSOSupported { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS + } + } + vnetHdrBuf = vnetHdr.marshal() + } + + views := pkt.AsSlices() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > e.writevMaxIovs { + numIovecs = e.writevMaxIovs + } + + // Allocate small iovec arrays on the stack. + var iovecsArr [8]unix.Iovec + iovecs := iovecsArr[:0] + if numIovecs > len(iovecsArr) { + iovecs = make([]unix.Iovec, 0, numIovecs) + } + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } + return rawfile.NonBlockingWriteIovec(fd, iovecs) +} + +func (e *endpoint) sendBatch(batchFDInfo fdInfo, pkts []stack.PacketBufferPtr) (int, tcpip.Error) { + // Degrade to writePacket if underlying fd is not a socket. + if !batchFDInfo.isSocket { + var written int + var err tcpip.Error + for written < len(pkts) { + if err = e.writePacket(pkts[written]); err != nil { + break + } + written++ + } + return written, err + } + + // Send a batch of packets through batchFD. + batchFD := batchFDInfo.fd + mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts)) + packets := 0 + for packets < len(pkts) { + mmsgHdrs := mmsgHdrsStorage + batch := pkts[packets:] + syscallHeaderBytes := uintptr(0) + for _, pkt := range batch { + var vnetHdrBuf []byte + if e.gsoKind == stack.HostGSOSupported { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS + } + } + vnetHdrBuf = vnetHdr.marshal() + } + + views := pkt.AsSlices() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > rawfile.MaxIovs { + numIovecs = rawfile.MaxIovs + } + if e.maxSyscallHeaderBytes != 0 { + syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec + if syscallHeaderBytes > e.maxSyscallHeaderBytes { + // We can't fit this packet into this call to sendmmsg(). + // We could potentially do so if we reduced numIovecs + // further, but this might incur considerable extra + // copying. Leave it to the next batch instead. + break + } + } + + // We can't easily allocate iovec arrays on the stack here since + // they will escape this loop iteration via mmsgHdrs. + iovecs := make([]unix.Iovec, 0, numIovecs) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } + + var mmsgHdr rawfile.MMsgHdr + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.SetIovlen(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) + } + + if len(mmsgHdrs) == 0 { + // We can't fit batch[0] into a mmsghdr while staying under + // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the + // mmsghdr (by using writev) and re-buffer iovecs more aggressively + // if necessary (by using e.writevMaxIovs instead of + // rawfile.MaxIovs). + pkt := batch[0] + if err := e.writePacket(pkt); err != nil { + return packets, err + } + packets++ + } else { + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } + } + } + + return packets, nil +} + +// WritePackets writes outbound packets to the underlying file descriptors. If +// one is not currently writable, the packet is dropped. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + // Preallocate to avoid repeated reallocation as we append to batch. + batch := make([]stack.PacketBufferPtr, 0, BatchSize) + batchFDInfo := fdInfo{fd: -1, isSocket: false} + sentPackets := 0 + for _, pkt := range pkts.AsSlice() { + if len(batch) == 0 { + batchFDInfo = e.fds[pkt.Hash%uint32(len(e.fds))] + } + pktFDInfo := e.fds[pkt.Hash%uint32(len(e.fds))] + if sendNow := pktFDInfo != batchFDInfo; !sendNow { + batch = append(batch, pkt) + continue + } + n, err := e.sendBatch(batchFDInfo, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + batch = batch[:0] + batch = append(batch, pkt) + batchFDInfo = pktFDInfo + } + + if len(batch) != 0 { + n, err := e.sendBatch(batchFDInfo, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + } + return sentPackets, nil +} + +// InjectOutbound implements stack.InjectableEndpoint.InjectOutbound. +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet *bufferv2.View) tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0].fd, packet.AsSlice()) +} + +// dispatchLoop reads packets from the file descriptor in a loop and dispatches +// them to the network stack. +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { + for { + cont, err := inboundDispatcher.dispatch() + if err != nil || !cont { + if e.closed != nil { + e.closed(err) + } + inboundDispatcher.release() + return err + } + } +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *endpoint) GSOMaxSize() uint32 { + return e.gsoMaxSize +} + +// SupportsHWGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return e.gsoKind +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes +// to the FD, but does not read from it. All reads come from injected packets. +type InjectableEndpoint struct { + endpoint + + mu sync.RWMutex + // +checklocks:mu + dispatcher stack.NetworkDispatcher +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + defer e.mu.Unlock() + e.dispatcher = dispatcher +} + +// InjectInbound injects an inbound packet. If the endpoint is not attached, the +// packet is not delivered. +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(protocol, pkt) + } +} + +// NewInjectable creates a new fd-based InjectableEndpoint. +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (*InjectableEndpoint, error) { + unix.SetNonblock(fd, true) + isSocket, err := isSocketFD(fd) + if err != nil { + return nil, err + } + + return &InjectableEndpoint{endpoint: endpoint{ + fds: []fdInfo{{fd: fd, isSocket: isSocket}}, + mtu: mtu, + caps: capabilities, + writevMaxIovs: rawfile.MaxIovs, + }}, nil +} diff --git a/internal/fdbased/endpoint_unsafe.go b/internal/fdbased/endpoint_unsafe.go new file mode 100644 index 0000000..904393f --- /dev/null +++ b/internal/fdbased/endpoint_unsafe.go @@ -0,0 +1,24 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package fdbased + +import ( + "unsafe" +) + +const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) diff --git a/internal/fdbased/mmap.go b/internal/fdbased/mmap.go new file mode 100644 index 0000000..218057f --- /dev/null +++ b/internal/fdbased/mmap.go @@ -0,0 +1,207 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package fdbased + +import ( + "encoding/binary" + "fmt" + + "github.com/sagernet/sing-tun/internal/fdbased/stopfd" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + tPacketAlignment = uintptr(16) + tpStatusKernel = 0 + tpStatusUser = 1 + tpStatusCopy = 2 + tpStatusLosing = 4 +) + +// We overallocate the frame size to accommodate space for the +// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. +// +// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB +// +// NOTE: +// +// Frames need to be aligned at 16 byte boundaries. +// BlockSize needs to be page aligned. +// +// For details see PACKET_MMAP setting constraints in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +const ( + tpFrameSize = 65536 + 128 + tpBlockSize = tpFrameSize * 32 + tpBlockNR = 1 + tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize +) + +// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct +// translation of the TPACKET_ALIGN macro in . +func tPacketAlign(v uintptr) uintptr { + return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) +} + +// tPacketReq is the tpacket_req structure as described in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +type tPacketReq struct { + tpBlockSize uint32 + tpBlockNR uint32 + tpFrameSize uint32 + tpFrameNR uint32 +} + +// tPacketHdr is tpacket_hdr structure as described in +type tPacketHdr []byte + +const ( + tpStatusOffset = 0 + tpLenOffset = 8 + tpSnapLenOffset = 12 + tpMacOffset = 16 + tpNetOffset = 18 + tpSecOffset = 20 + tpUSecOffset = 24 +) + +func (t tPacketHdr) tpLen() uint32 { + return binary.LittleEndian.Uint32(t[tpLenOffset:]) +} + +func (t tPacketHdr) tpSnapLen() uint32 { + return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) +} + +func (t tPacketHdr) tpMac() uint16 { + return binary.LittleEndian.Uint16(t[tpMacOffset:]) +} + +func (t tPacketHdr) tpNet() uint16 { + return binary.LittleEndian.Uint16(t[tpNetOffset:]) +} + +func (t tPacketHdr) tpSec() uint32 { + return binary.LittleEndian.Uint32(t[tpSecOffset:]) +} + +func (t tPacketHdr) tpUSec() uint32 { + return binary.LittleEndian.Uint32(t[tpUSecOffset:]) +} + +func (t tPacketHdr) Payload() []byte { + return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] +} + +// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. +// See: mmap_amd64_unsafe.go for implementation details. +type packetMMapDispatcher struct { + stopfd.StopFD + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // ringBuffer is only used when PacketMMap dispatcher is used and points + // to the start of the mmapped PACKET_RX_RING buffer. + ringBuffer []byte + + // ringOffset is the current offset into the ring buffer where the next + // inbound packet will be placed by the kernel. + ringOffset int +} + +func (*packetMMapDispatcher) release() {} + +func (d *packetMMapDispatcher) readMMappedPacket() (*bufferv2.View, bool, tcpip.Error) { + hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) + for hdr.tpStatus()&tpStatusUser == 0 { + stopped, errno := rawfile.BlockingPollUntilStopped(d.EFD, d.fd, unix.POLLIN|unix.POLLERR) + if errno != 0 { + if errno == unix.EINTR { + continue + } + return nil, stopped, rawfile.TranslateErrno(errno) + } + if stopped { + return nil, true, nil + } + if hdr.tpStatus()&tpStatusCopy != 0 { + // This frame is truncated so skip it after flipping the + // buffer to the kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) + continue + } + } + + // Copy out the packet from the mmapped frame to a locally owned buffer. + pkt := bufferv2.NewView(int(hdr.tpSnapLen())) + pkt.Write(hdr.Payload()) + // Release packet to kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + return pkt, false, nil +} + +// dispatch reads packets from an mmaped ring buffer and dispatches them to the +// network stack. +func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { + pkt, stopped, err := d.readMMappedPacket() + if err != nil || stopped { + return false, err + } + var p tcpip.NetworkProtocolNumber + if d.e.hdrSize > 0 { + p = header.Ethernet(pkt.AsSlice()).Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(pkt.AsSlice()) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithView(pkt), + }) + defer pbuf.DecRef() + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.mu.RLock() + dsp := d.e.dispatcher + d.e.mu.RUnlock() + dsp.DeliverNetworkPacket(p, pbuf) + return true, nil +} diff --git a/internal/fdbased/mmap_stub.go b/internal/fdbased/mmap_stub.go new file mode 100644 index 0000000..7c7efad --- /dev/null +++ b/internal/fdbased/mmap_stub.go @@ -0,0 +1,24 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux +// +build !linux + +package fdbased + +// Stubbed out version for non-linux/non-amd64/non-arm64 platforms. + +func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + return nil, nil +} diff --git a/internal/fdbased/mmap_unsafe.go b/internal/fdbased/mmap_unsafe.go new file mode 100644 index 0000000..fa2c9d6 --- /dev/null +++ b/internal/fdbased/mmap_unsafe.go @@ -0,0 +1,91 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package fdbased + +import ( + "fmt" + "unsafe" + + "github.com/sagernet/sing-tun/internal/fdbased/stopfd" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/atomicbitops" +) + +// tPacketHdrlen is the TPACKET_HDRLEN variable defined in . +var tPacketHdrlen = tPacketAlign(unsafe.Sizeof(tPacketHdr{}) + unsafe.Sizeof(unix.RawSockaddrLinklayer{})) + +// tpStatus returns the frame status field. +// The status is concurrently updated by the kernel as a result we must +// use atomic operations to prevent races. +func (t tPacketHdr) tpStatus() uint32 { + hdr := unsafe.Pointer(&t[0]) + statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset)) + return (*atomicbitops.Uint32)(statusPtr).Load() +} + +// setTPStatus set's the frame status to the provided status. +// The status is concurrently updated by the kernel as a result we must +// use atomic operations to prevent races. +func (t tPacketHdr) setTPStatus(status uint32) { + hdr := unsafe.Pointer(&t[0]) + statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset)) + (*atomicbitops.Uint32)(statusPtr).Store(status) +} + +func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFD, err := stopfd.New() + if err != nil { + return nil, err + } + d := &packetMMapDispatcher{ + StopFD: stopFD, + fd: fd, + e: e, + } + pageSize := unix.Getpagesize() + if tpBlockSize%pageSize != 0 { + return nil, fmt.Errorf("tpBlockSize: %d is not page aligned, pagesize: %d", tpBlockSize, pageSize) + } + tReq := tPacketReq{ + tpBlockSize: uint32(tpBlockSize), + tpBlockNR: uint32(tpBlockNR), + tpFrameSize: uint32(tpFrameSize), + tpFrameNR: uint32(tpFrameNR), + } + // Setup PACKET_RX_RING. + if err := setsockopt(d.fd, unix.SOL_PACKET, unix.PACKET_RX_RING, unsafe.Pointer(&tReq), unsafe.Sizeof(tReq)); err != nil { + return nil, fmt.Errorf("failed to enable PACKET_RX_RING: %v", err) + } + // Let's mmap the blocks. + sz := tpBlockSize * tpBlockNR + buf, err := unix.Mmap(d.fd, 0, sz, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + return nil, fmt.Errorf("unix.Mmap(...,0, %v, ...) failed = %v", sz, err) + } + d.ringBuffer = buf + return d, nil +} + +func setsockopt(fd, level, name int, val unsafe.Pointer, vallen uintptr) error { + if _, _, errno := unix.Syscall6(unix.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(val), vallen, 0); errno != 0 { + return error(errno) + } + + return nil +} diff --git a/internal/fdbased/packet_dispatchers.go b/internal/fdbased/packet_dispatchers.go new file mode 100644 index 0000000..51d5c42 --- /dev/null +++ b/internal/fdbased/packet_dispatchers.go @@ -0,0 +1,344 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package fdbased + +import ( + "github.com/sagernet/sing-tun/internal/fdbased/stopfd" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// BufConfig defines the shape of the buffer used to read packets from the NIC. +var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} + +type iovecBuffer struct { + // buffer is the actual buffer that holds the packet contents. Some contents + // are reused across calls to pullBuffer if number of requested bytes is + // smaller than the number of bytes allocated in the buffer. + views []*bufferv2.View + + // iovecs are initialized with base pointers/len of the corresponding + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further + // processing. + iovecs []unix.Iovec + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // skipsVnetHdr is true if virtioNetHdr is to skipped. + skipsVnetHdr bool + + // pulledIndex is the index of the last []byte buffer pulled from the + // underlying buffer storage during a call to pullBuffers. It is -1 + // if no buffer is pulled. + pulledIndex int +} + +func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer { + b := &iovecBuffer{ + views: make([]*bufferv2.View, len(sizes)), + sizes: sizes, + skipsVnetHdr: skipsVnetHdr, + } + niov := len(b.views) + if b.skipsVnetHdr { + niov++ + } + b.iovecs = make([]unix.Iovec, niov) + return b +} + +func (b *iovecBuffer) nextIovecs() []unix.Iovec { + vnetHdrOff := 0 + if b.skipsVnetHdr { + var vnetHdr [virtioNetHdrSize]byte + // The kernel adds virtioNetHdr before each packet, but + // we don't use it, so we allocate a buffer for it, + // add it in iovecs but don't add it in a view. + b.iovecs[0] = unix.Iovec{Base: &vnetHdr[0]} + b.iovecs[0].SetLen(virtioNetHdrSize) + vnetHdrOff++ + } + + for i := range b.views { + if b.views[i] != nil { + break + } + v := bufferv2.NewViewSize(b.sizes[i]) + b.views[i] = v + b.iovecs[i+vnetHdrOff] = unix.Iovec{Base: v.BasePtr()} + b.iovecs[i+vnetHdrOff].SetLen(v.Size()) + } + return b.iovecs +} + +// pullBuffer extracts the enough underlying storage from b.buffer to hold n +// bytes. It removes this storage from b.buffer, returns a new buffer +// that holds the storage, and updates pulledIndex to indicate which part +// of b.buffer's storage must be reallocated during the next call to +// nextIovecs. +func (b *iovecBuffer) pullBuffer(n int) bufferv2.Buffer { + var views []*bufferv2.View + c := 0 + if b.skipsVnetHdr { + c += virtioNetHdrSize + if c >= n { + // Nothing in the packet. + return bufferv2.Buffer{} + } + } + // Remove the used views from the buffer. + for i, v := range b.views { + c += v.Size() + if c >= n { + b.views[i].CapLength(v.Size() - (c - n)) + views = append(views, b.views[:i+1]...) + break + } + } + for i := range views { + b.views[i] = nil + } + if b.skipsVnetHdr { + // Exclude the size of the vnet header. + n -= virtioNetHdrSize + } + pulled := bufferv2.Buffer{} + for _, v := range views { + pulled.Append(v) + } + pulled.Truncate(int64(n)) + return pulled +} + +func (b *iovecBuffer) release() { + for _, v := range b.views { + if v != nil { + v.Release() + v = nil + } + } +} + +// readVDispatcher uses readv() system call to read inbound packets and +// dispatches them. +type readVDispatcher struct { + stopfd.StopFD + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // buf is the iovec buffer that contains the packet contents. + buf *iovecBuffer +} + +func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFD, err := stopfd.New() + if err != nil { + return nil, err + } + d := &readVDispatcher{ + StopFD: stopFD, + fd: fd, + e: e, + } + skipsVnetHdr := d.e.gsoKind == stack.HostGSOSupported + d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) + return d, nil +} + +func (d *readVDispatcher) release() { + d.buf.release() +} + +// dispatch reads one packet from the file descriptor and dispatches it. +func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { + n, err := rawfile.BlockingReadvUntilStopped(d.EFD, d.fd, d.buf.nextIovecs()) + if n <= 0 || err != nil { + return false, err + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: d.buf.pullBuffer(n), + }) + defer pkt.DecRef() + + var p tcpip.NetworkProtocolNumber + if d.e.hdrSize > 0 { + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + p = header.Ethernet(hdr).Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + return true, nil + } + switch header.IPVersion(h) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + d.e.mu.RLock() + dsp := d.e.dispatcher + d.e.mu.RUnlock() + dsp.DeliverNetworkPacket(p, pkt) + + return true, nil +} + +// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and +// dispatches them. +type recvMMsgDispatcher struct { + stopfd.StopFD + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer + + // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to + // reference an array of iovecs in the iovecs field defined above. This + // array is passed as the parameter to recvmmsg call to retrieve + // potentially more than 1 packet per unix. + msgHdrs []rawfile.MMsgHdr +} + +const ( + // MaxMsgsPerRecv is the maximum number of packets we want to retrieve + // in a single RecvMMsg call. + MaxMsgsPerRecv = 8 +) + +func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFD, err := stopfd.New() + if err != nil { + return nil, err + } + d := &recvMMsgDispatcher{ + StopFD: stopFD, + fd: fd, + e: e, + bufs: make([]*iovecBuffer, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), + } + skipsVnetHdr := d.e.gsoKind == stack.HostGSOSupported + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) + } + return d, nil +} + +func (d *recvMMsgDispatcher) release() { + for _, iov := range d.bufs { + iov.release() + } +} + +// recvMMsgDispatch reads more than one packet at a time from the file +// descriptor and dispatches it. +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].Len = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.SetIovlen(iovLen) + } + + nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.EFD, d.fd, d.msgHdrs) + if nMsgs == -1 || err != nil { + return false, err + } + // Process each of received packets. + // Keep a list of packets so we can DecRef outside of the loop. + var pkts stack.PacketBufferList + + d.e.mu.RLock() + dsp := d.e.dispatcher + d.e.mu.RUnlock() + + defer func() { pkts.DecRef() }() + for k := 0; k < nMsgs; k++ { + n := int(d.msgHdrs[k].Len) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: d.bufs[k].pullBuffer(n), + }) + pkts.PushBack(pkt) + + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + + var p tcpip.NetworkProtocolNumber + if d.e.hdrSize > 0 { + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + p = header.Ethernet(hdr).Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + // Skip this packet. + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + // Skip this packet. + continue + } + } + + dsp.DeliverNetworkPacket(p, pkt) + } + + return true, nil +} diff --git a/internal/fdbased/stopfd/stopfd.go b/internal/fdbased/stopfd/stopfd.go new file mode 100644 index 0000000..9cada9e --- /dev/null +++ b/internal/fdbased/stopfd/stopfd.go @@ -0,0 +1,52 @@ +// Copyright 2022 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +// Package stopfd provides an type that can be used to signal the stop of a dispatcher. +package stopfd + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// StopFD is an eventfd used to signal the stop of a dispatcher. +type StopFD struct { + EFD int +} + +// New returns a new, initialized StopFD. +func New() (StopFD, error) { + efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) + if err != nil { + return StopFD{EFD: -1}, fmt.Errorf("failed to create eventfd: %w", err) + } + return StopFD{EFD: efd}, nil +} + +// Stop writes to the eventfd and notifies the dispatcher to stop. It does not +// block. +func (sf *StopFD) Stop() { + increment := []byte{1, 0, 0, 0, 0, 0, 0, 0} + if n, err := unix.Write(sf.EFD, increment); n != len(increment) || err != nil { + // There are two possible errors documented in eventfd(2) for writing: + // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL. + // 2. stop is only supposed to be called once, it can't reach the limit, + // thus no EAGAIN. + panic(fmt.Sprintf("write(EFD) = (%d, %s), want (%d, nil)", n, err, len(increment))) + } +} diff --git a/internal/fdbased/stopfd/stopfd_state_autogen.go b/internal/fdbased/stopfd/stopfd_state_autogen.go new file mode 100644 index 0000000..c013598 --- /dev/null +++ b/internal/fdbased/stopfd/stopfd_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build (linux && amd64) || (linux && arm64) +// +build linux,amd64 linux,arm64 + +package stopfd diff --git a/route_gvisor.go b/route_gvisor.go index eaf6c1d..185a362 100644 --- a/route_gvisor.go +++ b/route_gvisor.go @@ -10,7 +10,7 @@ import ( type DirectDestination interface { WritePacket(buffer *buf.Buffer) error - WritePacketBuffer(buffer *stack.PacketBuffer) error + WritePacketBuffer(buffer stack.PacketBufferPtr) error Close() error Timeout() bool } diff --git a/route_nat_gvisor.go b/route_nat_gvisor.go index cb7c5ce..1c59335 100644 --- a/route_nat_gvisor.go +++ b/route_nat_gvisor.go @@ -8,7 +8,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) { +func (w *NatWriter) RewritePacketBuffer(packetBuffer stack.PacketBufferPtr) { var bindAddr tcpip.Address if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber { bindAddr = tcpip.Address(w.inet4Address.AsSlice()) diff --git a/tun_darwin_gvisor.go b/tun_darwin_gvisor.go index b57bf2c..f0d66bb 100644 --- a/tun_darwin_gvisor.go +++ b/tun_darwin_gvisor.go @@ -109,7 +109,7 @@ func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } -func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) { +func (e *DarwinEndpoint) AddHeader(buffer stack.PacketBufferPtr) { } func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) { diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index 4386220..6bd193c 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -3,7 +3,8 @@ package tun import ( - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "github.com/sagernet/sing-tun/internal/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/tun_windows_gvisor.go b/tun_windows_gvisor.go index 11fe041..c8f7aad 100644 --- a/tun_windows_gvisor.go +++ b/tun_windows_gvisor.go @@ -99,7 +99,7 @@ func (e *WintunEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } -func (e *WintunEndpoint) AddHeader(buffer *stack.PacketBuffer) { +func (e *WintunEndpoint) AddHeader(buffer stack.PacketBufferPtr) { } func (e *WintunEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {