diff --git a/go.mod b/go.mod index 76fdb09..7d799b2 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,9 @@ go 1.18 require ( github.com/fsnotify/fsnotify v1.7.0 github.com/go-ole/go-ole v1.3.0 - github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 - github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898 + github.com/sagernet/sing v0.3.0-rc.2 github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/net v0.19.0 diff --git a/go.sum b/go.sum index d090c88..f7b43f6 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,22 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= -github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA= -github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e h1:DOkjByVeAR56dkszjnMZke4wr7yM/1xHaJF3G9olkEE= github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e/go.mod h1:fLxq/gtp0qzkaEwywlRRiGmjOK5ES/xUzyIKIFP2Asw= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE= github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= -github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= -github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898 h1:ZR0wpw4/0NCICOX10SIUW8jpPVV7+D98nGA6p4zWICo= -github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= +github.com/sagernet/sing v0.3.0-beta.3 h1:E2xBoJUducK/FE6EwMk95Rt2bkXeht9l1BTYRui+DXs= +github.com/sagernet/sing v0.3.0-beta.3/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= +github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug= +github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg= github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= @@ -22,9 +24,9 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/W golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/clashtcpip/tcp.go b/internal/clashtcpip/tcp.go index 3e0ee73..ee7a894 100644 --- a/internal/clashtcpip/tcp.go +++ b/internal/clashtcpip/tcp.go @@ -50,6 +50,10 @@ func (p TCPPacket) SetChecksum(sum [2]byte) { p[17] = sum[1] } +func (p TCPPacket) OffloadChecksum() { + p.SetChecksum(zeroChecksum) +} + func (p TCPPacket) ResetChecksum(psum uint32) { p.SetChecksum(zeroChecksum) p.SetChecksum(Checksum(psum, p)) diff --git a/internal/clashtcpip/udp.go b/internal/clashtcpip/udp.go index f5773a1..f576e99 100644 --- a/internal/clashtcpip/udp.go +++ b/internal/clashtcpip/udp.go @@ -45,6 +45,10 @@ func (p UDPPacket) SetChecksum(sum [2]byte) { p[7] = sum[1] } +func (p UDPPacket) OffloadChecksum() { + p.SetChecksum(zeroChecksum) +} + func (p UDPPacket) ResetChecksum(psum uint32) { p.SetChecksum(zeroChecksum) p.SetChecksum(Checksum(psum, p)) diff --git a/stack.go b/stack.go index 7eb7854..2d5ef5a 100644 --- a/stack.go +++ b/stack.go @@ -19,10 +19,7 @@ type Stack interface { type StackOptions struct { Context context.Context Tun Tun - Name string - MTU uint32 - Inet4Address []netip.Prefix - Inet6Address []netip.Prefix + TunOptions Options EndpointIndependentNat bool UDPTimeout int64 Handler Handler @@ -37,7 +34,7 @@ func NewStack( ) (Stack, error) { switch stack { case "": - if WithGVisor { + if WithGVisor && !options.TunOptions.GSO { return NewMixed(options) } else { return NewSystem(options) @@ -48,8 +45,6 @@ func NewStack( return NewMixed(options) case "system": return NewSystem(options) - case "lwip": - return NewLWIP(options) default: return nil, E.New("unknown stack: ", stack) } diff --git a/stack_gvisor.go b/stack_gvisor.go index 108af21..ca41a37 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -31,7 +31,6 @@ const defaultNIC tcpip.NICID = 1 type GVisor struct { ctx context.Context tun GVisorTun - tunMtu uint32 endpointIndependentNat bool udpTimeout int64 broadcastAddr netip.Addr @@ -57,10 +56,9 @@ func NewGVisor( gStack := &GVisor{ ctx: options.Context, tun: gTun, - tunMtu: options.MTU, endpointIndependentNat: options.EndpointIndependentNat, udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.Inet4Address), + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), handler: options.Handler, logger: options.Logger, } @@ -72,7 +70,7 @@ func (t *GVisor) Start() error { if err != nil { return err } - linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()} + linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun} ipStack, err := newGVisorStack(linkEndpoint) if err != nil { return err diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index ca5856a..74ce4b5 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -82,7 +82,6 @@ type UDPBackWriter struct { source tcpip.Address sourcePort uint16 sourceNetwork tcpip.NetworkProtocolNumber - packet stack.PacketBufferPtr } func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { @@ -149,12 +148,6 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock return nil } -type gRequest struct { - stack *stack.Stack - id stack.TransportEndpointID - pkt stack.PacketBufferPtr -} - type gUDPConn struct { *gonet.UDPConn } diff --git a/stack_lwip.go b/stack_lwip.go deleted file mode 100644 index 42cb651..0000000 --- a/stack_lwip.go +++ /dev/null @@ -1,144 +0,0 @@ -//go:build with_lwip - -package tun - -import ( - "context" - "net" - "net/netip" - "os" - - lwip "github.com/sagernet/go-tun2socks/core" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/udpnat" -) - -type LWIP struct { - ctx context.Context - tun Tun - tunMtu uint32 - udpTimeout int64 - handler Handler - stack lwip.LWIPStack - udpNat *udpnat.Service[netip.AddrPort] -} - -func NewLWIP( - options StackOptions, -) (Stack, error) { - return &LWIP{ - ctx: options.Context, - tun: options.Tun, - tunMtu: options.MTU, - handler: options.Handler, - stack: lwip.NewLWIPStack(), - udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler), - }, nil -} - -func (l *LWIP) Start() error { - lwip.RegisterTCPConnHandler(l) - lwip.RegisterUDPConnHandler(l) - lwip.RegisterOutputFn(l.tun.Write) - go l.loopIn() - return nil -} - -func (l *LWIP) loopIn() { - if winTun, isWintun := l.tun.(WinTun); isWintun { - l.loopInWintun(winTun) - return - } - buffer := make([]byte, int(l.tunMtu)+PacketOffset) - for { - n, err := l.tun.Read(buffer) - if err != nil { - return - } - _, err = l.stack.Write(buffer[PacketOffset:n]) - if err != nil { - if err.Error() == "stack closed" { - return - } - l.handler.NewError(context.Background(), err) - } - } -} - -func (l *LWIP) loopInWintun(tun WinTun) { - for { - packet, release, err := tun.ReadPacket() - if err != nil { - return - } - _, err = l.stack.Write(packet) - release() - if err != nil { - if err.Error() == "stack closed" { - return - } - l.handler.NewError(context.Background(), err) - } - } -} - -func (l *LWIP) Close() error { - lwip.RegisterTCPConnHandler(nil) - lwip.RegisterUDPConnHandler(nil) - lwip.RegisterOutputFn(func(bytes []byte) (int, error) { - return 0, os.ErrClosed - }) - return l.stack.Close() -} - -func (l *LWIP) Handle(conn net.Conn) error { - lAddr := conn.LocalAddr() - rAddr := conn.RemoteAddr() - if lAddr == nil || rAddr == nil { - conn.Close() - return nil - } - go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - hErr := l.handler.NewConnection(l.ctx, conn, metadata) - if hErr != nil { - conn.(lwip.TCPConn).Abort() - } - }() - return nil -} - -func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error { - var upstreamMetadata M.Metadata - upstreamMetadata.Source = conn.LocalAddr() - upstreamMetadata.Destination = addr - - l.udpNat.NewPacket( - l.ctx, - upstreamMetadata.Source.AddrPort(), - buf.As(data).ToOwned(), - upstreamMetadata, - func(natConn N.PacketConn) N.PacketWriter { - return &LWIPUDPBackWriter{conn} - }, - ) - return nil -} - -type LWIPUDPBackWriter struct { - conn lwip.UDPConn -} - -func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination)) -} - -func (w *LWIPUDPBackWriter) Close() error { - return w.conn.Close() -} diff --git a/stack_lwip_stub.go b/stack_lwip_stub.go deleted file mode 100644 index 403a45e..0000000 --- a/stack_lwip_stub.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !with_lwip - -package tun - -import E "github.com/sagernet/sing/common/exceptions" - -func NewLWIP( - options StackOptions, -) (Stack, error) { - return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`) -} diff --git a/stack_mixed.go b/stack_mixed.go index 41d0ce2..811e0fd 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -13,17 +13,14 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/waiter" "github.com/sagernet/sing-tun/internal/clashtcpip" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) type Mixed struct { *System - writer N.VectorisedWriter endpointIndependentNat bool stack *stack.Stack endpoint *channel.Endpoint @@ -38,7 +35,6 @@ func NewMixed( } return &Mixed{ System: system.(*System), - writer: options.Tun.CreateVectorisedWriter(), endpointIndependentNat: options.EndpointIndependentNat, }, nil } @@ -48,7 +44,7 @@ func (m *Mixed) Start() error { if err != nil { return err } - endpoint := channel.New(1024, m.mtu, "") + endpoint := channel.New(1024, uint32(m.mtu), "") ipStack, err := newGVisorStack(endpoint) if err != nil { return err @@ -95,26 +91,34 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } + if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN { + m.frontHeadroom = linuxTUN.FrontHeadroom() + m.txChecksumOffload = linuxTUN.TXChecksumOffload() + batchSize := linuxTUN.BatchSize() + if batchSize > 1 { + m.batchLoop(linuxTUN, batchSize) + return + } + } packetBuffer := make([]byte, m.mtu+PacketOffset) for { n, err := m.tun.Read(packetBuffer) if err != nil { - return + if E.IsClosed(err) { + return + } + m.logger.Error(E.Cause(err, "read packet")) } if n < clashtcpip.IPv4PacketMinLength { continue } + rawPacket := packetBuffer[:n] packet := packetBuffer[PacketOffset:n] - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = m.processIPv4(packet) - case 6: - err = m.processIPv6(packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - m.logger.Trace(err) + if m.processPacket(packet) { + _, err = m.tun.Write(rawPacket) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } } } } @@ -129,62 +133,119 @@ func (m *Mixed) wintunLoop(winTun WinTun) { release() continue } - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = m.processIPv4(packet) - case 6: - err = m.processIPv6(packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - m.logger.Trace(err) + if m.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } } release() } } -func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { - destination := packet.DestinationIP() - if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(m.tun.Write(packet)) +func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { + packetBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) + packetSizes := make([]int, batchSize) + for i := range packetBuffers { + packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom) } - switch packet.Protocol() { - case clashtcpip.TCP: - return m.processIPv4TCP(packet, packet.Payload()) - case clashtcpip.UDP: - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), - }) - m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) - pkt.DecRef() - return nil - case clashtcpip.ICMP: - return m.processIPv4ICMP(packet, packet.Payload()) - default: - return common.Error(m.tun.Write(packet)) + for { + n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes) + if err != nil { + if E.IsClosed(err) { + return + } + m.logger.Error(E.Cause(err, "batch read packet")) + } + if n == 0 { + continue + } + for i := 0; i < n; i++ { + packetSize := packetSizes[i] + if packetSize < clashtcpip.IPv4PacketMinLength { + continue + } + packetBuffer := packetBuffers[i] + packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize] + if m.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) + if err != nil { + m.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] + } } } -func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { - if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(m.tun.Write(packet)) +func (m *Mixed) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + writeBack, err = m.processIPv4(packet) + case 6: + writeBack, err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + return false + } + return writeBack +} + +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true + destination := packet.DestinationIP() + if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { + return } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv6TCP(packet, packet.Payload()) + err = m.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: + writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), + Payload: buffer.MakeWithData(packet), + IsForwardedPacket: true, + }) + m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) + pkt.DecRef() + return + case clashtcpip.ICMP: + err = m.processIPv4ICMP(packet, packet.Payload()) + } + return +} + +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true + if !packet.DestinationIP().IsGlobalUnicast() { + return + } + switch packet.Protocol() { + case clashtcpip.TCP: + err = m.processIPv6TCP(packet, packet.Payload()) + case clashtcpip.UDP: + writeBack = false + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + IsForwardedPacket: true, }) m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) pkt.DecRef() - return nil case clashtcpip.ICMPv6: - return m.processIPv6ICMP(packet, packet.Payload()) - default: - return common.Error(m.tun.Write(packet)) + err = m.processIPv6ICMP(packet, packet.Payload()) } + return } func (m *Mixed) packetLoop() { @@ -193,7 +254,7 @@ func (m *Mixed) packetLoop() { if packet == nil { break } - bufio.WriteVectorised(m.writer, packet.AsSlices()) + bufio.WriteVectorised(m.tun, packet.AsSlices()) packet.DecRef() } } diff --git a/stack_system.go b/stack_system.go index dd305fe..c1d1216 100644 --- a/stack_system.go +++ b/stack_system.go @@ -22,7 +22,7 @@ type System struct { ctx context.Context tun Tun tunName string - mtu uint32 + mtu int handler Handler logger logger.Logger inet4Prefixes []netip.Prefix @@ -41,6 +41,8 @@ type System struct { udpNat *udpnat.Service[netip.AddrPort] bindInterface bool interfaceFinder control.InterfaceFinder + frontHeadroom int + txChecksumOffload bool } type Session struct { @@ -54,29 +56,29 @@ func NewSystem(options StackOptions) (Stack, error) { stack := &System{ ctx: options.Context, tun: options.Tun, - tunName: options.Name, - mtu: options.MTU, + tunName: options.TunOptions.Name, + mtu: int(options.TunOptions.MTU), udpTimeout: options.UDPTimeout, handler: options.Handler, logger: options.Logger, - inet4Prefixes: options.Inet4Address, - inet6Prefixes: options.Inet6Address, - broadcastAddr: BroadcastAddr(options.Inet4Address), + inet4Prefixes: options.TunOptions.Inet4Address, + inet6Prefixes: options.TunOptions.Inet6Address, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), bindInterface: options.ForwarderBindInterface, interfaceFinder: options.InterfaceFinder, } - if len(options.Inet4Address) > 0 { - if options.Inet4Address[0].Bits() == 32 { + if len(options.TunOptions.Inet4Address) > 0 { + if options.TunOptions.Inet4Address[0].Bits() == 32 { return nil, E.New("need one more IPv4 address in first prefix for system stack") } - stack.inet4ServerAddress = options.Inet4Address[0].Addr() + stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr() stack.inet4Address = stack.inet4ServerAddress.Next() } - if len(options.Inet6Address) > 0 { - if options.Inet6Address[0].Bits() == 128 { + if len(options.TunOptions.Inet6Address) > 0 { + if options.TunOptions.Inet6Address[0].Bits() == 128 { return nil, E.New("need one more IPv6 address in first prefix for system stack") } - stack.inet6ServerAddress = options.Inet6Address[0].Addr() + stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr() stack.inet6Address = stack.inet6ServerAddress.Next() } if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { @@ -144,26 +146,34 @@ func (s *System) tunLoop() { s.wintunLoop(winTun) return } + if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN { + s.frontHeadroom = linuxTUN.FrontHeadroom() + s.txChecksumOffload = linuxTUN.TXChecksumOffload() + batchSize := linuxTUN.BatchSize() + if batchSize > 1 { + s.batchLoop(linuxTUN, batchSize) + return + } + } packetBuffer := make([]byte, s.mtu+PacketOffset) for { n, err := s.tun.Read(packetBuffer) if err != nil { - return + if E.IsClosed(err) { + return + } + s.logger.Error(E.Cause(err, "read packet")) } if n < clashtcpip.IPv4PacketMinLength { continue } + rawPacket := packetBuffer[:n] packet := packetBuffer[PacketOffset:n] - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = s.processIPv4(packet) - case 6: - err = s.processIPv6(packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - s.logger.Trace(err) + if s.processPacket(packet) { + _, err = s.tun.Write(rawPacket) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } } } } @@ -178,21 +188,75 @@ func (s *System) wintunLoop(winTun WinTun) { release() continue } - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = s.processIPv4(packet) - case 6: - err = s.processIPv6(packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - s.logger.Trace(err) + if s.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } } release() } } +func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { + packetBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) + packetSizes := make([]int, batchSize) + for i := range packetBuffers { + packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom) + } + for { + n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes) + if err != nil { + if E.IsClosed(err) { + return + } + s.logger.Error(E.Cause(err, "batch read packet")) + } + if n == 0 { + continue + } + for i := 0; i < n; i++ { + packetSize := packetSizes[i] + if packetSize < clashtcpip.IPv4PacketMinLength { + continue + } + packetBuffer := packetBuffers[i] + packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize] + if s.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) + if err != nil { + s.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] + } + } +} + +func (s *System) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + writeBack, err = s.processIPv4(packet) + case 6: + writeBack, err = s.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + s.logger.Trace(err) + return false + } + return writeBack +} + func (s *System) acceptLoop(listener net.Listener) { for { conn, err := listener.Accept() @@ -234,44 +298,46 @@ func (s *System) acceptLoop(listener net.Listener) { } } -func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error { +func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true destination := packet.DestinationIP() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv4TCP(packet, packet.Payload()) + err = s.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv4UDP(packet, packet.Payload()) + writeBack = false + err = s.processIPv4UDP(packet, packet.Payload()) case clashtcpip.ICMP: - return s.processIPv4ICMP(packet, packet.Payload()) - default: - return common.Error(s.tun.Write(packet)) + err = s.processIPv4ICMP(packet, packet.Payload()) } + return } -func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error { +func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv6TCP(packet, packet.Payload()) + err = s.processIPv6TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv6UDP(packet, packet.Payload()) + writeBack = false + err = s.processIPv6UDP(packet, packet.Payload()) case clashtcpip.ICMPv6: - return s.processIPv6ICMP(packet, packet.Payload()) - default: - return common.Error(s.tun.Write(packet)) + err = s.processIPv6ICMP(packet, packet.Payload()) } + return } func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return nil } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -288,16 +354,21 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip. packet.SetDestinationIP(s.inet4ServerAddress) header.SetDestinationPort(s.tcpPort) } - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + if !s.txChecksumOffload { + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + } else { + header.OffloadChecksum() + packet.ResetChecksum() + } + return nil } func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return nil } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -314,9 +385,12 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip. packet.SetDestinationIP(s.inet6ServerAddress) header.SetDestinationPort(s.tcpPort6) } - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + if !s.txChecksumOffload { + header.ResetChecksum(packet.PseudoSum()) + } else { + header.OffloadChecksum() + } + return nil } func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { @@ -332,7 +406,7 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -346,7 +420,13 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter4{s.tun, headerCopy, source} + return &systemUDPPacketWriter4{ + s.tun, + s.frontHeadroom + PacketOffset, + headerCopy, + source, + s.txChecksumOffload, + } }) return nil } @@ -358,7 +438,7 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip. source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -372,7 +452,13 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip. headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter6{s.tun, headerCopy, source} + return &systemUDPPacketWriter6{ + s.tun, + s.frontHeadroom + PacketOffset, + headerCopy, + source, + s.txChecksumOffload, + } }) return nil } @@ -387,7 +473,7 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip packet.SetDestinationIP(sourceAddress) header.ResetChecksum() packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + return nil } func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { @@ -400,102 +486,21 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip packet.SetDestinationIP(sourceAddress) header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) -} - -type systemTCPDirectPacketWriter4 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemTCPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - header := clashtcpip.TCPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemTCPDirectPacketWriter6 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemTCPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - header := clashtcpip.TCPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemUDPDirectPacketWriter4 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemUDPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - header := clashtcpip.UDPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemUDPDirectPacketWriter6 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemUDPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - header := clashtcpip.UDPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemICMPDirectPacketWriter4 struct { - tun Tun - source netip.Addr -} - -func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - packet.SetDestinationIP(w.source) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemICMPDirectPacketWriter6 struct { - tun Tun - source netip.Addr -} - -func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - packet.SetDestinationIP(w.source) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) + return nil } type systemUDPPacketWriter4 struct { - tun Tun - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort + txChecksumOffload bool } func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - newPacket := buf.NewSize(len(w.header) + buffer.Len()) + newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() + newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) @@ -506,20 +511,33 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) - udpHdr.ResetChecksum(ipHdr.PseudoSum()) - ipHdr.ResetChecksum() + if !w.txChecksumOffload { + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + ipHdr.ResetChecksum() + } else { + udpHdr.OffloadChecksum() + ipHdr.ResetChecksum() + } + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + } else { + newPacket.Advance(-w.frontHeadroom) + } return common.Error(w.tun.Write(newPacket.Bytes())) } type systemUDPPacketWriter6 struct { - tun Tun - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort + txChecksumOffload bool } func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - newPacket := buf.NewSize(len(w.header) + buffer.Len()) + newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() + newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) @@ -531,6 +549,15 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) - udpHdr.ResetChecksum(ipHdr.PseudoSum()) + if !w.txChecksumOffload { + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + } else { + udpHdr.OffloadChecksum() + } + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + } else { + newPacket.Advance(-w.frontHeadroom) + } return common.Error(w.tun.Write(newPacket.Bytes())) } diff --git a/tun.go b/tun.go index 977d4eb..6d86228 100644 --- a/tun.go +++ b/tun.go @@ -23,7 +23,7 @@ type Handler interface { type Tun interface { io.ReadWriter - CreateVectorisedWriter() N.VectorisedWriter + N.VectorisedWriter Close() error } @@ -32,11 +32,21 @@ type WinTun interface { ReadPacket() ([]byte, func(), error) } +type LinuxTUN interface { + Tun + N.FrontHeadroom + BatchSize() int + BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) + BatchWrite(buffers [][]byte, offset int) error + TXChecksumOffload() bool +} + type Options struct { Name string Inet4Address []netip.Prefix Inet6Address []netip.Prefix MTU uint32 + GSO bool AutoRoute bool StrictRoute bool Inet4RouteAddress []netip.Prefix @@ -54,6 +64,9 @@ type Options struct { TableIndex int FileDescriptor int Logger logger.Logger + + // No work for TCP, do not use. + _TXChecksumOffload bool } func CalculateInterfaceName(name string) (tunName string) { diff --git a/tun_darwin.go b/tun_darwin.go index 553eb02..26782f0 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "os" - "runtime" "syscall" "unsafe" @@ -68,44 +67,22 @@ func New(options Options) (Tun, error) { if !ok { panic("create vectorised writer") } - runtime.SetFinalizer(nativeTun.tunFile, nil) return nativeTun, nil } func (t *NativeTun) Read(p []byte) (n int, err error) { - /*n, err = t.tunFile.Read(p) - if n < 4 { - return 0, err - } - - copy(p[:], p[4:]) - return n - 4, err*/ return t.tunFile.Read(p) } +func (t *NativeTun) Write(p []byte) (n int, err error) { + return t.tunFile.Write(p) +} + var ( packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET} packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6} ) -func (t *NativeTun) Write(p []byte) (n int, err error) { - var packetHeader []byte - if p[0]>>4 == 4 { - packetHeader = packetHeader4[:] - } else { - packetHeader = packetHeader6[:] - } - _, err = bufio.WriteVectorised(t.tunWriter, [][]byte{packetHeader, p}) - if err == nil { - n = len(p) - } - return -} - -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return t -} - func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { var packetHeader []byte if buffers[0].Byte(0)>>4 == 4 { diff --git a/tun_darwin_gvisor.go b/tun_darwin_gvisor.go index 132d0c3..a1f13ae 100644 --- a/tun_darwin_gvisor.go +++ b/tun_darwin_gvisor.go @@ -36,7 +36,7 @@ func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress { } func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityNone + return stack.CapabilityRXChecksumOffload } func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -51,13 +51,13 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } func (e *DarwinEndpoint) dispatchLoop() { - packetBuffer := make([]byte, e.tun.mtu+4) + packetBuffer := make([]byte, e.tun.mtu+PacketOffset) for { n, err := e.tun.tunFile.Read(packetBuffer) if err != nil { break } - packet := packetBuffer[4:n] + packet := packetBuffer[PacketOffset:n] var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(packet) { case header.IPv4Version: @@ -112,14 +112,7 @@ func (e *DarwinEndpoint) ParseHeader(ptr stack.PacketBufferPtr) bool { func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) { var n int for _, packet := range packetBufferList.AsSlice() { - var packetHeader []byte - switch packet.NetworkProtocolNumber { - case header.IPv4ProtocolNumber: - packetHeader = packetHeader4[:] - case header.IPv6ProtocolNumber: - packetHeader = packetHeader6[:] - } - _, err := bufio.WriteVectorised(e.tun.tunWriter, append([][]byte{packetHeader}, packet.AsSlices()...)) + _, err := bufio.WriteVectorised(e.tun, packet.AsSlices()) if err != nil { return n, &tcpip.ErrAborted{} } diff --git a/tun_linux.go b/tun_linux.go index a261455..69dafdf 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -7,11 +7,13 @@ import ( "os" "os/exec" "runtime" + "sync" "syscall" "unsafe" "github.com/sagernet/netlink" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" @@ -22,17 +24,29 @@ import ( "golang.org/x/sys/unix" ) +var _ LinuxTUN = (*NativeTun)(nil) + type NativeTun struct { tunFd int tunFile *os.File + tunWriter N.VectorisedWriter interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int + gsoEnabled bool + gsoBuffer []byte + gsoToWrite []int + gsoReadAccess sync.Mutex + tcpGROAccess sync.Mutex + tcp4GROTable *tcpGROTable + tcp6GROTable *tcpGROTable + txChecksumOffload bool } func New(options Options) (Tun, error) { + var nativeTun *NativeTun if options.FileDescriptor == 0 { - tunFd, err := open(options.Name) + tunFd, err := open(options.Name, options.GSO) if err != nil { return nil, err } @@ -40,38 +54,125 @@ func New(options Options) (Tun, error) { if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: tunFd, tunFile: os.NewFile(uintptr(tunFd), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) err = nativeTun.configure(tunLink) if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - return nativeTun, nil } else { - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: options.FileDescriptor, tunFile: os.NewFile(uintptr(options.FileDescriptor), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) - return nativeTun, nil } + var ok bool + nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile) + if !ok { + panic("create vectorised writer") + } + return nativeTun, nil +} + +func (t *NativeTun) FrontHeadroom() int { + if t.gsoEnabled { + return virtioNetHdrLen + } + return 0 } func (t *NativeTun) Read(p []byte) (n int, err error) { - return t.tunFile.Read(p) + if t.gsoEnabled { + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { + return + } + var sizes [1]int + n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0) + if err != nil { + return + } + if n == 0 { + return + } + n = sizes[0] + return + } else { + return t.tunFile.Read(p) + } } func (t *NativeTun) Write(p []byte) (n int, err error) { + if t.gsoEnabled { + err = t.BatchWrite([][]byte{p}, virtioNetHdrLen) + if err != nil { + return + } + n = len(p) + return + } return t.tunFile.Write(p) } -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return bufio.NewVectorisedWriter(t.tunFile) +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + if t.gsoEnabled { + n := buf.LenMulti(buffers) + buffer := buf.NewSize(virtioNetHdrLen + n) + buffer.Truncate(virtioNetHdrLen) + buf.CopyMulti(buffer.Extend(n), buffers) + _, err := t.tunFile.Write(buffer.Bytes()) + buffer.Release() + return err + } else { + return t.tunWriter.WriteVectorised(buffers) + } +} + +func (t *NativeTun) BatchSize() int { + if !t.gsoEnabled { + return 1 + } + batchSize := int(gsoMaxSize/t.options.MTU) * 2 + if batchSize > idealBatchSize { + batchSize = idealBatchSize + } + return batchSize +} + +func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { + t.gsoReadAccess.Lock() + defer t.gsoReadAccess.Unlock() + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { + return + } + return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset) +} + +func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error { + t.tcpGROAccess.Lock() + defer func() { + t.tcp4GROTable.reset() + t.tcp6GROTable.reset() + t.tcpGROAccess.Unlock() + }() + t.gsoToWrite = t.gsoToWrite[:0] + err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite) + if err != nil { + return err + } + offset -= virtioNetHdrLen + for _, bufferIndex := range t.gsoToWrite { + _, err = t.tunFile.Write(buffers[bufferIndex][offset:]) + if err != nil { + return err + } + } + return nil } var controlPath string @@ -86,7 +187,7 @@ func init() { } } -func open(name string) (int, error) { +func open(name string, vnetHdr bool) (int, error) { fd, err := unix.Open(controlPath, unix.O_RDWR, 0) if err != nil { return -1, err @@ -100,6 +201,9 @@ func open(name string) (int, error) { copy(ifr.name[:], name) ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI + if vnetHdr { + ifr.flags |= unix.IFF_VNET_HDR + } _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) if errno != 0 { unix.Close(fd) @@ -142,6 +246,46 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { } } + if t.options.GSO { + var vnetHdrEnabled bool + vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name) + if err != nil { + return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled") + } + if !vnetHdrEnabled { + return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled") + } + err = setTCPOffload(t.tunFd) + if err != nil { + return err + } + t.gsoEnabled = true + t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) + t.tcp4GROTable = newTCPGROTable() + t.tcp6GROTable = newTCPGROTable() + } + + var rxChecksumOffload bool + rxChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GRXCSUM) + if err == nil && !rxChecksumOffload { + _ = setChecksumOffload(t.options.Name, unix.ETHTOOL_SRXCSUM) + } + + if t.options._TXChecksumOffload { + var txChecksumOffload bool + txChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GTXCSUM) + if err != nil { + return err + } + if err == nil && !txChecksumOffload { + err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM) + if err != nil { + return err + } + } + t.txChecksumOffload = true + } + err = netlink.LinkSetUp(tunLink) if err != nil { return err @@ -188,6 +332,10 @@ func (t *NativeTun) Close() error { return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } +func (t *NativeTun) TXChecksumOffload() bool { + return t.txChecksumOffload +} + func prefixToIPNet(prefix netip.Prefix) *net.IPNet { return &net.IPNet{ IP: prefix.Addr().AsSlice(), diff --git a/tun_linux_flags.go b/tun_linux_flags.go new file mode 100644 index 0000000..1b84baa --- /dev/null +++ b/tun_linux_flags.go @@ -0,0 +1,84 @@ +//go:build linux + +package tun + +import ( + "os" + "syscall" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/unix" +) + +func checkVNETHDREnabled(fd int, name string) (bool, error) { + ifr, err := unix.NewIfreq(name) + if err != nil { + return false, err + } + err = unix.IoctlIfreq(fd, unix.TUNGETIFF, ifr) + if err != nil { + return false, os.NewSyscallError("TUNGETIFF", err) + } + return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil +} + +func setTCPOffload(fd int) error { + const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + ) + err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") + } + return nil +} + +type ifreqData struct { + ifrName [unix.IFNAMSIZ]byte + ifrData uintptr +} + +type ethtoolValue struct { + cmd uint32 + data uint32 +} + +//go:linkname ioctlPtr golang.org/x/sys/unix.ioctlPtr +func ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) + +func checkChecksumOffload(name string, cmd uint32) (bool, error) { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return false, err + } + defer syscall.Close(fd) + ifr := ifreqData{} + copy(ifr.ifrName[:], name) + data := ethtoolValue{cmd: cmd} + ifr.ifrData = uintptr(unsafe.Pointer(&data)) + err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr)) + if err != nil { + return false, os.NewSyscallError("SIOCETHTOOL ETHTOOL_GTXCSUM", err) + } + return data.data == 0, nil +} + +func setChecksumOffload(name string, cmd uint32) error { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(fd) + ifr := ifreqData{} + copy(ifr.ifrName[:], name) + data := ethtoolValue{cmd: cmd, data: 0} + ifr.ifrData = uintptr(unsafe.Pointer(&data)) + err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr)) + if err != nil { + return os.NewSyscallError("SIOCETHTOOL ETHTOOL_STXCSUM", err) + } + return nil +} diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index b5c400c..1edeab1 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -10,8 +10,19 @@ import ( var _ GVisorTun = (*NativeTun)(nil) func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { + if t.gsoEnabled { + return fdbased.New(&fdbased.Options{ + FDs: []int{t.tunFd}, + MTU: t.options.MTU, + GSOMaxSize: gsoMaxSize, + RXChecksumOffload: true, + TXChecksumOffload: t.txChecksumOffload, + }) + } return fdbased.New(&fdbased.Options{ - FDs: []int{t.tunFd}, - MTU: t.options.MTU, + FDs: []int{t.tunFd}, + MTU: t.options.MTU, + RXChecksumOffload: true, + TXChecksumOffload: t.txChecksumOffload, }) } diff --git a/tun_linux_offload.go b/tun_linux_offload.go new file mode 100644 index 0000000..930b939 --- /dev/null +++ b/tun_linux_offload.go @@ -0,0 +1,768 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "unsafe" + + "github.com/sagernet/sing-tun/internal/clashtcpip" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/unix" +) + +const ( + gsoMaxSize = 65536 + tcpFlagsOffset = 13 + idealBatchSize = 128 +) + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// flowKey represents the key for a flow. +type flowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. +} + +// tcpGROTable holds flow and coalescing information for the purposes of GRO. +type tcpGROTable struct { + itemsByFlow map[flowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize), + itemsPool: make([][]tcpGROItem, idealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) + } + return t +} + +func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { + key := flowKey{} + addrSize := dstAddr - srcAddr + copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) + copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key flowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key flowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if pkt[0]>>4 == 6 { + if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { + // cannot coalesce with unequal Traffic class values + return coalesceUnavailable + } + if pkt[7] != pktTarget[7] { + // cannot coalesce with unequal Hop limit values + return coalesceUnavailable + } + } else { + if pkt[1] != pktTarget[1] { + // cannot coalesce with unequal ToS values + return coalesceUnavailable + } + if pkt[6]>>5 != pktTarget[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return coalesceUnavailable + } + if pkt[8] != pktTarget[8] { + // cannot coalesce with unequal TTL values + return coalesceUnavailable + } + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + tcpTotalLen := uint16(len(pkt) - int(iphLen)) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, returning the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return tcpGROResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return tcpGROResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return tcpGROResultNoop + } + } + if len(pkt) < iphLen { + return tcpGROResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return tcpGROResultNoop + } + if len(pkt) < iphLen+tcphLen { + return tcpGROResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return tcpGROResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return tcpGROResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return tcpGROResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return tcpGROResultNoop + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return tcpGROResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return tcpGROResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return tcpGROResultTableInsert +} + +func isTCP4NoIPOptions(b []byte) bool { + if len(b) < 40 { + return false + } + if b[0]>>4 != 4 { + return false + } + if b[0]&0x0F != 5 { + return false + } + if b[9] != unix.IPPROTO_TCP { + return false + } + return true +} + +func isTCP6NoEH(b []byte) bool { + if len(b) < 60 { + return false + } + if b[0]>>4 != 6 { + return false + } + if b[6] != unix.IPPROTO_TCP { + return false + } + return true +} + +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. +func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result tcpGROResult + switch { + case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce + result = tcpGRO(bufs, offset, i, tcp4Table, false) + case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce + result = tcpGRO(bufs, offset, i, tcp6Table, true) + } + switch result { + case tcpGROResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case tcpGROResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return E.Errors(err4, err6) +} + +// tcpTSO splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) + in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum + firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksumFold(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // TCP header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // TCP checksum + tcpHLen := int(hdr.hdrLen - hdr.csumStart) + tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial))) + return nil +} + +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } + + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, bufs, sizes, offset) +} + +func checksumNoFold(b []byte, initial uint64) uint64 { + return initial + uint64(clashtcpip.Sum(b)) +} + +func checksumFold(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun_linux_offload_errors.go b/tun_linux_offload_errors.go new file mode 100644 index 0000000..8e5db90 --- /dev/null +++ b/tun_linux_offload_errors.go @@ -0,0 +1,5 @@ +package tun + +import E "github.com/sagernet/sing/common/exceptions" + +var ErrTooManySegments = E.New("too many segments") diff --git a/tun_nonlinux.go b/tun_nonlinux.go new file mode 100644 index 0000000..28ce640 --- /dev/null +++ b/tun_nonlinux.go @@ -0,0 +1,5 @@ +//go:build !linux + +package tun + +const OffloadOffset = 0 diff --git a/tun_windows.go b/tun_windows.go index 2028746..7e1a0c3 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -19,7 +19,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/windnsapi" "golang.org/x/sys/windows" @@ -454,10 +453,6 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { return 0, fmt.Errorf("write failed: %w", err) } -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return t -} - func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { defer buf.ReleaseMulti(buffers) return common.Error(t.write(buf.ToSliceMulti(buffers))) diff --git a/tun_windows_gvisor.go b/tun_windows_gvisor.go index 49ad60f..5bea8d7 100644 --- a/tun_windows_gvisor.go +++ b/tun_windows_gvisor.go @@ -35,7 +35,7 @@ func (e *WintunEndpoint) LinkAddress() tcpip.LinkAddress { } func (e *WintunEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityNone + return stack.CapabilityRXChecksumOffload } func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {