mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-04-06 05:17:39 +03:00
Update BatchTUN API for WireGuard
This commit is contained in:
parent
0e138754d5
commit
3195f6f4a2
5 changed files with 31 additions and 35 deletions
|
@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
|
|||
func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
|
||||
frontHeadroom := m.tun.FrontHeadroom()
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
readBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
for i := range packetBuffers {
|
||||
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
|
||||
readBuffers[i] = packetBuffers[i][frontHeadroom:]
|
||||
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom)
|
||||
}
|
||||
for {
|
||||
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
|
||||
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
|
@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
|
|||
continue
|
||||
}
|
||||
packetBuffer := packetBuffers[i]
|
||||
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
|
||||
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
|
||||
if m.processPacket(packet) {
|
||||
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
err = linuxTUN.BatchWrite(writeBuffers)
|
||||
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
|
||||
if err != nil {
|
||||
m.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
|
|
|
@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) {
|
|||
func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
|
||||
frontHeadroom := s.tun.FrontHeadroom()
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
readBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
for i := range packetBuffers {
|
||||
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset)
|
||||
readBuffers[i] = packetBuffers[i][frontHeadroom:]
|
||||
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom)
|
||||
}
|
||||
for {
|
||||
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
|
||||
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
|
@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
|
|||
continue
|
||||
}
|
||||
packetBuffer := packetBuffers[i]
|
||||
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
|
||||
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
|
||||
if s.processPacket(packet) {
|
||||
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
err = linuxTUN.BatchWrite(writeBuffers)
|
||||
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
|
||||
if err != nil {
|
||||
s.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
|
|
4
tun.go
4
tun.go
|
@ -36,8 +36,8 @@ type WinTun interface {
|
|||
type BatchTUN interface {
|
||||
Tun
|
||||
BatchSize() int
|
||||
BatchRead(buffers [][]byte, readN []int) (n int, err error)
|
||||
BatchWrite(buffers [][]byte) error
|
||||
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
|
||||
BatchWrite(buffers [][]byte, offset int) error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
|
|
30
tun_linux.go
30
tun_linux.go
|
@ -35,6 +35,8 @@ type NativeTun struct {
|
|||
ruleIndex6 []int
|
||||
gsoEnabled bool
|
||||
gsoBuffer []byte
|
||||
gsoToWrite []int
|
||||
gsoReadAccess sync.Mutex
|
||||
tcpGROAccess sync.Mutex
|
||||
tcp4GROTable *tcpGROTable
|
||||
tcp6GROTable *tcpGROTable
|
||||
|
@ -105,7 +107,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {
|
|||
|
||||
func (t *NativeTun) Write(p []byte) (n int, err error) {
|
||||
if t.gsoEnabled {
|
||||
err = t.BatchWrite([][]byte{p})
|
||||
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -140,37 +142,31 @@ func (t *NativeTun) BatchSize() int {
|
|||
return batchSize
|
||||
}
|
||||
|
||||
func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) {
|
||||
if t.gsoEnabled {
|
||||
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
|
||||
t.gsoReadAccess.Lock()
|
||||
defer t.gsoReadAccess.Unlock()
|
||||
n, err = t.tunFile.Read(t.gsoBuffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
} else {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
|
||||
}
|
||||
|
||||
func (t *NativeTun) BatchWrite(buffers [][]byte) error {
|
||||
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
|
||||
t.tcpGROAccess.Lock()
|
||||
defer func() {
|
||||
t.tcp4GROTable.reset()
|
||||
t.tcp6GROTable.reset()
|
||||
t.tcpGROAccess.Unlock()
|
||||
}()
|
||||
var toWrite []int
|
||||
err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite)
|
||||
t.gsoToWrite = t.gsoToWrite[:0]
|
||||
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, bufferIndex := range toWrite {
|
||||
_, err = t.tunFile.Write(buffers[bufferIndex])
|
||||
offset -= virtioNetHdrLen
|
||||
for _, bufferIndex := range t.gsoToWrite {
|
||||
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -750,8 +750,12 @@ func checksumNoFold(b []byte, initial uint64) uint64 {
|
|||
}
|
||||
|
||||
func checksumFold(b []byte, initial uint64) uint16 {
|
||||
r := clashtcpip.Checksum(uint32(initial), b)
|
||||
return binary.BigEndian.Uint16(r[:])
|
||||
ac := checksumNoFold(b, initial)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue