Add GSO support

This commit is contained in:
世界 2023-12-10 00:00:14 +08:00
parent fa89d2c0a5
commit 5b50c61b72
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
22 changed files with 1376 additions and 449 deletions

3
go.mod
View file

@ -5,10 +5,9 @@ go 1.18
require (
github.com/fsnotify/fsnotify v1.7.0
github.com/go-ole/go-ole v1.3.0
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898
github.com/sagernet/sing v0.3.0-rc.2
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/net v0.19.0

14
go.sum
View file

@ -1,20 +1,22 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e h1:DOkjByVeAR56dkszjnMZke4wr7yM/1xHaJF3G9olkEE=
github.com/sagernet/gvisor v0.0.0-20231209105102-8d27a30e436e/go.mod h1:fLxq/gtp0qzkaEwywlRRiGmjOK5ES/xUzyIKIFP2Asw=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898 h1:ZR0wpw4/0NCICOX10SIUW8jpPVV7+D98nGA6p4zWICo=
github.com/sagernet/sing v0.2.20-0.20231211084415-35e7014b0898/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/sagernet/sing v0.3.0-beta.3 h1:E2xBoJUducK/FE6EwMk95Rt2bkXeht9l1BTYRui+DXs=
github.com/sagernet/sing v0.3.0-beta.3/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug=
github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
@ -22,9 +24,9 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/W
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -50,6 +50,10 @@ func (p TCPPacket) SetChecksum(sum [2]byte) {
p[17] = sum[1]
}
func (p TCPPacket) OffloadChecksum() {
p.SetChecksum(zeroChecksum)
}
func (p TCPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))

View file

@ -45,6 +45,10 @@ func (p UDPPacket) SetChecksum(sum [2]byte) {
p[7] = sum[1]
}
func (p UDPPacket) OffloadChecksum() {
p.SetChecksum(zeroChecksum)
}
func (p UDPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))

View file

@ -19,10 +19,7 @@ type Stack interface {
type StackOptions struct {
Context context.Context
Tun Tun
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
TunOptions Options
EndpointIndependentNat bool
UDPTimeout int64
Handler Handler
@ -37,7 +34,7 @@ func NewStack(
) (Stack, error) {
switch stack {
case "":
if WithGVisor {
if WithGVisor && !options.TunOptions.GSO {
return NewMixed(options)
} else {
return NewSystem(options)
@ -48,8 +45,6 @@ func NewStack(
return NewMixed(options)
case "system":
return NewSystem(options)
case "lwip":
return NewLWIP(options)
default:
return nil, E.New("unknown stack: ", stack)
}

View file

@ -31,7 +31,6 @@ const defaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
tunMtu uint32
endpointIndependentNat bool
udpTimeout int64
broadcastAddr netip.Addr
@ -57,10 +56,9 @@ func NewGVisor(
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.Inet4Address),
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
@ -72,7 +70,7 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err

View file

@ -82,7 +82,6 @@ type UDPBackWriter struct {
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
packet stack.PacketBufferPtr
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
@ -149,12 +148,6 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
return nil
}
type gRequest struct {
stack *stack.Stack
id stack.TransportEndpointID
pkt stack.PacketBufferPtr
}
type gUDPConn struct {
*gonet.UDPConn
}

View file

@ -1,144 +0,0 @@
//go:build with_lwip
package tun
import (
"context"
"net"
"net/netip"
"os"
lwip "github.com/sagernet/go-tun2socks/core"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type LWIP struct {
ctx context.Context
tun Tun
tunMtu uint32
udpTimeout int64
handler Handler
stack lwip.LWIPStack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewLWIP(
options StackOptions,
) (Stack, error) {
return &LWIP{
ctx: options.Context,
tun: options.Tun,
tunMtu: options.MTU,
handler: options.Handler,
stack: lwip.NewLWIPStack(),
udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler),
}, nil
}
func (l *LWIP) Start() error {
lwip.RegisterTCPConnHandler(l)
lwip.RegisterUDPConnHandler(l)
lwip.RegisterOutputFn(l.tun.Write)
go l.loopIn()
return nil
}
func (l *LWIP) loopIn() {
if winTun, isWintun := l.tun.(WinTun); isWintun {
l.loopInWintun(winTun)
return
}
buffer := make([]byte, int(l.tunMtu)+PacketOffset)
for {
n, err := l.tun.Read(buffer)
if err != nil {
return
}
_, err = l.stack.Write(buffer[PacketOffset:n])
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) loopInWintun(tun WinTun) {
for {
packet, release, err := tun.ReadPacket()
if err != nil {
return
}
_, err = l.stack.Write(packet)
release()
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) Close() error {
lwip.RegisterTCPConnHandler(nil)
lwip.RegisterUDPConnHandler(nil)
lwip.RegisterOutputFn(func(bytes []byte) (int, error) {
return 0, os.ErrClosed
})
return l.stack.Close()
}
func (l *LWIP) Handle(conn net.Conn) error {
lAddr := conn.LocalAddr()
rAddr := conn.RemoteAddr()
if lAddr == nil || rAddr == nil {
conn.Close()
return nil
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
hErr := l.handler.NewConnection(l.ctx, conn, metadata)
if hErr != nil {
conn.(lwip.TCPConn).Abort()
}
}()
return nil
}
func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = conn.LocalAddr()
upstreamMetadata.Destination = addr
l.udpNat.NewPacket(
l.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(data).ToOwned(),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &LWIPUDPBackWriter{conn}
},
)
return nil
}
type LWIPUDPBackWriter struct {
conn lwip.UDPConn
}
func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination))
}
func (w *LWIPUDPBackWriter) Close() error {
return w.conn.Close()
}

View file

@ -1,11 +0,0 @@
//go:build !with_lwip
package tun
import E "github.com/sagernet/sing/common/exceptions"
func NewLWIP(
options StackOptions,
) (Stack, error) {
return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`)
}

View file

@ -13,17 +13,14 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type Mixed struct {
*System
writer N.VectorisedWriter
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
@ -38,7 +35,6 @@ func NewMixed(
}
return &Mixed{
System: system.(*System),
writer: options.Tun.CreateVectorisedWriter(),
endpointIndependentNat: options.EndpointIndependentNat,
}, nil
}
@ -48,7 +44,7 @@ func (m *Mixed) Start() error {
if err != nil {
return err
}
endpoint := channel.New(1024, m.mtu, "")
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
if err != nil {
return err
@ -95,26 +91,34 @@ func (m *Mixed) tunLoop() {
m.wintunLoop(winTun)
return
}
if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
m.frontHeadroom = linuxTUN.FrontHeadroom()
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
if err != nil {
return
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "read packet"))
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
if m.processPacket(packet) {
_, err = m.tun.Write(rawPacket)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
@ -129,62 +133,119 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
if m.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(m.tun.Write(packet))
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMP:
return m.processIPv4ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
for {
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < clashtcpip.IPv4PacketMinLength {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(m.tun.Write(packet))
func (m *Mixed) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
writeBack, err = m.processIPv4(packet)
case 6:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
return false
}
return writeBack
}
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
writeBack = true
destination := packet.DestinationIP()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv6TCP(packet, packet.Payload())
err = m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
Payload: buffer.MakeWithData(packet),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return
case clashtcpip.ICMP:
err = m.processIPv4ICMP(packet, packet.Payload())
}
return
}
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
writeBack = true
if !packet.DestinationIP().IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
err = m.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMPv6:
return m.processIPv6ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
err = m.processIPv6ICMP(packet, packet.Payload())
}
return
}
func (m *Mixed) packetLoop() {
@ -193,7 +254,7 @@ func (m *Mixed) packetLoop() {
if packet == nil {
break
}
bufio.WriteVectorised(m.writer, packet.AsSlices())
bufio.WriteVectorised(m.tun, packet.AsSlices())
packet.DecRef()
}
}

View file

@ -22,7 +22,7 @@ type System struct {
ctx context.Context
tun Tun
tunName string
mtu uint32
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
@ -41,6 +41,8 @@ type System struct {
udpNat *udpnat.Service[netip.AddrPort]
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
}
type Session struct {
@ -54,29 +56,29 @@ func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
tunName: options.Name,
mtu: options.MTU,
tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
broadcastAddr: BroadcastAddr(options.Inet4Address),
inet4Prefixes: options.TunOptions.Inet4Address,
inet6Prefixes: options.TunOptions.Inet6Address,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
if len(options.Inet4Address) > 0 {
if options.Inet4Address[0].Bits() == 32 {
if len(options.TunOptions.Inet4Address) > 0 {
if options.TunOptions.Inet4Address[0].Bits() == 32 {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
}
stack.inet4ServerAddress = options.Inet4Address[0].Addr()
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
stack.inet4Address = stack.inet4ServerAddress.Next()
}
if len(options.Inet6Address) > 0 {
if options.Inet6Address[0].Bits() == 128 {
if len(options.TunOptions.Inet6Address) > 0 {
if options.TunOptions.Inet6Address[0].Bits() == 128 {
return nil, E.New("need one more IPv6 address in first prefix for system stack")
}
stack.inet6ServerAddress = options.Inet6Address[0].Addr()
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
stack.inet6Address = stack.inet6ServerAddress.Next()
}
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
@ -144,26 +146,34 @@ func (s *System) tunLoop() {
s.wintunLoop(winTun)
return
}
if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
s.frontHeadroom = linuxTUN.FrontHeadroom()
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer)
if err != nil {
return
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "read packet"))
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet)
case 6:
err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
if s.processPacket(packet) {
_, err = s.tun.Write(rawPacket)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
@ -178,21 +188,75 @@ func (s *System) wintunLoop(winTun WinTun) {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet)
case 6:
err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
if s.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < clashtcpip.IPv4PacketMinLength {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (s *System) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
writeBack, err = s.processIPv4(packet)
case 6:
writeBack, err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
return false
}
return writeBack
}
func (s *System) acceptLoop(listener net.Listener) {
for {
conn, err := listener.Accept()
@ -234,44 +298,46 @@ func (s *System) acceptLoop(listener net.Listener) {
}
}
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
writeBack = true
destination := packet.DestinationIP()
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
err = s.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv4UDP(packet, packet.Payload())
writeBack = false
err = s.processIPv4UDP(packet, packet.Payload())
case clashtcpip.ICMP:
return s.processIPv4ICMP(packet, packet.Payload())
default:
return common.Error(s.tun.Write(packet))
err = s.processIPv4ICMP(packet, packet.Payload())
}
return
}
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
writeBack = true
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())
err = s.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv6UDP(packet, packet.Payload())
writeBack = false
err = s.processIPv6UDP(packet, packet.Payload())
case clashtcpip.ICMPv6:
return s.processIPv6ICMP(packet, packet.Payload())
default:
return common.Error(s.tun.Write(packet))
err = s.processIPv6ICMP(packet, packet.Payload())
}
return
}
func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
@ -288,16 +354,21 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(s.inet4ServerAddress)
header.SetDestinationPort(s.tcpPort)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
} else {
header.OffloadChecksum()
packet.ResetChecksum()
}
return nil
}
func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
@ -314,9 +385,12 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(s.inet6ServerAddress)
header.SetDestinationPort(s.tcpPort6)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
} else {
header.OffloadChecksum()
}
return nil
}
func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error {
@ -332,7 +406,7 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
@ -346,7 +420,13 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{s.tun, headerCopy, source}
return &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
@ -358,7 +438,7 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
@ -372,7 +452,13 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{s.tun, headerCopy, source}
return &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
@ -387,7 +473,7 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum()
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
return nil
}
func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
@ -400,102 +486,21 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
type systemTCPDirectPacketWriter4 struct {
tun Tun
source netip.AddrPort
}
func (w *systemTCPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
header := clashtcpip.TCPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemTCPDirectPacketWriter6 struct {
tun Tun
source netip.AddrPort
}
func (w *systemTCPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
header := clashtcpip.TCPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemUDPDirectPacketWriter4 struct {
tun Tun
source netip.AddrPort
}
func (w *systemUDPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
header := clashtcpip.UDPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemUDPDirectPacketWriter6 struct {
tun Tun
source netip.AddrPort
}
func (w *systemUDPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
header := clashtcpip.UDPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemICMPDirectPacketWriter4 struct {
tun Tun
source netip.Addr
}
func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
packet.SetDestinationIP(w.source)
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemICMPDirectPacketWriter6 struct {
tun Tun
source netip.Addr
}
func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
packet.SetDestinationIP(w.source)
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
return nil
}
type systemUDPPacketWriter4 struct {
tun Tun
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
@ -506,20 +511,33 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
} else {
udpHdr.OffloadChecksum()
ipHdr.ResetChecksum()
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemUDPPacketWriter6 struct {
tun Tun
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
@ -531,6 +549,15 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
udpHdr.ResetChecksum(ipHdr.PseudoSum())
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
} else {
udpHdr.OffloadChecksum()
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

15
tun.go
View file

@ -23,7 +23,7 @@ type Handler interface {
type Tun interface {
io.ReadWriter
CreateVectorisedWriter() N.VectorisedWriter
N.VectorisedWriter
Close() error
}
@ -32,11 +32,21 @@ type WinTun interface {
ReadPacket() ([]byte, func(), error)
}
type LinuxTUN interface {
Tun
N.FrontHeadroom
BatchSize() int
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
TXChecksumOffload() bool
}
type Options struct {
Name string
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
MTU uint32
GSO bool
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
@ -54,6 +64,9 @@ type Options struct {
TableIndex int
FileDescriptor int
Logger logger.Logger
// No work for TCP, do not use.
_TXChecksumOffload bool
}
func CalculateInterfaceName(name string) (tunName string) {

View file

@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"os"
"runtime"
"syscall"
"unsafe"
@ -68,44 +67,22 @@ func New(options Options) (Tun, error) {
if !ok {
panic("create vectorised writer")
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
return nativeTun, nil
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
/*n, err = t.tunFile.Read(p)
if n < 4 {
return 0, err
}
copy(p[:], p[4:])
return n - 4, err*/
return t.tunFile.Read(p)
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}
var (
packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET}
packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6}
)
func (t *NativeTun) Write(p []byte) (n int, err error) {
var packetHeader []byte
if p[0]>>4 == 4 {
packetHeader = packetHeader4[:]
} else {
packetHeader = packetHeader6[:]
}
_, err = bufio.WriteVectorised(t.tunWriter, [][]byte{packetHeader, p})
if err == nil {
n = len(p)
}
return
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return t
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if buffers[0].Byte(0)>>4 == 4 {

View file

@ -36,7 +36,7 @@ func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress {
}
func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
return stack.CapabilityRXChecksumOffload
}
func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
@ -51,13 +51,13 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
func (e *DarwinEndpoint) dispatchLoop() {
packetBuffer := make([]byte, e.tun.mtu+4)
packetBuffer := make([]byte, e.tun.mtu+PacketOffset)
for {
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {
break
}
packet := packetBuffer[4:n]
packet := packetBuffer[PacketOffset:n]
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
case header.IPv4Version:
@ -112,14 +112,7 @@ func (e *DarwinEndpoint) ParseHeader(ptr stack.PacketBufferPtr) bool {
func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {
var n int
for _, packet := range packetBufferList.AsSlice() {
var packetHeader []byte
switch packet.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
packetHeader = packetHeader4[:]
case header.IPv6ProtocolNumber:
packetHeader = packetHeader6[:]
}
_, err := bufio.WriteVectorised(e.tun.tunWriter, append([][]byte{packetHeader}, packet.AsSlices()...))
_, err := bufio.WriteVectorised(e.tun, packet.AsSlices())
if err != nil {
return n, &tcpip.ErrAborted{}
}

View file

@ -7,11 +7,13 @@ import (
"os"
"os/exec"
"runtime"
"sync"
"syscall"
"unsafe"
"github.com/sagernet/netlink"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
@ -22,17 +24,29 @@ import (
"golang.org/x/sys/unix"
)
var _ LinuxTUN = (*NativeTun)(nil)
type NativeTun struct {
tunFd int
tunFile *os.File
tunWriter N.VectorisedWriter
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
txChecksumOffload bool
}
func New(options Options) (Tun, error) {
var nativeTun *NativeTun
if options.FileDescriptor == 0 {
tunFd, err := open(options.Name)
tunFd, err := open(options.Name, options.GSO)
if err != nil {
return nil, err
}
@ -40,38 +54,125 @@ func New(options Options) (Tun, error) {
if err != nil {
return nil, E.Errors(err, unix.Close(tunFd))
}
nativeTun := &NativeTun{
nativeTun = &NativeTun{
tunFd: tunFd,
tunFile: os.NewFile(uintptr(tunFd), "tun"),
options: options,
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
err = nativeTun.configure(tunLink)
if err != nil {
return nil, E.Errors(err, unix.Close(tunFd))
}
return nativeTun, nil
} else {
nativeTun := &NativeTun{
nativeTun = &NativeTun{
tunFd: options.FileDescriptor,
tunFile: os.NewFile(uintptr(options.FileDescriptor), "tun"),
options: options,
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
return nativeTun, nil
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
if !ok {
panic("create vectorised writer")
}
return nativeTun, nil
}
func (t *NativeTun) FrontHeadroom() int {
if t.gsoEnabled {
return virtioNetHdrLen
}
return 0
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p)
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
var sizes [1]int
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
if err != nil {
return
}
if n == 0 {
return
}
n = sizes[0]
return
} else {
return t.tunFile.Read(p)
}
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if err != nil {
return
}
n = len(p)
return
}
return t.tunFile.Write(p)
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return bufio.NewVectorisedWriter(t.tunFile)
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.gsoEnabled {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
buf.CopyMulti(buffer.Extend(n), buffers)
_, err := t.tunFile.Write(buffer.Bytes())
buffer.Release()
return err
} else {
return t.tunWriter.WriteVectorised(buffers)
}
}
func (t *NativeTun) BatchSize() int {
if !t.gsoEnabled {
return 1
}
batchSize := int(gsoMaxSize/t.options.MTU) * 2
if batchSize > idealBatchSize {
batchSize = idealBatchSize
}
return batchSize
}
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}
}
return nil
}
var controlPath string
@ -86,7 +187,7 @@ func init() {
}
}
func open(name string) (int, error) {
func open(name string, vnetHdr bool) (int, error) {
fd, err := unix.Open(controlPath, unix.O_RDWR, 0)
if err != nil {
return -1, err
@ -100,6 +201,9 @@ func open(name string) (int, error) {
copy(ifr.name[:], name)
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
if vnetHdr {
ifr.flags |= unix.IFF_VNET_HDR
}
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
unix.Close(fd)
@ -142,6 +246,46 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
}
}
if t.options.GSO {
var vnetHdrEnabled bool
vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name)
if err != nil {
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
}
if !vnetHdrEnabled {
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
}
err = setTCPOffload(t.tunFd)
if err != nil {
return err
}
t.gsoEnabled = true
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcp4GROTable = newTCPGROTable()
t.tcp6GROTable = newTCPGROTable()
}
var rxChecksumOffload bool
rxChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GRXCSUM)
if err == nil && !rxChecksumOffload {
_ = setChecksumOffload(t.options.Name, unix.ETHTOOL_SRXCSUM)
}
if t.options._TXChecksumOffload {
var txChecksumOffload bool
txChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GTXCSUM)
if err != nil {
return err
}
if err == nil && !txChecksumOffload {
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
if err != nil {
return err
}
}
t.txChecksumOffload = true
}
err = netlink.LinkSetUp(tunLink)
if err != nil {
return err
@ -188,6 +332,10 @@ func (t *NativeTun) Close() error {
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
}
func (t *NativeTun) TXChecksumOffload() bool {
return t.txChecksumOffload
}
func prefixToIPNet(prefix netip.Prefix) *net.IPNet {
return &net.IPNet{
IP: prefix.Addr().AsSlice(),

84
tun_linux_flags.go Normal file
View file

@ -0,0 +1,84 @@
//go:build linux
package tun
import (
"os"
"syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
func checkVNETHDREnabled(fd int, name string) (bool, error) {
ifr, err := unix.NewIfreq(name)
if err != nil {
return false, err
}
err = unix.IoctlIfreq(fd, unix.TUNGETIFF, ifr)
if err != nil {
return false, os.NewSyscallError("TUNGETIFF", err)
}
return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil
}
func setTCPOffload(fd int) error {
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
}
return nil
}
type ifreqData struct {
ifrName [unix.IFNAMSIZ]byte
ifrData uintptr
}
type ethtoolValue struct {
cmd uint32
data uint32
}
//go:linkname ioctlPtr golang.org/x/sys/unix.ioctlPtr
func ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error)
func checkChecksumOffload(name string, cmd uint32) (bool, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return false, err
}
defer syscall.Close(fd)
ifr := ifreqData{}
copy(ifr.ifrName[:], name)
data := ethtoolValue{cmd: cmd}
ifr.ifrData = uintptr(unsafe.Pointer(&data))
err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr))
if err != nil {
return false, os.NewSyscallError("SIOCETHTOOL ETHTOOL_GTXCSUM", err)
}
return data.data == 0, nil
}
func setChecksumOffload(name string, cmd uint32) error {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(fd)
ifr := ifreqData{}
copy(ifr.ifrName[:], name)
data := ethtoolValue{cmd: cmd, data: 0}
ifr.ifrData = uintptr(unsafe.Pointer(&data))
err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr))
if err != nil {
return os.NewSyscallError("SIOCETHTOOL ETHTOOL_STXCSUM", err)
}
return nil
}

View file

@ -10,8 +10,19 @@ import (
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
if t.gsoEnabled {
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
GSOMaxSize: gsoMaxSize,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
}
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
FDs: []int{t.tunFd},
MTU: t.options.MTU,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
}

768
tun_linux_offload.go Normal file
View file

@ -0,0 +1,768 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"unsafe"
"github.com/sagernet/sing-tun/internal/clashtcpip"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
const (
gsoMaxSize = 65536
tcpFlagsOffset = 13
idealBatchSize = 128
)
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize),
itemsPool: make([][]tcpGROItem, idealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
type tcpGROResult int
const (
tcpGROResultNoop tcpGROResult = iota
tcpGROResultTableInsert
tcpGROResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a tcpGROResultNoop when no
// action was taken, tcpGROResultTableInsert when the evaluated packet was
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return tcpGROResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return tcpGROResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return tcpGROResultNoop
}
}
if len(pkt) < iphLen {
return tcpGROResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return tcpGROResultNoop
}
if len(pkt) < iphLen+tcphLen {
return tcpGROResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return tcpGROResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return tcpGROResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return tcpGROResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return tcpGROResultNoop
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return tcpGROResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return tcpGROResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return tcpGROResultTableInsert
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// applyCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result tcpGROResult
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
result = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
result = tcpGRO(bufs, offset, i, tcp6Table, true)
}
switch result {
case tcpGROResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case tcpGROResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
return E.Errors(err4, err6)
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksumFold(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial)))
return nil
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
// This means CHECKSUM_PARTIAL in skb context. We are responsible
// for computing the checksum starting at hdr.csumStart and placing
// at hdr.csumOffset.
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
if err != nil {
return 0, err
}
}
if len(in) > len(bufs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
}
n := copy(bufs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
}
func checksumNoFold(b []byte, initial uint64) uint64 {
return initial + uint64(clashtcpip.Sum(b))
}
func checksumFold(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

View file

@ -0,0 +1,5 @@
package tun
import E "github.com/sagernet/sing/common/exceptions"
var ErrTooManySegments = E.New("too many segments")

5
tun_nonlinux.go Normal file
View file

@ -0,0 +1,5 @@
//go:build !linux
package tun
const OffloadOffset = 0

View file

@ -19,7 +19,6 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/windnsapi"
"golang.org/x/sys/windows"
@ -454,10 +453,6 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return t
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
return common.Error(t.write(buf.ToSliceMulti(buffers)))

View file

@ -35,7 +35,7 @@ func (e *WintunEndpoint) LinkAddress() tcpip.LinkAddress {
}
func (e *WintunEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
return stack.CapabilityRXChecksumOffload
}
func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {