Update BatchTUN API for WireGuard

This commit is contained in:
世界 2023-12-15 12:18:45 +08:00
parent 0e138754d5
commit 3195f6f4a2
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 31 additions and 35 deletions

View file

@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}

View file

@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) {
func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := s.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}

4
tun.go
View file

@ -36,8 +36,8 @@ type WinTun interface {
type BatchTUN interface {
Tun
BatchSize() int
BatchRead(buffers [][]byte, readN []int) (n int, err error)
BatchWrite(buffers [][]byte) error
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
}
type Options struct {

View file

@ -35,6 +35,8 @@ type NativeTun struct {
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
@ -105,7 +107,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {
func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p})
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if err != nil {
return
}
@ -140,37 +142,31 @@ func (t *NativeTun) BatchSize() int {
return batchSize
}
func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) {
if t.gsoEnabled {
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0)
if err != nil {
return
}
return
} else {
return 0, os.ErrInvalid
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}
func (t *NativeTun) BatchWrite(buffers [][]byte) error {
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
var toWrite []int
err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite)
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
for _, bufferIndex := range toWrite {
_, err = t.tunFile.Write(buffers[bufferIndex])
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}

View file

@ -750,8 +750,12 @@ func checksumNoFold(b []byte, initial uint64) uint64 {
}
func checksumFold(b []byte, initial uint64) uint16 {
r := clashtcpip.Checksum(uint32(initial), b)
return binary.BigEndian.Uint16(r[:])
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {