mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-04-05 04:47:41 +03:00
Compare commits
25 commits
v0.6.0-alp
...
dev
Author | SHA1 | Date | |
---|---|---|---|
|
35b5747b44 | ||
|
5cb6d27288 | ||
|
9105485a50 | ||
|
57aba1a5c4 | ||
|
7f3343169a | ||
|
22b811f938 | ||
|
618be14c7b | ||
|
c8c2984261 | ||
|
8cc5351bb3 | ||
|
d093b82064 | ||
|
aa9d9c6296 | ||
|
f457988090 | ||
|
d0887eabba | ||
|
edabb6d7ba | ||
|
d38f9adaef | ||
|
618d3f9a52 | ||
|
c21c623174 | ||
|
091b5da950 | ||
|
59a6bdc1fa | ||
|
c177abb523 | ||
|
6ef42f019b | ||
|
c1f61d08ba | ||
|
2b8115e83b | ||
|
06b4d4ecd1 | ||
|
4ebeb2fa86 |
40 changed files with 4244 additions and 1543 deletions
1
Makefile
1
Makefile
|
@ -29,4 +29,5 @@ lint_install:
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go build -v .
|
go build -v .
|
||||||
|
go test -bench=. ./internal/checksum_test
|
||||||
#go test -v .
|
#go test -v .
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -6,10 +6,10 @@ require (
|
||||||
github.com/go-ole/go-ole v1.3.0
|
github.com/go-ole/go-ole v1.3.0
|
||||||
github.com/google/btree v1.1.3
|
github.com/google/btree v1.1.3
|
||||||
github.com/sagernet/fswatch v0.1.1
|
github.com/sagernet/fswatch v0.1.1
|
||||||
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
|
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
|
||||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
||||||
github.com/sagernet/nftables v0.3.0-beta.4
|
github.com/sagernet/nftables v0.3.0-beta.4
|
||||||
github.com/sagernet/sing v0.6.0-alpha.11
|
github.com/sagernet/sing v0.6.0-beta.2
|
||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
||||||
golang.org/x/net v0.26.0
|
golang.org/x/net v0.26.0
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
|
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
|
||||||
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
|
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
|
||||||
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM=
|
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
|
||||||
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
|
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
|
||||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
|
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
|
||||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||||
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
||||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||||
github.com/sagernet/sing v0.6.0-alpha.11 h1:ZcZlA0/vdDeiipAbjK73x9VabGJ/RRcAJgWhOo/OoBk=
|
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
|
||||||
github.com/sagernet/sing v0.6.0-alpha.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
|
|
33
internal/checksum_test/sum_bench_test.go
Normal file
33
internal/checksum_test/sum_bench_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package checksum_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/internal/tschecksum"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkTsChecksum(b *testing.B) {
|
||||||
|
packet := make([][]byte, 1000)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
packet[i] = make([]byte, 1500)
|
||||||
|
rand.Read(packet[i])
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tschecksum.Checksum(packet[i%1000], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGChecksum(b *testing.B) {
|
||||||
|
packet := make([][]byte, 1000)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
packet[i] = make([]byte, 1500)
|
||||||
|
rand.Read(packet[i])
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
checksum.ChecksumDefault(packet[i%1000], 0)
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,34 +30,6 @@ func Put(b []byte, xsum uint16) {
|
||||||
binary.BigEndian.PutUint16(b, xsum)
|
binary.BigEndian.PutUint16(b, xsum)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
|
|
||||||
// given byte array. This function uses an optimized version of the checksum
|
|
||||||
// algorithm.
|
|
||||||
//
|
|
||||||
// The initial checksum must have been computed on an even number of bytes.
|
|
||||||
func Checksum(buf []byte, initial uint16) uint16 {
|
|
||||||
s, _ := calculateChecksum(buf, false, initial)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checksumer calculates checksum defined in RFC 1071.
|
|
||||||
type Checksumer struct {
|
|
||||||
sum uint16
|
|
||||||
odd bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds b to checksum.
|
|
||||||
func (c *Checksumer) Add(b []byte) {
|
|
||||||
if len(b) > 0 {
|
|
||||||
c.sum, c.odd = calculateChecksum(b, c.odd, c.sum)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checksum returns the latest checksum value.
|
|
||||||
func (c *Checksumer) Checksum() uint16 {
|
|
||||||
return c.sum
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine combines the two uint16 to form their checksum. This is done
|
// Combine combines the two uint16 to form their checksum. This is done
|
||||||
// by adding them and the carry.
|
// by adding them and the carry.
|
||||||
//
|
//
|
||||||
|
@ -66,3 +38,8 @@ func Combine(a, b uint16) uint16 {
|
||||||
v := uint32(a) + uint32(b)
|
v := uint32(a) + uint32(b)
|
||||||
return uint16(v + v>>16)
|
return uint16(v + v>>16)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ChecksumDefault(buf []byte, initial uint16) uint16 {
|
||||||
|
s, _ := calculateChecksum(buf, false, initial)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
12
internal/gtcpip/checksum/checksum_default.go
Normal file
12
internal/gtcpip/checksum/checksum_default.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
//go:build !amd64
|
||||||
|
|
||||||
|
package checksum
|
||||||
|
|
||||||
|
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
|
||||||
|
// given byte array. This function uses an optimized version of the checksum
|
||||||
|
// algorithm.
|
||||||
|
//
|
||||||
|
// The initial checksum must have been computed on an even number of bytes.
|
||||||
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
return ChecksumDefault(buf, initial)
|
||||||
|
}
|
9
internal/gtcpip/checksum/checksum_ts.go
Normal file
9
internal/gtcpip/checksum/checksum_ts.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
//go:build amd64
|
||||||
|
|
||||||
|
package checksum
|
||||||
|
|
||||||
|
import "github.com/sagernet/sing-tun/internal/tschecksum"
|
||||||
|
|
||||||
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
return tschecksum.Checksum(buf, initial)
|
||||||
|
}
|
136
internal/gtcpip/header/interfaces.go
Normal file
136
internal/gtcpip/header/interfaces.go
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
// Copyright 2018 The gVisor Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package header
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MaxIPPacketSize is the maximum supported IP packet size, excluding
|
||||||
|
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
|
||||||
|
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
|
||||||
|
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
|
||||||
|
// m is the minimum IPv6 header size; we leave room for some potential
|
||||||
|
// IP options.
|
||||||
|
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transport offers generic methods to query and/or update the fields of the
|
||||||
|
// header of a transport protocol buffer.
|
||||||
|
type Transport interface {
|
||||||
|
// SourcePort returns the value of the "source port" field.
|
||||||
|
SourcePort() uint16
|
||||||
|
|
||||||
|
// Destination returns the value of the "destination port" field.
|
||||||
|
DestinationPort() uint16
|
||||||
|
|
||||||
|
// Checksum returns the value of the "checksum" field.
|
||||||
|
Checksum() uint16
|
||||||
|
|
||||||
|
// SetSourcePort sets the value of the "source port" field.
|
||||||
|
SetSourcePort(uint16)
|
||||||
|
|
||||||
|
// SetDestinationPort sets the value of the "destination port" field.
|
||||||
|
SetDestinationPort(uint16)
|
||||||
|
|
||||||
|
// SetChecksum sets the value of the "checksum" field.
|
||||||
|
SetChecksum(uint16)
|
||||||
|
|
||||||
|
// Payload returns the data carried in the transport buffer.
|
||||||
|
Payload() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChecksummableTransport is a Transport that supports checksumming.
|
||||||
|
type ChecksummableTransport interface {
|
||||||
|
Transport
|
||||||
|
|
||||||
|
// SetSourcePortWithChecksumUpdate sets the source port and updates
|
||||||
|
// the checksum.
|
||||||
|
//
|
||||||
|
// The receiver's checksum must be a fully calculated checksum.
|
||||||
|
SetSourcePortWithChecksumUpdate(port uint16)
|
||||||
|
|
||||||
|
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
|
||||||
|
// the checksum.
|
||||||
|
//
|
||||||
|
// The receiver's checksum must be a fully calculated checksum.
|
||||||
|
SetDestinationPortWithChecksumUpdate(port uint16)
|
||||||
|
|
||||||
|
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
|
||||||
|
// updated address in the pseudo header.
|
||||||
|
//
|
||||||
|
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
|
||||||
|
// fully calculated checksum. Otherwise, it is assumed to hold a partially
|
||||||
|
// calculated checksum which only reflects the pseudo header.
|
||||||
|
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network offers generic methods to query and/or update the fields of the
|
||||||
|
// header of a network protocol buffer.
|
||||||
|
type Network interface {
|
||||||
|
// SourceAddress returns the value of the "source address" field.
|
||||||
|
SourceAddress() tcpip.Address
|
||||||
|
|
||||||
|
// DestinationAddress returns the value of the "destination address"
|
||||||
|
// field.
|
||||||
|
DestinationAddress() tcpip.Address
|
||||||
|
|
||||||
|
DestinationAddr() netip.Addr
|
||||||
|
|
||||||
|
// Checksum returns the value of the "checksum" field.
|
||||||
|
Checksum() uint16
|
||||||
|
|
||||||
|
// SetSourceAddress sets the value of the "source address" field.
|
||||||
|
SetSourceAddress(tcpip.Address)
|
||||||
|
|
||||||
|
// SetDestinationAddress sets the value of the "destination address"
|
||||||
|
// field.
|
||||||
|
SetDestinationAddress(tcpip.Address)
|
||||||
|
|
||||||
|
SetDestinationAddr(addr netip.Addr)
|
||||||
|
|
||||||
|
// SetChecksum sets the value of the "checksum" field.
|
||||||
|
SetChecksum(uint16)
|
||||||
|
|
||||||
|
// TransportProtocol returns the number of the transport protocol
|
||||||
|
// stored in the payload.
|
||||||
|
TransportProtocol() tcpip.TransportProtocolNumber
|
||||||
|
|
||||||
|
// Payload returns a byte slice containing the payload of the network
|
||||||
|
// packet.
|
||||||
|
Payload() []byte
|
||||||
|
|
||||||
|
// TOS returns the values of the "type of service" and "flow label" fields.
|
||||||
|
TOS() (uint8, uint32)
|
||||||
|
|
||||||
|
// SetTOS sets the values of the "type of service" and "flow label" fields.
|
||||||
|
SetTOS(t uint8, l uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChecksummableNetwork is a Network that supports checksumming.
|
||||||
|
type ChecksummableNetwork interface {
|
||||||
|
Network
|
||||||
|
|
||||||
|
// SetSourceAddressAndChecksum sets the source address and updates the
|
||||||
|
// checksum to reflect the new address.
|
||||||
|
SetSourceAddressWithChecksumUpdate(tcpip.Address)
|
||||||
|
|
||||||
|
// SetDestinationAddressAndChecksum sets the destination address and
|
||||||
|
// updates the checksum to reflect the new address.
|
||||||
|
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
|
||||||
|
}
|
|
@ -18,10 +18,8 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/sagernet/gvisor/pkg/buffer"
|
|
||||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
@ -145,79 +143,6 @@ func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) i
|
||||||
return ((padLen + align - 1) & ^(align - 1)) - padLen
|
return ((padLen + align - 1) & ^(align - 1)) - padLen
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6PayloadHeader is implemented by the various headers that can be found
|
|
||||||
// in an IPv6 payload.
|
|
||||||
//
|
|
||||||
// These headers include IPv6 extension headers or upper layer data.
|
|
||||||
type IPv6PayloadHeader interface {
|
|
||||||
isIPv6PayloadHeader()
|
|
||||||
|
|
||||||
// Release frees all resources held by the header.
|
|
||||||
Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
|
|
||||||
// encounters a Next Header field it does not recognize as an IPv6 extension
|
|
||||||
// header. The caller is responsible for releasing the underlying buffer after
|
|
||||||
// it's no longer needed.
|
|
||||||
type IPv6RawPayloadHeader struct {
|
|
||||||
Identifier IPv6ExtensionHeaderIdentifier
|
|
||||||
Buf buffer.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
|
||||||
func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
|
|
||||||
|
|
||||||
// Release implements IPv6PayloadHeader.Release.
|
|
||||||
func (i IPv6RawPayloadHeader) Release() {
|
|
||||||
i.Buf.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
|
|
||||||
type ipv6OptionsExtHdr struct {
|
|
||||||
buf *buffer.View
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release implements IPv6PayloadHeader.Release.
|
|
||||||
func (i ipv6OptionsExtHdr) Release() {
|
|
||||||
if i.buf != nil {
|
|
||||||
i.buf.Release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iter returns an iterator over the IPv6 extension header options held in b.
|
|
||||||
func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
|
|
||||||
it := IPv6OptionsExtHdrOptionsIterator{}
|
|
||||||
it.reader = i.buf
|
|
||||||
return it
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
|
|
||||||
// options.
|
|
||||||
//
|
|
||||||
// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
|
|
||||||
// used, no changes to the underlying buffer may happen. Doing so may cause
|
|
||||||
// undefined and unexpected behaviour. It is fine to obtain an
|
|
||||||
// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
|
|
||||||
// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
|
|
||||||
// obtained before modification is no longer used.
|
|
||||||
type IPv6OptionsExtHdrOptionsIterator struct {
|
|
||||||
reader *buffer.View
|
|
||||||
|
|
||||||
// optionOffset is the number of bytes from the first byte of the
|
|
||||||
// options field to the beginning of the current option.
|
|
||||||
optionOffset uint32
|
|
||||||
|
|
||||||
// nextOptionOffset is the offset of the next option.
|
|
||||||
nextOptionOffset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptionOffset returns the number of bytes parsed while processing the
|
|
||||||
// option field of the current Extension Header.
|
|
||||||
func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
|
|
||||||
return i.optionOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6OptionUnknownAction is the action that must be taken if the processing
|
// IPv6OptionUnknownAction is the action that must be taken if the processing
|
||||||
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
|
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
|
||||||
type IPv6OptionUnknownAction int
|
type IPv6OptionUnknownAction int
|
||||||
|
@ -294,143 +219,6 @@ func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUn
|
||||||
// is malformed.
|
// is malformed.
|
||||||
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
|
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
|
||||||
|
|
||||||
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
|
|
||||||
// header option that is unknown by the parsing utilities.
|
|
||||||
type IPv6UnknownExtHdrOption struct {
|
|
||||||
Identifier IPv6ExtHdrOptionIdentifier
|
|
||||||
Data *buffer.View
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
|
|
||||||
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
|
|
||||||
return ipv6UnknownActionFromIdentifier(o.Identifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
|
|
||||||
func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
|
|
||||||
|
|
||||||
// Next returns the next option in the options data.
|
|
||||||
//
|
|
||||||
// If the next item is not a known extension header option,
|
|
||||||
// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
|
|
||||||
//
|
|
||||||
// The return is of the format (option, done, error). done will be true when
|
|
||||||
// Next is unable to return anything because the iterator has reached the end of
|
|
||||||
// the options data, or an error occurred.
|
|
||||||
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
|
|
||||||
for {
|
|
||||||
i.optionOffset = i.nextOptionOffset
|
|
||||||
temp, err := i.reader.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
// If we can't read the first byte of a new option, then we know the
|
|
||||||
// options buffer has been exhausted and we are done iterating.
|
|
||||||
return nil, true, nil
|
|
||||||
}
|
|
||||||
id := IPv6ExtHdrOptionIdentifier(temp)
|
|
||||||
|
|
||||||
// If the option identifier indicates the option is a Pad1 option, then we
|
|
||||||
// know the option does not have Length and Data fields. End processing of
|
|
||||||
// the Pad1 option and continue processing the buffer as a new option.
|
|
||||||
if id == ipv6Pad1ExtHdrOptionIdentifier {
|
|
||||||
i.nextOptionOffset = i.optionOffset + 1
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
length, err := i.reader.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
// ReadByte should only ever return nil or io.EOF.
|
|
||||||
panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
|
|
||||||
// we start parsing an option; we expect the reader to contain enough
|
|
||||||
// bytes for the whole option.
|
|
||||||
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do we have enough bytes in the reader for the next option?
|
|
||||||
if n := i.reader.Size(); n < int(length) {
|
|
||||||
// Consume the remaining buffer.
|
|
||||||
i.reader.TrimFront(i.reader.Size())
|
|
||||||
|
|
||||||
// We return the same error as if we failed to read a non-padding option
|
|
||||||
// so consumers of this iterator don't need to differentiate between
|
|
||||||
// padding and non-padding options.
|
|
||||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
|
|
||||||
|
|
||||||
switch id {
|
|
||||||
case ipv6PadNExtHdrOptionIdentifier:
|
|
||||||
// Special-case the variable length padding option to avoid a copy.
|
|
||||||
i.reader.TrimFront(int(length))
|
|
||||||
continue
|
|
||||||
case ipv6RouterAlertHopByHopOptionIdentifier:
|
|
||||||
var routerAlertValue [ipv6RouterAlertPayloadLength]byte
|
|
||||||
if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil {
|
|
||||||
switch err {
|
|
||||||
case io.EOF, io.ErrUnexpectedEOF:
|
|
||||||
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
|
|
||||||
default:
|
|
||||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
|
|
||||||
}
|
|
||||||
} else if n != int(length) {
|
|
||||||
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
|
|
||||||
}
|
|
||||||
return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
|
|
||||||
default:
|
|
||||||
bytes := buffer.NewView(int(length))
|
|
||||||
if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
|
|
||||||
}
|
|
||||||
return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
|
|
||||||
// extension header.
|
|
||||||
type IPv6HopByHopOptionsExtHdr struct {
|
|
||||||
ipv6OptionsExtHdr
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
|
||||||
func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
|
|
||||||
|
|
||||||
// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
|
|
||||||
// extension header.
|
|
||||||
type IPv6DestinationOptionsExtHdr struct {
|
|
||||||
ipv6OptionsExtHdr
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
|
||||||
func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
|
|
||||||
|
|
||||||
// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
|
|
||||||
// data as outlined in RFC 8200 section 4.4.
|
|
||||||
type IPv6RoutingExtHdr struct {
|
|
||||||
Buf *buffer.View
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
|
||||||
func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
|
|
||||||
|
|
||||||
// Release implements IPv6PayloadHeader.Release.
|
|
||||||
func (b IPv6RoutingExtHdr) Release() {
|
|
||||||
b.Buf.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SegmentsLeft returns the Segments Left field.
|
|
||||||
func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
|
|
||||||
return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx]
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
|
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
|
||||||
// data as outlined in RFC 8200 section 4.5.
|
// data as outlined in RFC 8200 section 4.5.
|
||||||
//
|
//
|
||||||
|
@ -473,242 +261,6 @@ func (b IPv6FragmentExtHdr) IsAtomic() bool {
|
||||||
return !b.More() && b.FragmentOffset() == 0
|
return !b.More() && b.FragmentOffset() == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
|
|
||||||
//
|
|
||||||
// The IPv6 payload may contain IPv6 extension headers before any upper layer
|
|
||||||
// data.
|
|
||||||
//
|
|
||||||
// Note, between when an IPv6PayloadIterator is obtained and last used, no
|
|
||||||
// changes to the payload may happen. Doing so may cause undefined and
|
|
||||||
// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
|
|
||||||
// over the first few headers then modify the backing payload so long as the
|
|
||||||
// IPv6PayloadIterator obtained before modification is no longer used.
|
|
||||||
type IPv6PayloadIterator struct {
|
|
||||||
// The identifier of the next header to parse.
|
|
||||||
nextHdrIdentifier IPv6ExtensionHeaderIdentifier
|
|
||||||
|
|
||||||
payload buffer.Buffer
|
|
||||||
|
|
||||||
// Indicates to the iterator that it should return the remaining payload as a
|
|
||||||
// raw payload on the next call to Next.
|
|
||||||
forceRaw bool
|
|
||||||
|
|
||||||
// headerOffset is the offset of the beginning of the current extension
|
|
||||||
// header starting from the beginning of the fixed header.
|
|
||||||
headerOffset uint32
|
|
||||||
|
|
||||||
// parseOffset is the byte offset into the current extension header of the
|
|
||||||
// field we are currently examining. It can be added to the header offset
|
|
||||||
// if the absolute offset within the packet is required.
|
|
||||||
parseOffset uint32
|
|
||||||
|
|
||||||
// nextOffset is the offset of the next header.
|
|
||||||
nextOffset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// HeaderOffset returns the offset to the start of the extension
|
|
||||||
// header most recently processed.
|
|
||||||
func (i IPv6PayloadIterator) HeaderOffset() uint32 {
|
|
||||||
return i.headerOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseOffset returns the number of bytes successfully parsed.
|
|
||||||
func (i IPv6PayloadIterator) ParseOffset() uint32 {
|
|
||||||
return i.headerOffset + i.parseOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
|
|
||||||
// extension headers, or a raw payload if the payload cannot be parsed. The
|
|
||||||
// iterator takes ownership of the payload.
|
|
||||||
func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator {
|
|
||||||
return IPv6PayloadIterator{
|
|
||||||
nextHdrIdentifier: nextHdrIdentifier,
|
|
||||||
payload: payload,
|
|
||||||
nextOffset: IPv6FixedHeaderSize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release frees the resources owned by the iterator.
|
|
||||||
func (i *IPv6PayloadIterator) Release() {
|
|
||||||
i.payload.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsRawHeader returns the remaining payload of i as a raw header and
|
|
||||||
// optionally consumes the iterator.
|
|
||||||
//
|
|
||||||
// If consume is true, calls to Next after calling AsRawHeader on i will
|
|
||||||
// indicate that the iterator is done. The returned header takes ownership of
|
|
||||||
// its payload.
|
|
||||||
func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
|
|
||||||
identifier := i.nextHdrIdentifier
|
|
||||||
|
|
||||||
var buf buffer.Buffer
|
|
||||||
if consume {
|
|
||||||
// Since we consume the iterator, we return the payload as is.
|
|
||||||
buf = i.payload
|
|
||||||
|
|
||||||
// Mark i as done, but keep track of where we were for error reporting.
|
|
||||||
*i = IPv6PayloadIterator{
|
|
||||||
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
|
|
||||||
headerOffset: i.headerOffset,
|
|
||||||
nextOffset: i.nextOffset,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf = i.payload.Clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next returns the next item in the payload.
|
|
||||||
//
|
|
||||||
// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
|
|
||||||
// will be returned with the remaining bytes and next header identifier.
|
|
||||||
//
|
|
||||||
// The return is of the format (header, done, error). done will be true when
|
|
||||||
// Next is unable to return anything because the iterator has reached the end of
|
|
||||||
// the payload, or an error occurred.
|
|
||||||
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
|
|
||||||
i.headerOffset = i.nextOffset
|
|
||||||
i.parseOffset = 0
|
|
||||||
// We could be forced to return i as a raw header when the previous header was
|
|
||||||
// a fragment extension header as the data following the fragment extension
|
|
||||||
// header may not be complete.
|
|
||||||
if i.forceRaw {
|
|
||||||
return i.AsRawHeader(true /* consume */), false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is the header we are parsing a known extension header?
|
|
||||||
switch i.nextHdrIdentifier {
|
|
||||||
case IPv6HopByHopOptionsExtHdrIdentifier:
|
|
||||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i.nextHdrIdentifier = nextHdrIdentifier
|
|
||||||
return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
|
|
||||||
case IPv6RoutingExtHdrIdentifier:
|
|
||||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i.nextHdrIdentifier = nextHdrIdentifier
|
|
||||||
return IPv6RoutingExtHdr{view}, false, nil
|
|
||||||
case IPv6FragmentExtHdrIdentifier:
|
|
||||||
var data [6]byte
|
|
||||||
// We ignore the returned bytes because we know the fragment extension
|
|
||||||
// header specific data will fit in data.
|
|
||||||
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fragmentExtHdr := IPv6FragmentExtHdr(data)
|
|
||||||
|
|
||||||
// If the packet is not the first fragment, do not attempt to parse anything
|
|
||||||
// after the fragment extension header as the payload following the fragment
|
|
||||||
// extension header should not contain any headers; the first fragment must
|
|
||||||
// hold all the headers up to and including any upper layer headers, as per
|
|
||||||
// RFC 8200 section 4.5.
|
|
||||||
if fragmentExtHdr.FragmentOffset() != 0 {
|
|
||||||
i.forceRaw = true
|
|
||||||
}
|
|
||||||
|
|
||||||
i.nextHdrIdentifier = nextHdrIdentifier
|
|
||||||
return fragmentExtHdr, false, nil
|
|
||||||
case IPv6DestinationOptionsExtHdrIdentifier:
|
|
||||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i.nextHdrIdentifier = nextHdrIdentifier
|
|
||||||
return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
|
|
||||||
case IPv6NoNextHeaderIdentifier:
|
|
||||||
// This indicates the end of the IPv6 payload.
|
|
||||||
return nil, true, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// The header we are parsing is not a known extension header. Return the
|
|
||||||
// raw payload.
|
|
||||||
return i.AsRawHeader(true /* consume */), false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NextHeaderIdentifier returns the identifier of the header next returned by
|
|
||||||
// it.Next().
|
|
||||||
func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier {
|
|
||||||
return i.nextHdrIdentifier
|
|
||||||
}
|
|
||||||
|
|
||||||
// nextHeaderData returns the extension header's Next Header field and raw data.
|
|
||||||
//
|
|
||||||
// fragmentHdr indicates that the extension header being parsed is the Fragment
|
|
||||||
// extension header so the Length field should be ignored as it is Reserved
|
|
||||||
// for the Fragment extension header.
|
|
||||||
//
|
|
||||||
// If bytes is not nil, extension header specific data will be read into bytes
|
|
||||||
// if it has enough capacity. If bytes is provided but does not have enough
|
|
||||||
// capacity for the data, nextHeaderData will panic.
|
|
||||||
func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
|
|
||||||
// We ignore the number of bytes read because we know we will only ever read
|
|
||||||
// at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
|
|
||||||
// would return io.EOF to indicate that io.Reader has reached the end of the
|
|
||||||
// payload.
|
|
||||||
rdr := i.payload.AsBufferReader()
|
|
||||||
nextHdrIdentifier, err := rdr.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
|
||||||
}
|
|
||||||
i.parseOffset++
|
|
||||||
|
|
||||||
var length uint8
|
|
||||||
length, err = rdr.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
if fragmentHdr {
|
|
||||||
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
|
||||||
}
|
|
||||||
if fragmentHdr {
|
|
||||||
length = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make parseOffset point to the first byte of the Extension Header
|
|
||||||
// specific data.
|
|
||||||
i.parseOffset++
|
|
||||||
|
|
||||||
// length is in 8 byte chunks but doesn't include the first one.
|
|
||||||
// See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
|
|
||||||
// in section 4.8 for new extension headers at the top of page 24.
|
|
||||||
// [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
|
|
||||||
// units, not including the first 8 octets.
|
|
||||||
i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
|
|
||||||
|
|
||||||
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
|
|
||||||
if fragmentHdr {
|
|
||||||
if n := len(bytes); n < bytesLen {
|
|
||||||
panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
|
|
||||||
}
|
|
||||||
if n, err := io.ReadFull(&rdr, bytes); err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
|
|
||||||
}
|
|
||||||
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil
|
|
||||||
}
|
|
||||||
v := buffer.NewView(bytesLen)
|
|
||||||
if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
v.Release()
|
|
||||||
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
|
|
||||||
}
|
|
||||||
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6SerializableExtHdr provides serialization for IPv6 extension
|
// IPv6SerializableExtHdr provides serialization for IPv6 extension
|
||||||
// headers.
|
// headers.
|
||||||
type IPv6SerializableExtHdr interface {
|
type IPv6SerializableExtHdr interface {
|
||||||
|
|
712
internal/tschecksum/checksum.go
Normal file
712
internal/tschecksum/checksum.go
Normal file
|
@ -0,0 +1,712 @@
|
||||||
|
package tschecksum
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"math/bits"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/sys/cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
// checksumGeneric64 is a reference implementation of checksum using 64 bit
|
||||||
|
// arithmetic for use in testing or when an architecture-specific implementation
|
||||||
|
// is not available.
|
||||||
|
func checksumGeneric64(b []byte, initial uint16) uint16 {
|
||||||
|
var ac uint64
|
||||||
|
var carry uint64
|
||||||
|
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac = uint64(initial)
|
||||||
|
} else {
|
||||||
|
ac = uint64(bits.ReverseBytes16(initial))
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(b) >= 128 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
|
||||||
|
}
|
||||||
|
b = b[128:]
|
||||||
|
}
|
||||||
|
if len(b) >= 64 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||||
|
}
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
if len(b) >= 32 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||||
|
}
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
|
||||||
|
}
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
|
||||||
|
}
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
|
||||||
|
}
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) >= 1 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
folded := ipChecksumFold64(ac, carry)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
folded = bits.ReverseBytes16(folded)
|
||||||
|
}
|
||||||
|
return folded
|
||||||
|
}
|
||||||
|
|
||||||
|
// checksumGeneric32 is a reference implementation of checksum using 32 bit
|
||||||
|
// arithmetic for use in testing or when an architecture-specific implementation
|
||||||
|
// is not available.
|
||||||
|
func checksumGeneric32(b []byte, initial uint16) uint16 {
|
||||||
|
var ac uint32
|
||||||
|
var carry uint32
|
||||||
|
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac = uint32(initial)
|
||||||
|
} else {
|
||||||
|
ac = uint32(bits.ReverseBytes16(initial))
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(b) >= 64 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
|
||||||
|
}
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
if len(b) >= 32 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||||
|
}
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||||
|
}
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
|
||||||
|
}
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
|
||||||
|
}
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) >= 1 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
|
||||||
|
} else {
|
||||||
|
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
folded := ipChecksumFold32(ac, carry)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
folded = bits.ReverseBytes16(folded)
|
||||||
|
}
|
||||||
|
return folded
|
||||||
|
}
|
||||||
|
|
||||||
|
// checksumGeneric32Alternate is an alternate reference implementation of
|
||||||
|
// checksum using 32 bit arithmetic for use in testing or when an
|
||||||
|
// architecture-specific implementation is not available.
|
||||||
|
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
|
||||||
|
var ac uint32
|
||||||
|
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac = uint32(initial)
|
||||||
|
} else {
|
||||||
|
ac = uint32(bits.ReverseBytes16(initial))
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(b) >= 64 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
|
||||||
|
}
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
if len(b) >= 32 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||||
|
}
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||||
|
}
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||||
|
}
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(binary.BigEndian.Uint16(b))
|
||||||
|
} else {
|
||||||
|
ac += uint32(binary.LittleEndian.Uint16(b))
|
||||||
|
}
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) >= 1 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint32(b[0]) << 8
|
||||||
|
} else {
|
||||||
|
ac += uint32(b[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
folded := ipChecksumFold32(ac, 0)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
folded = bits.ReverseBytes16(folded)
|
||||||
|
}
|
||||||
|
return folded
|
||||||
|
}
|
||||||
|
|
||||||
|
// checksumGeneric64Alternate is an alternate reference implementation of
|
||||||
|
// checksum using 64 bit arithmetic for use in testing or when an
|
||||||
|
// architecture-specific implementation is not available.
|
||||||
|
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
|
||||||
|
var ac uint64
|
||||||
|
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac = uint64(initial)
|
||||||
|
} else {
|
||||||
|
ac = uint64(bits.ReverseBytes16(initial))
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(b) >= 64 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
|
||||||
|
}
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
if len(b) >= 32 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||||
|
}
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||||
|
}
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||||
|
}
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint32(b))
|
||||||
|
}
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(binary.BigEndian.Uint16(b))
|
||||||
|
} else {
|
||||||
|
ac += uint64(binary.LittleEndian.Uint16(b))
|
||||||
|
}
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) >= 1 {
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
ac += uint64(b[0]) << 8
|
||||||
|
} else {
|
||||||
|
ac += uint64(b[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
folded := ipChecksumFold64(ac, 0)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
folded = bits.ReverseBytes16(folded)
|
||||||
|
}
|
||||||
|
return folded
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
|
||||||
|
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
|
||||||
|
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
|
||||||
|
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
|
||||||
|
// no need to save the carry flag
|
||||||
|
sum = (sum >> 16) + (sum & 0xffff) + carry
|
||||||
|
// sum <= 0x1_fffe therefore this is the last fold needed:
|
||||||
|
// if (sum >> 16) > 0 then
|
||||||
|
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
|
||||||
|
// the addition will not overflow
|
||||||
|
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||||
|
sum = (sum >> 16) + (sum & 0xffff)
|
||||||
|
return uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
|
||||||
|
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
|
||||||
|
// sum <= 0x1_ffff:
|
||||||
|
// 0xffff + 0xffff = 0x1_fffe
|
||||||
|
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
|
||||||
|
sum = (sum >> 16) + (sum & 0xffff)
|
||||||
|
// sum <= 0x1_0000 therefore this is the last fold needed:
|
||||||
|
// if (sum >> 16) > 0 then
|
||||||
|
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
|
||||||
|
// the addition will not overflow
|
||||||
|
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||||
|
sum = (sum >> 16) + (sum & 0xffff)
|
||||||
|
return uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
|
||||||
|
sum, carry = initial, carryIn
|
||||||
|
switch len(addr) {
|
||||||
|
case 4: // IPv4
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
|
||||||
|
} else {
|
||||||
|
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
|
||||||
|
}
|
||||||
|
case 16: // IPv6
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
|
||||||
|
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
|
||||||
|
} else {
|
||||||
|
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
|
||||||
|
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("bad addr length")
|
||||||
|
}
|
||||||
|
return sum, carry
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
|
||||||
|
sum, carry = initial, carryIn
|
||||||
|
switch len(addr) {
|
||||||
|
case 4: // IPv4
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||||
|
} else {
|
||||||
|
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||||
|
}
|
||||||
|
case 16: // IPv6
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
|
||||||
|
} else {
|
||||||
|
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
|
||||||
|
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("bad addr length")
|
||||||
|
}
|
||||||
|
return sum, carry
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||||
|
var sum uint64
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum = uint64(totalLen) + uint64(protocol)
|
||||||
|
} else {
|
||||||
|
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
|
||||||
|
}
|
||||||
|
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
|
||||||
|
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
|
||||||
|
|
||||||
|
foldedSum := ipChecksumFold64(sum, carry)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||||
|
}
|
||||||
|
return foldedSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||||
|
var sum uint32
|
||||||
|
if cpu.IsBigEndian {
|
||||||
|
sum = uint32(totalLen) + uint32(protocol)
|
||||||
|
} else {
|
||||||
|
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
|
||||||
|
}
|
||||||
|
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
|
||||||
|
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
|
||||||
|
|
||||||
|
foldedSum := ipChecksumFold32(sum, carry)
|
||||||
|
if !cpu.IsBigEndian {
|
||||||
|
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||||
|
}
|
||||||
|
return foldedSum
|
||||||
|
}
|
||||||
|
|
||||||
|
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
|
||||||
|
// dstAddr must be 4 or 16 bytes in length.
|
||||||
|
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||||
|
if strconv.IntSize < 64 {
|
||||||
|
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
|
||||||
|
}
|
||||||
|
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
|
||||||
|
}
|
23
internal/tschecksum/checksum_amd64.go
Normal file
23
internal/tschecksum/checksum_amd64.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package tschecksum
|
||||||
|
|
||||||
|
import "golang.org/x/sys/cpu"
|
||||||
|
|
||||||
|
var checksum = checksumAMD64
|
||||||
|
|
||||||
|
// Checksum computes an IP checksum starting with the provided initial value.
|
||||||
|
// The length of data should be at least 128 bytes for best performance. Smaller
|
||||||
|
// buffers will still compute a correct result.
|
||||||
|
func Checksum(data []byte, initial uint16) uint16 {
|
||||||
|
return checksum(data, initial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
|
||||||
|
checksum = checksumAVX2
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cpu.X86.HasSSE2 {
|
||||||
|
checksum = checksumSSE2
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
18
internal/tschecksum/checksum_generated_amd64.go
Normal file
18
internal/tschecksum/checksum_generated_amd64.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||||
|
|
||||||
|
package tschecksum
|
||||||
|
|
||||||
|
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func checksumAVX2(b []byte, initial uint16) uint16
|
||||||
|
|
||||||
|
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func checksumSSE2(b []byte, initial uint16) uint16
|
||||||
|
|
||||||
|
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func checksumAMD64(b []byte, initial uint16) uint16
|
851
internal/tschecksum/checksum_generated_amd64.s
Normal file
851
internal/tschecksum/checksum_generated_amd64.s
Normal file
|
@ -0,0 +1,851 @@
|
||||||
|
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||||
|
|
||||||
|
#include "textflag.h"
|
||||||
|
|
||||||
|
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||||
|
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||||
|
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
|
||||||
|
|
||||||
|
// func checksumAVX2(b []byte, initial uint16) uint16
|
||||||
|
// Requires: AVX, AVX2, BMI2
|
||||||
|
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
|
||||||
|
MOVWQZX initial+24(FP), AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVQ b_base+0(FP), DX
|
||||||
|
MOVQ b_len+8(FP), BX
|
||||||
|
|
||||||
|
// handle odd length buffers; they are difficult to handle in general
|
||||||
|
TESTQ $0x00000001, BX
|
||||||
|
JZ lengthIsEven
|
||||||
|
MOVBQZX -1(DX)(BX*1), CX
|
||||||
|
DECQ BX
|
||||||
|
ADDQ CX, AX
|
||||||
|
|
||||||
|
lengthIsEven:
|
||||||
|
// handle tiny buffers (<=31 bytes) specially
|
||||||
|
CMPQ BX, $0x1f
|
||||||
|
JGT bufferIsNotTiny
|
||||||
|
XORQ CX, CX
|
||||||
|
XORQ SI, SI
|
||||||
|
XORQ DI, DI
|
||||||
|
|
||||||
|
// shift twice to start because length is guaranteed to be even
|
||||||
|
// n = n >> 2; CF = originalN & 2
|
||||||
|
SHRQ $0x02, BX
|
||||||
|
JNC handleTiny4
|
||||||
|
|
||||||
|
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||||
|
MOVWQZX (DX), CX
|
||||||
|
ADDQ $0x02, DX
|
||||||
|
|
||||||
|
handleTiny4:
|
||||||
|
// n = n >> 1; CF = originalN & 4
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny8
|
||||||
|
|
||||||
|
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||||
|
MOVLQZX (DX), SI
|
||||||
|
ADDQ $0x04, DX
|
||||||
|
|
||||||
|
handleTiny8:
|
||||||
|
// n = n >> 1; CF = originalN & 8
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny16
|
||||||
|
|
||||||
|
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||||
|
MOVQ (DX), DI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleTiny16:
|
||||||
|
// n = n >> 1; CF = originalN & 16
|
||||||
|
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTinyFinish
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
|
||||||
|
handleTinyFinish:
|
||||||
|
// CF should be included from the previous add, so we use ADCQ.
|
||||||
|
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||||
|
// so ADCQ will still produce the correct result.
|
||||||
|
ADCQ CX, AX
|
||||||
|
ADCQ SI, AX
|
||||||
|
ADCQ DI, AX
|
||||||
|
JMP foldAndReturn
|
||||||
|
|
||||||
|
bufferIsNotTiny:
|
||||||
|
// skip all SIMD for small buffers
|
||||||
|
CMPQ BX, $0x00000100
|
||||||
|
JGE startSIMD
|
||||||
|
|
||||||
|
// Accumulate carries in this register. It is never expected to overflow.
|
||||||
|
XORQ SI, SI
|
||||||
|
|
||||||
|
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||||
|
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||||
|
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||||
|
// preserve any alignment that may exist for the start of the buffer.
|
||||||
|
MOVQ BX, CX
|
||||||
|
SHRQ $0x03, BX
|
||||||
|
ANDQ $0x07, CX
|
||||||
|
JZ handleRemaining8
|
||||||
|
LEAQ (DX)(BX*8), DI
|
||||||
|
MOVQ -8(DI)(CX*1), DI
|
||||||
|
|
||||||
|
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||||
|
SHLQ $0x03, CX
|
||||||
|
NEGQ CX
|
||||||
|
ADDQ $0x40, CX
|
||||||
|
SHRQ CL, DI
|
||||||
|
ADDQ DI, AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
|
||||||
|
handleRemaining8:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining16
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleRemaining16:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining32
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x10, DX
|
||||||
|
|
||||||
|
handleRemaining32:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining64
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x20, DX
|
||||||
|
|
||||||
|
handleRemaining64:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining128
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x40, DX
|
||||||
|
|
||||||
|
handleRemaining128:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemainingComplete
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ 64(DX), AX
|
||||||
|
ADCQ 72(DX), AX
|
||||||
|
ADCQ 80(DX), AX
|
||||||
|
ADCQ 88(DX), AX
|
||||||
|
ADCQ 96(DX), AX
|
||||||
|
ADCQ 104(DX), AX
|
||||||
|
ADCQ 112(DX), AX
|
||||||
|
ADCQ 120(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x80, DX
|
||||||
|
|
||||||
|
handleRemainingComplete:
|
||||||
|
ADDQ SI, AX
|
||||||
|
JMP foldAndReturn
|
||||||
|
|
||||||
|
startSIMD:
|
||||||
|
VPXOR Y0, Y0, Y0
|
||||||
|
VPXOR Y1, Y1, Y1
|
||||||
|
VPXOR Y2, Y2, Y2
|
||||||
|
VPXOR Y3, Y3, Y3
|
||||||
|
MOVQ BX, CX
|
||||||
|
|
||||||
|
// Update number of bytes remaining after the loop completes
|
||||||
|
ANDQ $0xff, BX
|
||||||
|
|
||||||
|
// Number of 256 byte iterations
|
||||||
|
SHRQ $0x08, CX
|
||||||
|
JZ smallLoop
|
||||||
|
|
||||||
|
bigLoop:
|
||||||
|
VPMOVZXWD (DX), Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
VPMOVZXWD 16(DX), Y4
|
||||||
|
VPADDD Y4, Y1, Y1
|
||||||
|
VPMOVZXWD 32(DX), Y4
|
||||||
|
VPADDD Y4, Y2, Y2
|
||||||
|
VPMOVZXWD 48(DX), Y4
|
||||||
|
VPADDD Y4, Y3, Y3
|
||||||
|
VPMOVZXWD 64(DX), Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
VPMOVZXWD 80(DX), Y4
|
||||||
|
VPADDD Y4, Y1, Y1
|
||||||
|
VPMOVZXWD 96(DX), Y4
|
||||||
|
VPADDD Y4, Y2, Y2
|
||||||
|
VPMOVZXWD 112(DX), Y4
|
||||||
|
VPADDD Y4, Y3, Y3
|
||||||
|
VPMOVZXWD 128(DX), Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
VPMOVZXWD 144(DX), Y4
|
||||||
|
VPADDD Y4, Y1, Y1
|
||||||
|
VPMOVZXWD 160(DX), Y4
|
||||||
|
VPADDD Y4, Y2, Y2
|
||||||
|
VPMOVZXWD 176(DX), Y4
|
||||||
|
VPADDD Y4, Y3, Y3
|
||||||
|
VPMOVZXWD 192(DX), Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
VPMOVZXWD 208(DX), Y4
|
||||||
|
VPADDD Y4, Y1, Y1
|
||||||
|
VPMOVZXWD 224(DX), Y4
|
||||||
|
VPADDD Y4, Y2, Y2
|
||||||
|
VPMOVZXWD 240(DX), Y4
|
||||||
|
VPADDD Y4, Y3, Y3
|
||||||
|
ADDQ $0x00000100, DX
|
||||||
|
DECQ CX
|
||||||
|
JNZ bigLoop
|
||||||
|
CMPQ BX, $0x10
|
||||||
|
JLT doneSmallLoop
|
||||||
|
|
||||||
|
// now read a single 16 byte unit of data at a time
|
||||||
|
smallLoop:
|
||||||
|
VPMOVZXWD (DX), Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
ADDQ $0x10, DX
|
||||||
|
SUBQ $0x10, BX
|
||||||
|
CMPQ BX, $0x10
|
||||||
|
JGE smallLoop
|
||||||
|
|
||||||
|
doneSmallLoop:
|
||||||
|
CMPQ BX, $0x00
|
||||||
|
JE doneSIMD
|
||||||
|
|
||||||
|
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||||
|
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||||
|
VMOVDQU -16(DX)(BX*1), X4
|
||||||
|
VPAND -16(CX)(BX*8), X4, X4
|
||||||
|
VPMOVZXWD X4, Y4
|
||||||
|
VPADDD Y4, Y0, Y0
|
||||||
|
|
||||||
|
doneSIMD:
|
||||||
|
// Multi-chain loop is done, combine the accumulators
|
||||||
|
VPADDD Y1, Y0, Y0
|
||||||
|
VPADDD Y2, Y0, Y0
|
||||||
|
VPADDD Y3, Y0, Y0
|
||||||
|
|
||||||
|
// extract the YMM into a pair of XMM and sum them
|
||||||
|
VEXTRACTI128 $0x01, Y0, X1
|
||||||
|
VPADDD X0, X1, X0
|
||||||
|
|
||||||
|
// extract the XMM into GP64
|
||||||
|
VPEXTRQ $0x00, X0, CX
|
||||||
|
VPEXTRQ $0x01, X0, DX
|
||||||
|
|
||||||
|
// no more AVX code, clear upper registers to avoid SSE slowdowns
|
||||||
|
VZEROUPPER
|
||||||
|
ADDQ CX, AX
|
||||||
|
ADCQ DX, AX
|
||||||
|
|
||||||
|
foldAndReturn:
|
||||||
|
// add CF and fold
|
||||||
|
RORXQ $0x20, AX, CX
|
||||||
|
ADCL CX, AX
|
||||||
|
RORXL $0x10, AX, CX
|
||||||
|
ADCW CX, AX
|
||||||
|
ADCW $0x00, AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVW AX, ret+32(FP)
|
||||||
|
RET
|
||||||
|
|
||||||
|
// func checksumSSE2(b []byte, initial uint16) uint16
|
||||||
|
// Requires: SSE2
|
||||||
|
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
|
||||||
|
MOVWQZX initial+24(FP), AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVQ b_base+0(FP), DX
|
||||||
|
MOVQ b_len+8(FP), BX
|
||||||
|
|
||||||
|
// handle odd length buffers; they are difficult to handle in general
|
||||||
|
TESTQ $0x00000001, BX
|
||||||
|
JZ lengthIsEven
|
||||||
|
MOVBQZX -1(DX)(BX*1), CX
|
||||||
|
DECQ BX
|
||||||
|
ADDQ CX, AX
|
||||||
|
|
||||||
|
lengthIsEven:
|
||||||
|
// handle tiny buffers (<=31 bytes) specially
|
||||||
|
CMPQ BX, $0x1f
|
||||||
|
JGT bufferIsNotTiny
|
||||||
|
XORQ CX, CX
|
||||||
|
XORQ SI, SI
|
||||||
|
XORQ DI, DI
|
||||||
|
|
||||||
|
// shift twice to start because length is guaranteed to be even
|
||||||
|
// n = n >> 2; CF = originalN & 2
|
||||||
|
SHRQ $0x02, BX
|
||||||
|
JNC handleTiny4
|
||||||
|
|
||||||
|
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||||
|
MOVWQZX (DX), CX
|
||||||
|
ADDQ $0x02, DX
|
||||||
|
|
||||||
|
handleTiny4:
|
||||||
|
// n = n >> 1; CF = originalN & 4
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny8
|
||||||
|
|
||||||
|
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||||
|
MOVLQZX (DX), SI
|
||||||
|
ADDQ $0x04, DX
|
||||||
|
|
||||||
|
handleTiny8:
|
||||||
|
// n = n >> 1; CF = originalN & 8
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny16
|
||||||
|
|
||||||
|
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||||
|
MOVQ (DX), DI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleTiny16:
|
||||||
|
// n = n >> 1; CF = originalN & 16
|
||||||
|
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTinyFinish
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
|
||||||
|
handleTinyFinish:
|
||||||
|
// CF should be included from the previous add, so we use ADCQ.
|
||||||
|
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||||
|
// so ADCQ will still produce the correct result.
|
||||||
|
ADCQ CX, AX
|
||||||
|
ADCQ SI, AX
|
||||||
|
ADCQ DI, AX
|
||||||
|
JMP foldAndReturn
|
||||||
|
|
||||||
|
bufferIsNotTiny:
|
||||||
|
// skip all SIMD for small buffers
|
||||||
|
CMPQ BX, $0x00000100
|
||||||
|
JGE startSIMD
|
||||||
|
|
||||||
|
// Accumulate carries in this register. It is never expected to overflow.
|
||||||
|
XORQ SI, SI
|
||||||
|
|
||||||
|
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||||
|
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||||
|
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||||
|
// preserve any alignment that may exist for the start of the buffer.
|
||||||
|
MOVQ BX, CX
|
||||||
|
SHRQ $0x03, BX
|
||||||
|
ANDQ $0x07, CX
|
||||||
|
JZ handleRemaining8
|
||||||
|
LEAQ (DX)(BX*8), DI
|
||||||
|
MOVQ -8(DI)(CX*1), DI
|
||||||
|
|
||||||
|
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||||
|
SHLQ $0x03, CX
|
||||||
|
NEGQ CX
|
||||||
|
ADDQ $0x40, CX
|
||||||
|
SHRQ CL, DI
|
||||||
|
ADDQ DI, AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
|
||||||
|
handleRemaining8:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining16
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleRemaining16:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining32
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x10, DX
|
||||||
|
|
||||||
|
handleRemaining32:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining64
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x20, DX
|
||||||
|
|
||||||
|
handleRemaining64:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining128
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x40, DX
|
||||||
|
|
||||||
|
handleRemaining128:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemainingComplete
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ 64(DX), AX
|
||||||
|
ADCQ 72(DX), AX
|
||||||
|
ADCQ 80(DX), AX
|
||||||
|
ADCQ 88(DX), AX
|
||||||
|
ADCQ 96(DX), AX
|
||||||
|
ADCQ 104(DX), AX
|
||||||
|
ADCQ 112(DX), AX
|
||||||
|
ADCQ 120(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x80, DX
|
||||||
|
|
||||||
|
handleRemainingComplete:
|
||||||
|
ADDQ SI, AX
|
||||||
|
JMP foldAndReturn
|
||||||
|
|
||||||
|
startSIMD:
|
||||||
|
PXOR X0, X0
|
||||||
|
PXOR X1, X1
|
||||||
|
PXOR X2, X2
|
||||||
|
PXOR X3, X3
|
||||||
|
PXOR X4, X4
|
||||||
|
MOVQ BX, CX
|
||||||
|
|
||||||
|
// Update number of bytes remaining after the loop completes
|
||||||
|
ANDQ $0xff, BX
|
||||||
|
|
||||||
|
// Number of 256 byte iterations
|
||||||
|
SHRQ $0x08, CX
|
||||||
|
JZ smallLoop
|
||||||
|
|
||||||
|
bigLoop:
|
||||||
|
MOVOU (DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X2
|
||||||
|
MOVOU 16(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X1
|
||||||
|
PADDD X6, X3
|
||||||
|
MOVOU 32(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X2
|
||||||
|
PADDD X6, X0
|
||||||
|
MOVOU 48(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X3
|
||||||
|
PADDD X6, X1
|
||||||
|
MOVOU 64(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X2
|
||||||
|
MOVOU 80(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X1
|
||||||
|
PADDD X6, X3
|
||||||
|
MOVOU 96(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X2
|
||||||
|
PADDD X6, X0
|
||||||
|
MOVOU 112(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X3
|
||||||
|
PADDD X6, X1
|
||||||
|
MOVOU 128(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X2
|
||||||
|
MOVOU 144(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X1
|
||||||
|
PADDD X6, X3
|
||||||
|
MOVOU 160(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X2
|
||||||
|
PADDD X6, X0
|
||||||
|
MOVOU 176(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X3
|
||||||
|
PADDD X6, X1
|
||||||
|
MOVOU 192(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X2
|
||||||
|
MOVOU 208(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X1
|
||||||
|
PADDD X6, X3
|
||||||
|
MOVOU 224(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X2
|
||||||
|
PADDD X6, X0
|
||||||
|
MOVOU 240(DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X3
|
||||||
|
PADDD X6, X1
|
||||||
|
ADDQ $0x00000100, DX
|
||||||
|
DECQ CX
|
||||||
|
JNZ bigLoop
|
||||||
|
CMPQ BX, $0x10
|
||||||
|
JLT doneSmallLoop
|
||||||
|
|
||||||
|
// now read a single 16 byte unit of data at a time
|
||||||
|
smallLoop:
|
||||||
|
MOVOU (DX), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X1
|
||||||
|
ADDQ $0x10, DX
|
||||||
|
SUBQ $0x10, BX
|
||||||
|
CMPQ BX, $0x10
|
||||||
|
JGE smallLoop
|
||||||
|
|
||||||
|
doneSmallLoop:
|
||||||
|
CMPQ BX, $0x00
|
||||||
|
JE doneSIMD
|
||||||
|
|
||||||
|
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||||
|
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||||
|
MOVOU -16(DX)(BX*1), X5
|
||||||
|
PAND -16(CX)(BX*8), X5
|
||||||
|
MOVOA X5, X6
|
||||||
|
PUNPCKHWL X4, X5
|
||||||
|
PUNPCKLWL X4, X6
|
||||||
|
PADDD X5, X0
|
||||||
|
PADDD X6, X1
|
||||||
|
|
||||||
|
doneSIMD:
|
||||||
|
// Multi-chain loop is done, combine the accumulators
|
||||||
|
PADDD X1, X0
|
||||||
|
PADDD X2, X0
|
||||||
|
PADDD X3, X0
|
||||||
|
|
||||||
|
// extract the XMM into GP64
|
||||||
|
MOVQ X0, CX
|
||||||
|
PSRLDQ $0x08, X0
|
||||||
|
MOVQ X0, DX
|
||||||
|
ADDQ CX, AX
|
||||||
|
ADCQ DX, AX
|
||||||
|
|
||||||
|
foldAndReturn:
|
||||||
|
// add CF and fold
|
||||||
|
MOVL AX, CX
|
||||||
|
ADCQ $0x00, CX
|
||||||
|
SHRQ $0x20, AX
|
||||||
|
ADDQ CX, AX
|
||||||
|
MOVWQZX AX, CX
|
||||||
|
SHRQ $0x10, AX
|
||||||
|
ADDQ CX, AX
|
||||||
|
MOVW AX, CX
|
||||||
|
SHRQ $0x10, AX
|
||||||
|
ADDW CX, AX
|
||||||
|
ADCW $0x00, AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVW AX, ret+32(FP)
|
||||||
|
RET
|
||||||
|
|
||||||
|
// func checksumAMD64(b []byte, initial uint16) uint16
|
||||||
|
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
|
||||||
|
MOVWQZX initial+24(FP), AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVQ b_base+0(FP), DX
|
||||||
|
MOVQ b_len+8(FP), BX
|
||||||
|
|
||||||
|
// handle odd length buffers; they are difficult to handle in general
|
||||||
|
TESTQ $0x00000001, BX
|
||||||
|
JZ lengthIsEven
|
||||||
|
MOVBQZX -1(DX)(BX*1), CX
|
||||||
|
DECQ BX
|
||||||
|
ADDQ CX, AX
|
||||||
|
|
||||||
|
lengthIsEven:
|
||||||
|
// handle tiny buffers (<=31 bytes) specially
|
||||||
|
CMPQ BX, $0x1f
|
||||||
|
JGT bufferIsNotTiny
|
||||||
|
XORQ CX, CX
|
||||||
|
XORQ SI, SI
|
||||||
|
XORQ DI, DI
|
||||||
|
|
||||||
|
// shift twice to start because length is guaranteed to be even
|
||||||
|
// n = n >> 2; CF = originalN & 2
|
||||||
|
SHRQ $0x02, BX
|
||||||
|
JNC handleTiny4
|
||||||
|
|
||||||
|
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||||
|
MOVWQZX (DX), CX
|
||||||
|
ADDQ $0x02, DX
|
||||||
|
|
||||||
|
handleTiny4:
|
||||||
|
// n = n >> 1; CF = originalN & 4
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny8
|
||||||
|
|
||||||
|
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||||
|
MOVLQZX (DX), SI
|
||||||
|
ADDQ $0x04, DX
|
||||||
|
|
||||||
|
handleTiny8:
|
||||||
|
// n = n >> 1; CF = originalN & 8
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTiny16
|
||||||
|
|
||||||
|
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||||
|
MOVQ (DX), DI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleTiny16:
|
||||||
|
// n = n >> 1; CF = originalN & 16
|
||||||
|
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleTinyFinish
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
|
||||||
|
handleTinyFinish:
|
||||||
|
// CF should be included from the previous add, so we use ADCQ.
|
||||||
|
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||||
|
// so ADCQ will still produce the correct result.
|
||||||
|
ADCQ CX, AX
|
||||||
|
ADCQ SI, AX
|
||||||
|
ADCQ DI, AX
|
||||||
|
JMP foldAndReturn
|
||||||
|
|
||||||
|
bufferIsNotTiny:
|
||||||
|
// Number of 256 byte iterations into loop counter
|
||||||
|
MOVQ BX, CX
|
||||||
|
|
||||||
|
// Update number of bytes remaining after the loop completes
|
||||||
|
ANDQ $0xff, BX
|
||||||
|
SHRQ $0x08, CX
|
||||||
|
JZ startCleanup
|
||||||
|
CLC
|
||||||
|
XORQ SI, SI
|
||||||
|
XORQ DI, DI
|
||||||
|
XORQ R8, R8
|
||||||
|
XORQ R9, R9
|
||||||
|
XORQ R10, R10
|
||||||
|
XORQ R11, R11
|
||||||
|
XORQ R12, R12
|
||||||
|
|
||||||
|
bigLoop:
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ 32(DX), DI
|
||||||
|
ADCQ 40(DX), DI
|
||||||
|
ADCQ 48(DX), DI
|
||||||
|
ADCQ 56(DX), DI
|
||||||
|
ADCQ $0x00, R8
|
||||||
|
ADDQ 64(DX), R9
|
||||||
|
ADCQ 72(DX), R9
|
||||||
|
ADCQ 80(DX), R9
|
||||||
|
ADCQ 88(DX), R9
|
||||||
|
ADCQ $0x00, R10
|
||||||
|
ADDQ 96(DX), R11
|
||||||
|
ADCQ 104(DX), R11
|
||||||
|
ADCQ 112(DX), R11
|
||||||
|
ADCQ 120(DX), R11
|
||||||
|
ADCQ $0x00, R12
|
||||||
|
ADDQ 128(DX), AX
|
||||||
|
ADCQ 136(DX), AX
|
||||||
|
ADCQ 144(DX), AX
|
||||||
|
ADCQ 152(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ 160(DX), DI
|
||||||
|
ADCQ 168(DX), DI
|
||||||
|
ADCQ 176(DX), DI
|
||||||
|
ADCQ 184(DX), DI
|
||||||
|
ADCQ $0x00, R8
|
||||||
|
ADDQ 192(DX), R9
|
||||||
|
ADCQ 200(DX), R9
|
||||||
|
ADCQ 208(DX), R9
|
||||||
|
ADCQ 216(DX), R9
|
||||||
|
ADCQ $0x00, R10
|
||||||
|
ADDQ 224(DX), R11
|
||||||
|
ADCQ 232(DX), R11
|
||||||
|
ADCQ 240(DX), R11
|
||||||
|
ADCQ 248(DX), R11
|
||||||
|
ADCQ $0x00, R12
|
||||||
|
ADDQ $0x00000100, DX
|
||||||
|
SUBQ $0x01, CX
|
||||||
|
JNZ bigLoop
|
||||||
|
ADDQ SI, AX
|
||||||
|
ADCQ DI, AX
|
||||||
|
ADCQ R8, AX
|
||||||
|
ADCQ R9, AX
|
||||||
|
ADCQ R10, AX
|
||||||
|
ADCQ R11, AX
|
||||||
|
ADCQ R12, AX
|
||||||
|
|
||||||
|
// accumulate CF (twice, in case the first time overflows)
|
||||||
|
ADCQ $0x00, AX
|
||||||
|
ADCQ $0x00, AX
|
||||||
|
|
||||||
|
startCleanup:
|
||||||
|
// Accumulate carries in this register. It is never expected to overflow.
|
||||||
|
XORQ SI, SI
|
||||||
|
|
||||||
|
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||||
|
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||||
|
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||||
|
// preserve any alignment that may exist for the start of the buffer.
|
||||||
|
MOVQ BX, CX
|
||||||
|
SHRQ $0x03, BX
|
||||||
|
ANDQ $0x07, CX
|
||||||
|
JZ handleRemaining8
|
||||||
|
LEAQ (DX)(BX*8), DI
|
||||||
|
MOVQ -8(DI)(CX*1), DI
|
||||||
|
|
||||||
|
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||||
|
SHLQ $0x03, CX
|
||||||
|
NEGQ CX
|
||||||
|
ADDQ $0x40, CX
|
||||||
|
SHRQ CL, DI
|
||||||
|
ADDQ DI, AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
|
||||||
|
handleRemaining8:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining16
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x08, DX
|
||||||
|
|
||||||
|
handleRemaining16:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining32
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x10, DX
|
||||||
|
|
||||||
|
handleRemaining32:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining64
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x20, DX
|
||||||
|
|
||||||
|
handleRemaining64:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemaining128
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x40, DX
|
||||||
|
|
||||||
|
handleRemaining128:
|
||||||
|
SHRQ $0x01, BX
|
||||||
|
JNC handleRemainingComplete
|
||||||
|
ADDQ (DX), AX
|
||||||
|
ADCQ 8(DX), AX
|
||||||
|
ADCQ 16(DX), AX
|
||||||
|
ADCQ 24(DX), AX
|
||||||
|
ADCQ 32(DX), AX
|
||||||
|
ADCQ 40(DX), AX
|
||||||
|
ADCQ 48(DX), AX
|
||||||
|
ADCQ 56(DX), AX
|
||||||
|
ADCQ 64(DX), AX
|
||||||
|
ADCQ 72(DX), AX
|
||||||
|
ADCQ 80(DX), AX
|
||||||
|
ADCQ 88(DX), AX
|
||||||
|
ADCQ 96(DX), AX
|
||||||
|
ADCQ 104(DX), AX
|
||||||
|
ADCQ 112(DX), AX
|
||||||
|
ADCQ 120(DX), AX
|
||||||
|
ADCQ $0x00, SI
|
||||||
|
ADDQ $0x80, DX
|
||||||
|
|
||||||
|
handleRemainingComplete:
|
||||||
|
ADDQ SI, AX
|
||||||
|
|
||||||
|
foldAndReturn:
|
||||||
|
// add CF and fold
|
||||||
|
MOVL AX, CX
|
||||||
|
ADCQ $0x00, CX
|
||||||
|
SHRQ $0x20, AX
|
||||||
|
ADDQ CX, AX
|
||||||
|
MOVWQZX AX, CX
|
||||||
|
SHRQ $0x10, AX
|
||||||
|
ADDQ CX, AX
|
||||||
|
MOVW AX, CX
|
||||||
|
SHRQ $0x10, AX
|
||||||
|
ADDW CX, AX
|
||||||
|
ADCW $0x00, AX
|
||||||
|
XCHGB AH, AL
|
||||||
|
MOVW AX, ret+32(FP)
|
||||||
|
RET
|
15
internal/tschecksum/checksum_generic.go
Normal file
15
internal/tschecksum/checksum_generic.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
// This file contains IP checksum algorithms that are not specific to any
|
||||||
|
// architecture and don't use hardware acceleration.
|
||||||
|
|
||||||
|
//go:build !amd64
|
||||||
|
|
||||||
|
package tschecksum
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func Checksum(data []byte, initial uint16) uint16 {
|
||||||
|
if strconv.IntSize < 64 {
|
||||||
|
return checksumGeneric32(data, initial)
|
||||||
|
}
|
||||||
|
return checksumGeneric64(data, initial)
|
||||||
|
}
|
578
internal/tschecksum/generate_amd64.go
Normal file
578
internal/tschecksum/generate_amd64.go
Normal file
|
@ -0,0 +1,578 @@
|
||||||
|
//go:build ignore
|
||||||
|
|
||||||
|
//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"math/bits"
|
||||||
|
|
||||||
|
"github.com/mmcloughlin/avo/operand"
|
||||||
|
"github.com/mmcloughlin/avo/reg"
|
||||||
|
)
|
||||||
|
|
||||||
|
const checksumSignature = "func(b []byte, initial uint16) uint16"
|
||||||
|
|
||||||
|
func loadParams() (accum, buf, n reg.GPVirtual) {
|
||||||
|
accum, buf, n = GP64(), GP64(), GP64()
|
||||||
|
Load(Param("initial"), accum)
|
||||||
|
XCHGB(accum.As8H(), accum.As8L())
|
||||||
|
Load(Param("b").Base(), buf)
|
||||||
|
Load(Param("b").Len(), n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type simdStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
sse2 = iota
|
||||||
|
avx2
|
||||||
|
)
|
||||||
|
|
||||||
|
const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes.
|
||||||
|
|
||||||
|
func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) {
|
||||||
|
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
|
||||||
|
Pragma("noescape")
|
||||||
|
Doc(doc)
|
||||||
|
|
||||||
|
accum64, buf, n := loadParams()
|
||||||
|
|
||||||
|
handleOddLength(n, buf, accum64)
|
||||||
|
// no chance of overflow because accum64 was initialized by a uint16 and
|
||||||
|
// handleOddLength adds at most a uint8
|
||||||
|
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
|
||||||
|
Label("bufferIsNotTiny")
|
||||||
|
|
||||||
|
const simdReadSize = 16
|
||||||
|
|
||||||
|
if minSIMDSize > tinyBufferSize {
|
||||||
|
Comment("skip all SIMD for small buffers")
|
||||||
|
if minSIMDSize <= math.MaxUint8 {
|
||||||
|
CMPQ(n, operand.U8(minSIMDSize))
|
||||||
|
} else {
|
||||||
|
CMPQ(n, operand.U32(minSIMDSize))
|
||||||
|
}
|
||||||
|
JGE(operand.LabelRef("startSIMD"))
|
||||||
|
|
||||||
|
handleRemaining(n, buf, accum64, minSIMDSize-1)
|
||||||
|
JMP(operand.LabelRef("foldAndReturn"))
|
||||||
|
}
|
||||||
|
|
||||||
|
Label("startSIMD")
|
||||||
|
|
||||||
|
// chains is the number of accumulators to use. This improves speed via
|
||||||
|
// reduced data dependency. We combine the accumulators once when the big
|
||||||
|
// loop is complete.
|
||||||
|
simdAccumulate := make([]reg.VecVirtual, chains)
|
||||||
|
for i := range simdAccumulate {
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
simdAccumulate[i] = XMM()
|
||||||
|
PXOR(simdAccumulate[i], simdAccumulate[i])
|
||||||
|
case avx2:
|
||||||
|
simdAccumulate[i] = YMM()
|
||||||
|
VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var zero reg.VecVirtual
|
||||||
|
if strategy == sse2 {
|
||||||
|
zero = XMM()
|
||||||
|
PXOR(zero, zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of loads per big loop
|
||||||
|
const unroll = 16
|
||||||
|
// Number of bytes
|
||||||
|
loopSize := uint64(simdReadSize * unroll)
|
||||||
|
if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 {
|
||||||
|
panic("loopSize is not a power of 2")
|
||||||
|
}
|
||||||
|
loopCount := GP64()
|
||||||
|
|
||||||
|
MOVQ(n, loopCount)
|
||||||
|
Comment("Update number of bytes remaining after the loop completes")
|
||||||
|
ANDQ(operand.Imm(loopSize-1), n)
|
||||||
|
Comment(fmt.Sprintf("Number of %d byte iterations", loopSize))
|
||||||
|
SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount)
|
||||||
|
JZ(operand.LabelRef("smallLoop"))
|
||||||
|
Label("bigLoop")
|
||||||
|
for i := 0; i < unroll; i++ {
|
||||||
|
chain := i % chains
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains])
|
||||||
|
case avx2:
|
||||||
|
avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ADDQ(operand.U32(loopSize), buf)
|
||||||
|
DECQ(loopCount)
|
||||||
|
JNZ(operand.LabelRef("bigLoop"))
|
||||||
|
|
||||||
|
Label("bigCleanup")
|
||||||
|
|
||||||
|
CMPQ(n, operand.Imm(uint64(simdReadSize)))
|
||||||
|
JLT(operand.LabelRef("doneSmallLoop"))
|
||||||
|
|
||||||
|
Commentf("now read a single %d byte unit of data at a time", simdReadSize)
|
||||||
|
Label("smallLoop")
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1])
|
||||||
|
case avx2:
|
||||||
|
avx2AccumulateStep(0, buf, simdAccumulate[0])
|
||||||
|
}
|
||||||
|
ADDQ(operand.Imm(uint64(simdReadSize)), buf)
|
||||||
|
SUBQ(operand.Imm(uint64(simdReadSize)), n)
|
||||||
|
CMPQ(n, operand.Imm(uint64(simdReadSize)))
|
||||||
|
JGE(operand.LabelRef("smallLoop"))
|
||||||
|
|
||||||
|
Label("doneSmallLoop")
|
||||||
|
CMPQ(n, operand.Imm(0))
|
||||||
|
JE(operand.LabelRef("doneSIMD"))
|
||||||
|
|
||||||
|
Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1)
|
||||||
|
|
||||||
|
maskDataPtr := GP64()
|
||||||
|
LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr)
|
||||||
|
dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize}
|
||||||
|
// scale 8 is only correct here because n is guaranteed to be even and we
|
||||||
|
// do not generate masks for odd lengths
|
||||||
|
maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16}
|
||||||
|
remainder := XMM()
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
MOVOU(dataAddr, remainder)
|
||||||
|
PAND(maskAddr, remainder)
|
||||||
|
low := XMM()
|
||||||
|
MOVOA(remainder, low)
|
||||||
|
PUNPCKHWL(zero, remainder)
|
||||||
|
PUNPCKLWL(zero, low)
|
||||||
|
PADDD(remainder, simdAccumulate[0])
|
||||||
|
PADDD(low, simdAccumulate[1])
|
||||||
|
case avx2:
|
||||||
|
// Note: this is very similar to the sse2 path but MOVOU has a massive
|
||||||
|
// performance hit if used here, presumably due to switching between SSE
|
||||||
|
// and AVX2 modes.
|
||||||
|
VMOVDQU(dataAddr, remainder)
|
||||||
|
VPAND(maskAddr, remainder, remainder)
|
||||||
|
|
||||||
|
temp := YMM()
|
||||||
|
VPMOVZXWD(remainder, temp)
|
||||||
|
VPADDD(temp, simdAccumulate[0], simdAccumulate[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
Label("doneSIMD")
|
||||||
|
|
||||||
|
Comment("Multi-chain loop is done, combine the accumulators")
|
||||||
|
for i := range simdAccumulate {
|
||||||
|
if i == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
PADDD(simdAccumulate[i], simdAccumulate[0])
|
||||||
|
case avx2:
|
||||||
|
VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strategy == avx2 {
|
||||||
|
Comment("extract the YMM into a pair of XMM and sum them")
|
||||||
|
tmp := YMM()
|
||||||
|
VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX())
|
||||||
|
|
||||||
|
xAccumulate := XMM()
|
||||||
|
VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate)
|
||||||
|
simdAccumulate = []reg.VecVirtual{xAccumulate}
|
||||||
|
}
|
||||||
|
|
||||||
|
Comment("extract the XMM into GP64")
|
||||||
|
low, high := GP64(), GP64()
|
||||||
|
switch strategy {
|
||||||
|
case sse2:
|
||||||
|
MOVQ(simdAccumulate[0], low)
|
||||||
|
PSRLDQ(operand.Imm(8), simdAccumulate[0])
|
||||||
|
MOVQ(simdAccumulate[0], high)
|
||||||
|
case avx2:
|
||||||
|
VPEXTRQ(operand.Imm(0), simdAccumulate[0], low)
|
||||||
|
VPEXTRQ(operand.Imm(1), simdAccumulate[0], high)
|
||||||
|
|
||||||
|
Comment("no more AVX code, clear upper registers to avoid SSE slowdowns")
|
||||||
|
VZEROUPPER()
|
||||||
|
}
|
||||||
|
ADDQ(low, accum64)
|
||||||
|
ADCQ(high, accum64)
|
||||||
|
Label("foldAndReturn")
|
||||||
|
foldWithCF(accum64, strategy == avx2)
|
||||||
|
XCHGB(accum64.As8H(), accum64.As8L())
|
||||||
|
Store(accum64.As16(), ReturnIndex(0))
|
||||||
|
RET()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleOddLength generates instructions to incorporate the last byte into
|
||||||
|
// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to
|
||||||
|
// handle that if overflow is possible.
|
||||||
|
func handleOddLength(n, buf, accum64 reg.GPVirtual) {
|
||||||
|
Comment("handle odd length buffers; they are difficult to handle in general")
|
||||||
|
TESTQ(operand.U32(1), n)
|
||||||
|
JZ(operand.LabelRef("lengthIsEven"))
|
||||||
|
|
||||||
|
tmp := GP64()
|
||||||
|
MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp)
|
||||||
|
DECQ(n)
|
||||||
|
ADDQ(tmp, accum64)
|
||||||
|
|
||||||
|
Label("lengthIsEven")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) {
|
||||||
|
high, low := XMM(), XMM()
|
||||||
|
MOVOU(operand.Mem{Disp: offset, Base: buf}, high)
|
||||||
|
MOVOA(high, low)
|
||||||
|
PUNPCKHWL(zero, high)
|
||||||
|
PUNPCKLWL(zero, low)
|
||||||
|
PADDD(high, accumulate1)
|
||||||
|
PADDD(low, accumulate2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) {
|
||||||
|
tmp := YMM()
|
||||||
|
VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp)
|
||||||
|
VPADDD(tmp, accumulate, accumulate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateAMD64Checksum(name, doc string) {
|
||||||
|
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
|
||||||
|
Pragma("noescape")
|
||||||
|
Doc(doc)
|
||||||
|
|
||||||
|
accum64, buf, n := loadParams()
|
||||||
|
|
||||||
|
handleOddLength(n, buf, accum64)
|
||||||
|
// no chance of overflow because accum64 was initialized by a uint16 and
|
||||||
|
// handleOddLength adds at most a uint8
|
||||||
|
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
|
||||||
|
Label("bufferIsNotTiny")
|
||||||
|
|
||||||
|
const (
|
||||||
|
// numChains is the number of accumulators and carry counters to use.
|
||||||
|
// This improves speed via reduced data dependency. We combine the
|
||||||
|
// accumulators and carry counters once when the loop is complete.
|
||||||
|
numChains = 4
|
||||||
|
unroll = 32 // The number of 64-bit reads to perform per iteration of the loop.
|
||||||
|
loopSize = 8 * unroll // The number of bytes read per iteration of the loop.
|
||||||
|
)
|
||||||
|
if bits.Len(loopSize) != bits.Len(loopSize-1)+1 {
|
||||||
|
panic("loopSize is not a power of 2")
|
||||||
|
}
|
||||||
|
loopCount := GP64()
|
||||||
|
|
||||||
|
Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize))
|
||||||
|
MOVQ(n, loopCount)
|
||||||
|
Comment("Update number of bytes remaining after the loop completes")
|
||||||
|
ANDQ(operand.Imm(loopSize-1), n)
|
||||||
|
SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount)
|
||||||
|
JZ(operand.LabelRef("startCleanup"))
|
||||||
|
CLC()
|
||||||
|
|
||||||
|
chains := make([]struct {
|
||||||
|
accum reg.GPVirtual
|
||||||
|
carries reg.GPVirtual
|
||||||
|
}, numChains)
|
||||||
|
for i := range chains {
|
||||||
|
if i == 0 {
|
||||||
|
chains[i].accum = accum64
|
||||||
|
} else {
|
||||||
|
chains[i].accum = GP64()
|
||||||
|
XORQ(chains[i].accum, chains[i].accum)
|
||||||
|
}
|
||||||
|
chains[i].carries = GP64()
|
||||||
|
XORQ(chains[i].carries, chains[i].carries)
|
||||||
|
}
|
||||||
|
|
||||||
|
Label("bigLoop")
|
||||||
|
|
||||||
|
var curChain int
|
||||||
|
for i := 0; i < unroll; i++ {
|
||||||
|
// It is significantly faster to use a ADCX/ADOX pair instead of plain
|
||||||
|
// ADC, which results in two dependency chains, however those require
|
||||||
|
// ADX support, which was added after AVX2. If AVX2 is available, that's
|
||||||
|
// even better than ADCX/ADOX.
|
||||||
|
//
|
||||||
|
// However, multiple dependency chains using multiple accumulators and
|
||||||
|
// occasionally storing CF into temporary counters seems to work almost
|
||||||
|
// as well.
|
||||||
|
addr := operand.Mem{Disp: i * 8, Base: buf}
|
||||||
|
|
||||||
|
if i%4 == 0 {
|
||||||
|
if i > 0 {
|
||||||
|
ADCQ(operand.Imm(0), chains[curChain].carries)
|
||||||
|
curChain = (curChain + 1) % len(chains)
|
||||||
|
}
|
||||||
|
ADDQ(addr, chains[curChain].accum)
|
||||||
|
} else {
|
||||||
|
ADCQ(addr, chains[curChain].accum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ADCQ(operand.Imm(0), chains[curChain].carries)
|
||||||
|
ADDQ(operand.U32(loopSize), buf)
|
||||||
|
SUBQ(operand.Imm(1), loopCount)
|
||||||
|
JNZ(operand.LabelRef("bigLoop"))
|
||||||
|
for i := range chains {
|
||||||
|
if i == 0 {
|
||||||
|
ADDQ(chains[i].carries, accum64)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ADCQ(chains[i].accum, accum64)
|
||||||
|
ADCQ(chains[i].carries, accum64)
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulateCF(accum64)
|
||||||
|
|
||||||
|
Label("startCleanup")
|
||||||
|
handleRemaining(n, buf, accum64, loopSize-1)
|
||||||
|
Label("foldAndReturn")
|
||||||
|
foldWithCF(accum64, false)
|
||||||
|
|
||||||
|
XCHGB(accum64.As8H(), accum64.As8L())
|
||||||
|
Store(accum64.As16(), ReturnIndex(0))
|
||||||
|
RET()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTinyBuffers computes checksums if the buffer length (the n parameter)
|
||||||
|
// is less than 32. After computing the checksum, a jump to returnLabel will
|
||||||
|
// be executed. Otherwise, if the buffer length is at least 32, nothing will be
|
||||||
|
// modified; a jump to continueLabel will be executed instead.
|
||||||
|
//
|
||||||
|
// When jumping to returnLabel, CF may be set and must be accommodated e.g.
|
||||||
|
// using foldWithCF or accumulateCF.
|
||||||
|
//
|
||||||
|
// Anecdotally, this appears to be faster than attempting to coordinate an
|
||||||
|
// overlapped read (which would also require special handling for buffers
|
||||||
|
// smaller than 8).
|
||||||
|
func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) {
|
||||||
|
Comment("handle tiny buffers (<=31 bytes) specially")
|
||||||
|
CMPQ(n, operand.Imm(tinyBufferSize))
|
||||||
|
JGT(continueLabel)
|
||||||
|
|
||||||
|
tmp2, tmp4, tmp8 := GP64(), GP64(), GP64()
|
||||||
|
XORQ(tmp2, tmp2)
|
||||||
|
XORQ(tmp4, tmp4)
|
||||||
|
XORQ(tmp8, tmp8)
|
||||||
|
|
||||||
|
Comment("shift twice to start because length is guaranteed to be even",
|
||||||
|
"n = n >> 2; CF = originalN & 2")
|
||||||
|
SHRQ(operand.Imm(2), n)
|
||||||
|
JNC(operand.LabelRef("handleTiny4"))
|
||||||
|
Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]")
|
||||||
|
MOVWQZX(operand.Mem{Base: buf}, tmp2)
|
||||||
|
ADDQ(operand.Imm(2), buf)
|
||||||
|
|
||||||
|
Label("handleTiny4")
|
||||||
|
Comment("n = n >> 1; CF = originalN & 4")
|
||||||
|
SHRQ(operand.Imm(1), n)
|
||||||
|
JNC(operand.LabelRef("handleTiny8"))
|
||||||
|
Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]")
|
||||||
|
MOVLQZX(operand.Mem{Base: buf}, tmp4)
|
||||||
|
ADDQ(operand.Imm(4), buf)
|
||||||
|
|
||||||
|
Label("handleTiny8")
|
||||||
|
Comment("n = n >> 1; CF = originalN & 8")
|
||||||
|
SHRQ(operand.Imm(1), n)
|
||||||
|
JNC(operand.LabelRef("handleTiny16"))
|
||||||
|
Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]")
|
||||||
|
MOVQ(operand.Mem{Base: buf}, tmp8)
|
||||||
|
ADDQ(operand.Imm(8), buf)
|
||||||
|
|
||||||
|
Label("handleTiny16")
|
||||||
|
Comment("n = n >> 1; CF = originalN & 16",
|
||||||
|
"n == 0 now, otherwise we would have branched after comparing with tinyBufferSize")
|
||||||
|
SHRQ(operand.Imm(1), n)
|
||||||
|
JNC(operand.LabelRef("handleTinyFinish"))
|
||||||
|
ADDQ(operand.Mem{Base: buf}, accum)
|
||||||
|
ADCQ(operand.Mem{Base: buf, Disp: 8}, accum)
|
||||||
|
|
||||||
|
Label("handleTinyFinish")
|
||||||
|
Comment("CF should be included from the previous add, so we use ADCQ.",
|
||||||
|
"If we arrived via the JNC above, then CF=0 due to the branch condition,",
|
||||||
|
"so ADCQ will still produce the correct result.")
|
||||||
|
ADCQ(tmp2, accum)
|
||||||
|
ADCQ(tmp4, accum)
|
||||||
|
ADCQ(tmp8, accum)
|
||||||
|
|
||||||
|
JMP(returnLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRemaining generates a series of conditional unrolled additions,
|
||||||
|
// starting with 8 bytes long and doubling each time until the length reaches
|
||||||
|
// max. This is the reverse order of what may be intuitive, but makes the branch
|
||||||
|
// conditions convenient to compute: perform one right shift each time and test
|
||||||
|
// against CF.
|
||||||
|
//
|
||||||
|
// When done, CF may be set and must be accommodated e.g., using foldWithCF or
|
||||||
|
// accumulateCF.
|
||||||
|
//
|
||||||
|
// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer
|
||||||
|
// will be performed, overlapping with data that will be read later. The
|
||||||
|
// duplicate data will be shifted off.
|
||||||
|
//
|
||||||
|
// The original buffer length must have been at least 8 bytes long, even if
|
||||||
|
// n < 8, otherwise this will access memory before the start of the buffer,
|
||||||
|
// which may be unsafe.
|
||||||
|
func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) {
|
||||||
|
Comment("Accumulate carries in this register. It is never expected to overflow.")
|
||||||
|
carries := GP64()
|
||||||
|
XORQ(carries, carries)
|
||||||
|
|
||||||
|
Comment("We will perform an overlapped read for buffers with length not a multiple of 8.",
|
||||||
|
"Overlapped in this context means some memory will be read twice, but a shift will",
|
||||||
|
"eliminate the duplicated data. This extra read is performed at the end of the buffer to",
|
||||||
|
"preserve any alignment that may exist for the start of the buffer.")
|
||||||
|
leftover := reg.RCX
|
||||||
|
MOVQ(n, leftover)
|
||||||
|
SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining
|
||||||
|
ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end
|
||||||
|
JZ(operand.LabelRef("handleRemaining8"))
|
||||||
|
endBuf := GP64()
|
||||||
|
// endBuf is the position near the end of the buffer that is just past the
|
||||||
|
// last multiple of 8: (buf + len(buf)) & ^0x7
|
||||||
|
LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf)
|
||||||
|
|
||||||
|
overlapRead := GP64()
|
||||||
|
// equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)])
|
||||||
|
MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead)
|
||||||
|
|
||||||
|
Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)")
|
||||||
|
SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8
|
||||||
|
NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression
|
||||||
|
ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8)
|
||||||
|
SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL
|
||||||
|
|
||||||
|
ADDQ(overlapRead, accum64)
|
||||||
|
ADCQ(operand.Imm(0), carries)
|
||||||
|
|
||||||
|
for curBytes := 8; curBytes <= max; curBytes *= 2 {
|
||||||
|
Label(fmt.Sprintf("handleRemaining%d", curBytes))
|
||||||
|
SHRQ(operand.Imm(1), n)
|
||||||
|
if curBytes*2 <= max {
|
||||||
|
JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
|
||||||
|
} else {
|
||||||
|
JNC(operand.LabelRef("handleRemainingComplete"))
|
||||||
|
}
|
||||||
|
|
||||||
|
numLoads := curBytes / 8
|
||||||
|
for i := 0; i < numLoads; i++ {
|
||||||
|
addr := operand.Mem{Base: buf, Disp: i * 8}
|
||||||
|
// It is possible to add the multiple dependency chains trick here
|
||||||
|
// that generateAMD64Checksum uses but anecdotally it does not
|
||||||
|
// appear to outweigh the cost.
|
||||||
|
if i == 0 {
|
||||||
|
ADDQ(addr, accum64)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ADCQ(addr, accum64)
|
||||||
|
}
|
||||||
|
ADCQ(operand.Imm(0), carries)
|
||||||
|
|
||||||
|
if curBytes > math.MaxUint8 {
|
||||||
|
ADDQ(operand.U32(uint64(curBytes)), buf)
|
||||||
|
} else {
|
||||||
|
ADDQ(operand.U8(uint64(curBytes)), buf)
|
||||||
|
}
|
||||||
|
if curBytes*2 >= max {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
|
||||||
|
}
|
||||||
|
Label("handleRemainingComplete")
|
||||||
|
ADDQ(carries, accum64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func accumulateCF(accum64 reg.GPVirtual) {
|
||||||
|
Comment("accumulate CF (twice, in case the first time overflows)")
|
||||||
|
// accum64 += CF
|
||||||
|
ADCQ(operand.Imm(0), accum64)
|
||||||
|
// accum64 += CF again if the previous add overflowed. The previous add was
|
||||||
|
// 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can
|
||||||
|
// never overflow.
|
||||||
|
ADCQ(operand.Imm(0), accum64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value
|
||||||
|
// according to ones-complement arithmetic. BMI2 instructions will be used if
|
||||||
|
// allowBMI2 is true (requires fewer instructions).
|
||||||
|
func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) {
|
||||||
|
Comment("add CF and fold")
|
||||||
|
|
||||||
|
// CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff
|
||||||
|
|
||||||
|
tmp := GP64()
|
||||||
|
if allowBMI2 {
|
||||||
|
// effectively, tmp = accum >> 32 (technically, this is a rotate)
|
||||||
|
RORXQ(operand.Imm(32), accum, tmp)
|
||||||
|
// accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set
|
||||||
|
ADCL(tmp.As32(), accum.As32())
|
||||||
|
// effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate)
|
||||||
|
RORXL(operand.Imm(16), accum.As32(), tmp.As32())
|
||||||
|
// accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set
|
||||||
|
ADCW(tmp.As16(), accum.As16())
|
||||||
|
} else {
|
||||||
|
// tmp = uint32(accum); max value 0xffff_ffff
|
||||||
|
// MOVL clears the upper 32 bits of a GP64 so this is equivalent to the
|
||||||
|
// non-existent MOVLQZX.
|
||||||
|
MOVL(accum.As32(), tmp.As32())
|
||||||
|
// tmp += CF; max value 0x1_0000_0000, CF unset
|
||||||
|
ADCQ(operand.Imm(0), tmp)
|
||||||
|
// accum = accum >> 32; max value 0xffff_ffff
|
||||||
|
SHRQ(operand.Imm(32), accum)
|
||||||
|
// accum = accum + tmp; max value 0x1_ffff_ffff + CF unset
|
||||||
|
ADDQ(tmp, accum)
|
||||||
|
// tmp = uint16(accum); max value 0xffff
|
||||||
|
MOVWQZX(accum.As16(), tmp)
|
||||||
|
// accum = accum >> 16; max value 0x1_ffff
|
||||||
|
SHRQ(operand.Imm(16), accum)
|
||||||
|
// accum = accum + tmp; max value 0x2_fffe + CF unset
|
||||||
|
ADDQ(tmp, accum)
|
||||||
|
// tmp as uint16 = uint16(accum); max value 0xffff
|
||||||
|
MOVW(accum.As16(), tmp.As16())
|
||||||
|
// accum = accum >> 16; max value 0x2
|
||||||
|
SHRQ(operand.Imm(16), accum)
|
||||||
|
// accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set
|
||||||
|
ADDW(tmp.As16(), accum.As16())
|
||||||
|
}
|
||||||
|
// accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe
|
||||||
|
ADCW(operand.Imm(0), accum.As16())
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateLoadMasks() {
|
||||||
|
var offset int
|
||||||
|
// xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14
|
||||||
|
GLOBL("xmmLoadMasks", RODATA|NOPTR)
|
||||||
|
|
||||||
|
for n := 2; n < 16; n += 2 {
|
||||||
|
var pattern [16]byte
|
||||||
|
for i := 0; i < len(pattern); i++ {
|
||||||
|
if i < len(pattern)-n {
|
||||||
|
pattern[i] = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pattern[i] = 0xff
|
||||||
|
}
|
||||||
|
DATA(offset, operand.String(pattern[:]))
|
||||||
|
offset += len(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
generateLoadMasks()
|
||||||
|
generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2)
|
||||||
|
generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2)
|
||||||
|
generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions")
|
||||||
|
Generate()
|
||||||
|
}
|
|
@ -51,12 +51,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldInterface := m.defaultInterface.Load()
|
|
||||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||||
}
|
}
|
||||||
m.defaultInterface.Store(newInterface)
|
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||||
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
|
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -165,12 +165,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||||
if defaultInterface == nil {
|
if defaultInterface == nil {
|
||||||
return ErrNoRoute
|
return ErrNoRoute
|
||||||
}
|
}
|
||||||
oldInterface := m.defaultInterface.Load()
|
|
||||||
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
|
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
|
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
|
||||||
}
|
}
|
||||||
m.defaultInterface.Store(newInterface)
|
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ type networkUpdateMonitor struct {
|
||||||
var ErrNetlinkBanned = E.New(
|
var ErrNetlinkBanned = E.New(
|
||||||
"netlink socket in Android is banned by Google, " +
|
"netlink socket in Android is banned by Google, " +
|
||||||
"use the root or system (ADB) user to run sing-box, " +
|
"use the root or system (ADB) user to run sing-box, " +
|
||||||
"or switch to the sing-box Adnroid graphical interface client",
|
"or switch to the sing-box Android graphical interface client",
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||||
|
|
|
@ -25,12 +25,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldInterface := m.defaultInterface.Load()
|
|
||||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||||
}
|
}
|
||||||
m.defaultInterface.Store(newInterface)
|
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,13 +102,12 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||||
return ErrNoRoute
|
return ErrNoRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
oldInterface := m.defaultInterface.Load()
|
|
||||||
newInterface, err := m.interfaceFinder.ByIndex(index)
|
newInterface, err := m.interfaceFinder.ByIndex(index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "find updated interface: ", alias)
|
return E.Cause(err, "find updated interface: ", alias)
|
||||||
}
|
}
|
||||||
m.defaultInterface.Store(newInterface)
|
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||||
if oldInterface != nil && !oldInterface.Equals(*newInterface) {
|
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.emit(newInterface, 0)
|
m.emit(newInterface, 0)
|
||||||
|
|
|
@ -44,7 +44,7 @@ type autoRedirect struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
r := &autoRedirect{
|
return &autoRedirect{
|
||||||
tunOptions: options.TunOptions,
|
tunOptions: options.TunOptions,
|
||||||
ctx: options.Context,
|
ctx: options.Context,
|
||||||
handler: options.Handler,
|
handler: options.Handler,
|
||||||
|
@ -56,7 +56,10 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
customRedirectPortFunc: options.CustomRedirectPort,
|
customRedirectPortFunc: options.CustomRedirectPort,
|
||||||
routeAddressSet: options.RouteAddressSet,
|
routeAddressSet: options.RouteAddressSet,
|
||||||
routeExcludeAddressSet: options.RouteExcludeAddressSet,
|
routeExcludeAddressSet: options.RouteExcludeAddressSet,
|
||||||
}
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *autoRedirect) Start() error {
|
||||||
var err error
|
var err error
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
r.enableIPv4 = true
|
r.enableIPv4 = true
|
||||||
|
@ -74,7 +77,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
|
return E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -90,7 +93,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
if !r.useNFTables {
|
if !r.useNFTables {
|
||||||
r.iptablesPath, err = exec.LookPath("iptables")
|
r.iptablesPath, err = exec.LookPath("iptables")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "iptables is required")
|
return E.Cause(err, "iptables is required")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,7 +103,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
r.ip6tablesPath, err = exec.LookPath("ip6tables")
|
r.ip6tablesPath, err = exec.LookPath("ip6tables")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !r.enableIPv4 {
|
if !r.enableIPv4 {
|
||||||
return nil, E.Cause(err, "ip6tables is required")
|
return E.Cause(err, "ip6tables is required")
|
||||||
} else {
|
} else {
|
||||||
r.enableIPv6 = false
|
r.enableIPv6 = false
|
||||||
r.logger.Error("device has no ip6tables nat support: ", err)
|
r.logger.Error("device has no ip6tables nat support: ", err)
|
||||||
|
@ -109,10 +112,6 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *autoRedirect) Start() error {
|
|
||||||
if r.customRedirectPortFunc != nil {
|
if r.customRedirectPortFunc != nil {
|
||||||
r.customRedirectPort = r.customRedirectPortFunc()
|
r.customRedirectPort = r.customRedirectPortFunc()
|
||||||
}
|
}
|
||||||
|
@ -132,7 +131,6 @@ func (r *autoRedirect) Start() error {
|
||||||
}
|
}
|
||||||
r.redirectServer = server
|
r.redirectServer = server
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
if r.useNFTables {
|
if r.useNFTables {
|
||||||
r.cleanupNFTables()
|
r.cleanupNFTables()
|
||||||
err = r.setupNFTables()
|
err = r.setupNFTables()
|
||||||
|
|
|
@ -32,6 +32,10 @@ func (r *autoRedirect) setupNFTables() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = r.interfaceFinder.Update()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
|
r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
|
||||||
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
|
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
|
||||||
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
|
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
|
||||||
|
|
|
@ -19,13 +19,11 @@ import (
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const WithGVisor = true
|
const WithGVisor = true
|
||||||
|
|
||||||
const defaultNIC tcpip.NICID = 1
|
const DefaultNIC tcpip.NICID = 1
|
||||||
|
|
||||||
type GVisor struct {
|
type GVisor struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -68,28 +66,11 @@ func (t *GVisor) Start() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
|
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
|
||||||
ipStack, err := newGVisorStack(linkEndpoint)
|
ipStack, err := NewGVisorStack(linkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
|
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
|
||||||
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
|
|
||||||
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
|
|
||||||
pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination)
|
|
||||||
if pErr != nil {
|
|
||||||
r.Complete(pErr != ErrDrop)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conn := &gLazyConn{
|
|
||||||
parentCtx: t.ctx,
|
|
||||||
stack: t.stack,
|
|
||||||
request: r,
|
|
||||||
localAddr: source.TCPAddr(),
|
|
||||||
remoteAddr: destination.TCPAddr(),
|
|
||||||
}
|
|
||||||
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
|
|
||||||
})
|
|
||||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
|
||||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
||||||
t.stack = ipStack
|
t.stack = ipStack
|
||||||
t.endpoint = linkEndpoint
|
t.endpoint = linkEndpoint
|
||||||
|
@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||||
ipStack := stack.New(stack.Options{
|
ipStack := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
ipv4.NewProtocol,
|
ipv4.NewProtocol,
|
||||||
|
@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||||
icmp.NewProtocol6,
|
icmp.NewProtocol6,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
err := ipStack.CreateNIC(defaultNIC, ep)
|
err := ipStack.CreateNIC(DefaultNIC, ep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gonet.TranslateNetstackError(err)
|
return nil, gonet.TranslateNetstackError(err)
|
||||||
}
|
}
|
||||||
ipStack.SetRouteTable([]tcpip.Route{
|
ipStack.SetRouteTable([]tcpip.Route{
|
||||||
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
|
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
|
||||||
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
|
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
|
||||||
})
|
})
|
||||||
err = ipStack.SetSpoofing(defaultNIC, true)
|
err = ipStack.SetSpoofing(DefaultNIC, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gonet.TranslateNetstackError(err)
|
return nil, gonet.TranslateNetstackError(err)
|
||||||
}
|
}
|
||||||
err = ipStack.SetPromiscuousMode(defaultNIC, true)
|
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gonet.TranslateNetstackError(err)
|
return nil, gonet.TranslateNetstackError(err)
|
||||||
}
|
}
|
||||||
|
|
51
stack_gvisor_tcp.go
Normal file
51
stack_gvisor_tcp.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
//go:build with_gvisor
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPForwarder struct {
|
||||||
|
ctx context.Context
|
||||||
|
stack *stack.Stack
|
||||||
|
handler Handler
|
||||||
|
forwarder *tcp.Forwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
|
||||||
|
forwarder := &TCPForwarder{
|
||||||
|
ctx: ctx,
|
||||||
|
stack: stack,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
|
||||||
|
return forwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||||
|
return f.forwarder.HandlePacket(id, pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
|
||||||
|
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
|
||||||
|
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
|
||||||
|
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
|
||||||
|
if pErr != nil {
|
||||||
|
r.Complete(pErr != ErrDrop)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := &gLazyConn{
|
||||||
|
parentCtx: f.ctx,
|
||||||
|
stack: f.stack,
|
||||||
|
request: r,
|
||||||
|
localAddr: source.TCPAddr(),
|
||||||
|
remoteAddr: destination.TCPAddr(),
|
||||||
|
}
|
||||||
|
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
|
||||||
|
}
|
|
@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
|
||||||
defer packetBuffer.Release()
|
defer packetBuffer.Release()
|
||||||
|
|
||||||
route, err := w.stack.FindRoute(
|
route, err := w.stack.FindRoute(
|
||||||
defaultNIC,
|
DefaultNIC,
|
||||||
AddressFromAddr(destination.Addr),
|
AddressFromAddr(destination.Addr),
|
||||||
w.source,
|
w.source,
|
||||||
w.sourceNetwork,
|
w.sourceNetwork,
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (m *Mixed) Start() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
endpoint := channel.New(1024, uint32(m.mtu), "")
|
endpoint := channel.New(1024, uint32(m.mtu), "")
|
||||||
ipStack, err := newGVisorStack(endpoint)
|
ipStack, err := NewGVisorStack(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,18 @@ func (m *Mixed) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Mixed) Close() error {
|
||||||
|
if m.stack == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.endpoint.Attach(nil)
|
||||||
|
m.stack.Close()
|
||||||
|
for _, endpoint := range m.stack.CleanupEndpoints() {
|
||||||
|
endpoint.Abort()
|
||||||
|
}
|
||||||
|
return m.System.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Mixed) tunLoop() {
|
func (m *Mixed) tunLoop() {
|
||||||
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
|
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
|
||||||
m.wintunLoop(winTun)
|
m.wintunLoop(winTun)
|
||||||
|
@ -137,7 +149,7 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(writeBuffers) > 0 {
|
if len(writeBuffers) > 0 {
|
||||||
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
|
_, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.logger.Trace(E.Cause(err, "batch write packet"))
|
m.logger.Trace(E.Cause(err, "batch write packet"))
|
||||||
}
|
}
|
||||||
|
@ -151,10 +163,10 @@ func (m *Mixed) processPacket(packet []byte) bool {
|
||||||
writeBack bool
|
writeBack bool
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
switch ipVersion := packet[0] >> 4; ipVersion {
|
switch ipVersion := header.IPVersion(packet); ipVersion {
|
||||||
case 4:
|
case header.IPv4Version:
|
||||||
writeBack, err = m.processIPv4(packet)
|
writeBack, err = m.processIPv4(packet)
|
||||||
case 6:
|
case header.IPv6Version:
|
||||||
writeBack, err = m.processIPv6(packet)
|
writeBack, err = m.processIPv6(packet)
|
||||||
default:
|
default:
|
||||||
err = E.New("ip: unknown version: ", ipVersion)
|
err = E.New("ip: unknown version: ", ipVersion)
|
||||||
|
@ -222,15 +234,3 @@ func (m *Mixed) packetLoop() {
|
||||||
packet.DecRef()
|
packet.DecRef()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mixed) Close() error {
|
|
||||||
if m.stack == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.endpoint.Attach(nil)
|
|
||||||
m.stack.Close()
|
|
||||||
for _, endpoint := range m.stack.CleanupEndpoints() {
|
|
||||||
endpoint.Abort()
|
|
||||||
}
|
|
||||||
return m.System.Close()
|
|
||||||
}
|
|
||||||
|
|
|
@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(writeBuffers) > 0 {
|
if len(writeBuffers) > 0 {
|
||||||
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
|
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Trace(E.Cause(err, "batch write packet"))
|
s.logger.Trace(E.Cause(err, "batch write packet"))
|
||||||
}
|
}
|
||||||
|
@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
|
||||||
ipHdr.SetChecksum(0)
|
ipHdr.SetChecksum(0)
|
||||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||||
if PacketOffset > 0 {
|
if PacketOffset > 0 {
|
||||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||||
} else {
|
} else {
|
||||||
newPacket.Advance(-s.frontHeadroom)
|
newPacket.Advance(-s.frontHeadroom)
|
||||||
}
|
}
|
||||||
|
@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro
|
||||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||||
}
|
}
|
||||||
if PacketOffset > 0 {
|
if PacketOffset > 0 {
|
||||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
|
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||||
} else {
|
} else {
|
||||||
newPacket.Advance(-s.frontHeadroom)
|
newPacket.Advance(-s.frontHeadroom)
|
||||||
}
|
}
|
||||||
|
@ -586,7 +586,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
|
||||||
sourceAddress := ipHdr.SourceAddr()
|
sourceAddress := ipHdr.SourceAddr()
|
||||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||||
ipHdr.SetDestinationAddr(sourceAddress)
|
ipHdr.SetDestinationAddr(sourceAddress)
|
||||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, checksum.Checksum(icmpHdr.Payload(), 0)))
|
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||||
ipHdr.SetChecksum(0)
|
ipHdr.SetChecksum(0)
|
||||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||||
return nil
|
return nil
|
||||||
|
@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
|
||||||
}))
|
}))
|
||||||
copy(icmpHdr.Payload(), payload)
|
copy(icmpHdr.Payload(), payload)
|
||||||
if PacketOffset > 0 {
|
if PacketOffset > 0 {
|
||||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
|
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||||
} else {
|
} else {
|
||||||
newPacket.Advance(-s.frontHeadroom)
|
newPacket.Advance(-s.frontHeadroom)
|
||||||
}
|
}
|
||||||
|
@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
|
||||||
ipHdr.SetChecksum(0)
|
ipHdr.SetChecksum(0)
|
||||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||||
if PacketOffset > 0 {
|
if PacketOffset > 0 {
|
||||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||||
} else {
|
} else {
|
||||||
newPacket.Advance(-w.frontHeadroom)
|
newPacket.Advance(-w.frontHeadroom)
|
||||||
}
|
}
|
||||||
|
@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
|
||||||
udpHdr.SetChecksum(0)
|
udpHdr.SetChecksum(0)
|
||||||
}
|
}
|
||||||
if PacketOffset > 0 {
|
if PacketOffset > 0 {
|
||||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
|
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||||
} else {
|
} else {
|
||||||
newPacket.Advance(-w.frontHeadroom)
|
newPacket.Advance(-w.frontHeadroom)
|
||||||
}
|
}
|
||||||
|
|
34
stack_system_packet.go
Normal file
34
stack_system_packet.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PacketIPVersion(packet []byte) int {
|
||||||
|
return header.IPVersion(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func PacketFillHeader(packet []byte, ipVersion int) {
|
||||||
|
if PacketOffset > 0 {
|
||||||
|
switch ipVersion {
|
||||||
|
case header.IPv4Version:
|
||||||
|
packet[3] = syscall.AF_INET
|
||||||
|
case header.IPv6Version:
|
||||||
|
packet[3] = syscall.AF_INET6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PacketDestination(packet []byte) netip.Addr {
|
||||||
|
switch ipVersion := header.IPVersion(packet); ipVersion {
|
||||||
|
case header.IPv4Version:
|
||||||
|
return header.IPv4(packet).DestinationAddr()
|
||||||
|
case header.IPv6Version:
|
||||||
|
return header.IPv6(packet).DestinationAddr()
|
||||||
|
default:
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
}
|
27
tun.go
27
tun.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/control"
|
||||||
F "github.com/sagernet/sing/common/format"
|
F "github.com/sagernet/sing/common/format"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
@ -24,8 +25,10 @@ type Handler interface {
|
||||||
type Tun interface {
|
type Tun interface {
|
||||||
io.ReadWriter
|
io.ReadWriter
|
||||||
N.VectorisedWriter
|
N.VectorisedWriter
|
||||||
|
Name() (string, error)
|
||||||
Start() error
|
Start() error
|
||||||
Close() error
|
Close() error
|
||||||
|
UpdateRouteOptions(tunOptions Options) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type WinTun interface {
|
type WinTun interface {
|
||||||
|
@ -38,7 +41,7 @@ type LinuxTUN interface {
|
||||||
N.FrontHeadroom
|
N.FrontHeadroom
|
||||||
BatchSize() int
|
BatchSize() int
|
||||||
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
|
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
|
||||||
BatchWrite(buffers [][]byte, offset int) error
|
BatchWrite(buffers [][]byte, offset int) (n int, err error)
|
||||||
TXChecksumOffload() bool
|
TXChecksumOffload() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +57,7 @@ type Options struct {
|
||||||
MTU uint32
|
MTU uint32
|
||||||
GSO bool
|
GSO bool
|
||||||
AutoRoute bool
|
AutoRoute bool
|
||||||
|
InterfaceScope bool
|
||||||
Inet4Gateway netip.Addr
|
Inet4Gateway netip.Addr
|
||||||
Inet6Gateway netip.Addr
|
Inet6Gateway netip.Addr
|
||||||
DNSServers []netip.Addr
|
DNSServers []netip.Addr
|
||||||
|
@ -74,6 +78,7 @@ type Options struct {
|
||||||
IncludeAndroidUser []int
|
IncludeAndroidUser []int
|
||||||
IncludePackage []string
|
IncludePackage []string
|
||||||
ExcludePackage []string
|
ExcludePackage []string
|
||||||
|
InterfaceFinder control.InterfaceFinder
|
||||||
InterfaceMonitor DefaultInterfaceMonitor
|
InterfaceMonitor DefaultInterfaceMonitor
|
||||||
FileDescriptor int
|
FileDescriptor int
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
|
@ -99,10 +104,12 @@ func (o *Options) Inet4GatewayAddr() netip.Addr {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
return o.Inet4Address[0].Addr()
|
return o.Inet4Address[0].Addr()
|
||||||
default:
|
default:
|
||||||
if HasNextAddress(o.Inet4Address[0], 1) {
|
if !o.InterfaceScope {
|
||||||
return o.Inet4Address[0].Addr().Next()
|
if HasNextAddress(o.Inet4Address[0], 1) {
|
||||||
} else {
|
return o.Inet4Address[0].Addr().Next()
|
||||||
return o.Inet4Address[0].Addr()
|
} else {
|
||||||
|
return o.Inet4Address[0].Addr()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,10 +130,12 @@ func (o *Options) Inet6GatewayAddr() netip.Addr {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
return o.Inet6Address[0].Addr()
|
return o.Inet6Address[0].Addr()
|
||||||
default:
|
default:
|
||||||
if HasNextAddress(o.Inet6Address[0], 1) {
|
if !o.InterfaceScope {
|
||||||
return o.Inet6Address[0].Addr().Next()
|
if HasNextAddress(o.Inet6Address[0], 1) {
|
||||||
} else {
|
return o.Inet6Address[0].Addr().Next()
|
||||||
return o.Inet6Address[0].Addr()
|
} else {
|
||||||
|
return o.Inet6Address[0].Addr()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
@ -28,7 +29,15 @@ type NativeTun struct {
|
||||||
options Options
|
options Options
|
||||||
inet4Address [4]byte
|
inet4Address [4]byte
|
||||||
inet6Address [16]byte
|
inet6Address [16]byte
|
||||||
routerSet bool
|
routeSet bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) Name() (string, error) {
|
||||||
|
return unix.GetsockoptString(
|
||||||
|
int(t.tunFile.Fd()),
|
||||||
|
2, /* #define SYSPROTO_CONTROL 2 */
|
||||||
|
2, /* #define UTUN_OPT_IFNAME 2 */
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(options Options) (Tun, error) {
|
func New(options Options) (Tun, error) {
|
||||||
|
@ -96,9 +105,10 @@ var (
|
||||||
|
|
||||||
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
||||||
var packetHeader []byte
|
var packetHeader []byte
|
||||||
if buffers[0].Byte(0)>>4 == 4 {
|
switch header.IPVersion(buffers[0].Bytes()) {
|
||||||
|
case header.IPv4Version:
|
||||||
packetHeader = packetHeader4[:]
|
packetHeader = packetHeader4[:]
|
||||||
} else {
|
case header.IPv6Version:
|
||||||
packetHeader = packetHeader6[:]
|
packetHeader = packetHeader6[:]
|
||||||
}
|
}
|
||||||
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
|
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
|
||||||
|
@ -248,44 +258,63 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
|
||||||
|
err := t.unsetRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.options = tunOptions
|
||||||
|
return t.setRoutes()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *NativeTun) setRoutes() error {
|
func (t *NativeTun) setRoutes() error {
|
||||||
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
|
if t.options.FileDescriptor == 0 {
|
||||||
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
|
if len(routeRanges) > 0 {
|
||||||
for _, destination := range routeRanges {
|
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
|
||||||
var gateway netip.Addr
|
for _, destination := range routeRanges {
|
||||||
if destination.Addr().Is4() {
|
var gateway netip.Addr
|
||||||
gateway = gateway4
|
if destination.Addr().Is4() {
|
||||||
} else {
|
gateway = gateway4
|
||||||
gateway = gateway6
|
|
||||||
}
|
|
||||||
err = execRoute(unix.RTM_ADD, destination, gateway)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
err = execRoute(unix.RTM_DELETE, destination, gateway)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "remove existing route: ", destination)
|
|
||||||
}
|
|
||||||
err = execRoute(unix.RTM_ADD, destination, gateway)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "re-add route: ", destination)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return E.Cause(err, "add route: ", destination)
|
gateway = gateway6
|
||||||
|
}
|
||||||
|
var interfaceIndex int
|
||||||
|
if t.options.InterfaceScope {
|
||||||
|
iff, err := t.options.InterfaceFinder.ByName(t.options.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
interfaceIndex = iff.Index
|
||||||
|
}
|
||||||
|
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "remove existing route: ", destination)
|
||||||
|
}
|
||||||
|
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "re-add route: ", destination)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return E.Cause(err, "add route: ", destination)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
flushDNSCache()
|
||||||
|
t.routeSet = true
|
||||||
}
|
}
|
||||||
flushDNSCache()
|
|
||||||
t.routerSet = true
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NativeTun) unsetRoutes() error {
|
func (t *NativeTun) unsetRoutes() error {
|
||||||
if !t.routerSet {
|
if !t.routeSet {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
||||||
|
@ -300,7 +329,7 @@ func (t *NativeTun) unsetRoutes() error {
|
||||||
} else {
|
} else {
|
||||||
gateway = gateway6
|
gateway = gateway6
|
||||||
}
|
}
|
||||||
err = execRoute(unix.RTM_DELETE, destination, gateway)
|
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
|
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
|
||||||
}
|
}
|
||||||
|
@ -317,7 +346,7 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
|
||||||
return block(socketFd)
|
return block(socketFd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
|
func execRoute(rtmType int, interfaceScope bool, interfaceIndex int, destination netip.Prefix, gateway netip.Addr) error {
|
||||||
routeMessage := route.RouteMessage{
|
routeMessage := route.RouteMessage{
|
||||||
Type: rtmType,
|
Type: rtmType,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
|
@ -326,6 +355,10 @@ func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error
|
||||||
}
|
}
|
||||||
if rtmType == unix.RTM_ADD {
|
if rtmType == unix.RTM_ADD {
|
||||||
routeMessage.Flags |= unix.RTF_UP
|
routeMessage.Flags |= unix.RTF_UP
|
||||||
|
if interfaceScope {
|
||||||
|
routeMessage.Flags |= unix.RTF_IFSCOPE
|
||||||
|
routeMessage.Index = interfaceIndex
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if gateway.Is4() {
|
if gateway.Is4() {
|
||||||
routeMessage.Addrs = []route.Addr{
|
routeMessage.Addrs = []route.Addr{
|
||||||
|
|
426
tun_linux.go
426
tun_linux.go
|
@ -2,6 +2,7 @@ package tun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -13,6 +14,8 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sagernet/netlink"
|
"github.com/sagernet/netlink"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
@ -35,13 +38,15 @@ type NativeTun struct {
|
||||||
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
|
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
|
||||||
options Options
|
options Options
|
||||||
ruleIndex6 []int
|
ruleIndex6 []int
|
||||||
gsoEnabled bool
|
readAccess sync.Mutex
|
||||||
gsoBuffer []byte
|
writeAccess sync.Mutex
|
||||||
|
vnetHdr bool
|
||||||
|
writeBuffer []byte
|
||||||
gsoToWrite []int
|
gsoToWrite []int
|
||||||
gsoReadAccess sync.Mutex
|
tcpGROTable *tcpGROTable
|
||||||
tcpGROAccess sync.Mutex
|
udpGroAccess sync.Mutex
|
||||||
tcp4GROTable *tcpGROTable
|
udpGROTable *udpGROTable
|
||||||
tcp6GROTable *tcpGROTable
|
gro groDisablementFlags
|
||||||
txChecksumOffload bool
|
txChecksumOffload bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,105 +85,6 @@ func New(options Options) (Tun, error) {
|
||||||
return nativeTun, nil
|
return nativeTun, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NativeTun) FrontHeadroom() int {
|
|
||||||
if t.gsoEnabled {
|
|
||||||
return virtioNetHdrLen
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
|
||||||
if t.gsoEnabled {
|
|
||||||
n, err = t.tunFile.Read(t.gsoBuffer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var sizes [1]int
|
|
||||||
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n = sizes[0]
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
return t.tunFile.Read(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) Write(p []byte) (n int, err error) {
|
|
||||||
if t.gsoEnabled {
|
|
||||||
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n = len(p)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return t.tunFile.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
|
||||||
if t.gsoEnabled {
|
|
||||||
n := buf.LenMulti(buffers)
|
|
||||||
buffer := buf.NewSize(virtioNetHdrLen + n)
|
|
||||||
buffer.Truncate(virtioNetHdrLen)
|
|
||||||
buf.CopyMulti(buffer.Extend(n), buffers)
|
|
||||||
_, err := t.tunFile.Write(buffer.Bytes())
|
|
||||||
buffer.Release()
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
return t.tunWriter.WriteVectorised(buffers)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) BatchSize() int {
|
|
||||||
if !t.gsoEnabled {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
|
|
||||||
batchSize := int(gsoMaxSize/t.options.MTU) * 2
|
|
||||||
if batchSize > idealBatchSize {
|
|
||||||
batchSize = idealBatchSize
|
|
||||||
}
|
|
||||||
return batchSize*/
|
|
||||||
return idealBatchSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
|
|
||||||
t.gsoReadAccess.Lock()
|
|
||||||
defer t.gsoReadAccess.Unlock()
|
|
||||||
n, err = t.tunFile.Read(t.gsoBuffer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
|
|
||||||
t.tcpGROAccess.Lock()
|
|
||||||
defer func() {
|
|
||||||
t.tcp4GROTable.reset()
|
|
||||||
t.tcp6GROTable.reset()
|
|
||||||
t.tcpGROAccess.Unlock()
|
|
||||||
}()
|
|
||||||
t.gsoToWrite = t.gsoToWrite[:0]
|
|
||||||
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
offset -= virtioNetHdrLen
|
|
||||||
for _, bufferIndex := range t.gsoToWrite {
|
|
||||||
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var controlPath string
|
var controlPath string
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -196,29 +102,26 @@ func open(name string, vnetHdr bool) (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
ifr, err := unix.NewIfreq(name)
|
||||||
var ifr struct {
|
if err != nil {
|
||||||
name [16]byte
|
unix.Close(fd)
|
||||||
flags uint16
|
return 0, err
|
||||||
_ [22]byte
|
|
||||||
}
|
}
|
||||||
|
flags := unix.IFF_TUN | unix.IFF_NO_PI
|
||||||
copy(ifr.name[:], name)
|
|
||||||
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
|
|
||||||
if vnetHdr {
|
if vnetHdr {
|
||||||
ifr.flags |= unix.IFF_VNET_HDR
|
flags |= unix.IFF_VNET_HDR
|
||||||
}
|
}
|
||||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
|
ifr.SetUint16(uint16(flags))
|
||||||
if errno != 0 {
|
err = unix.IoctlIfreq(fd, unix.TUNSETIFF, ifr)
|
||||||
|
if err != nil {
|
||||||
unix.Close(fd)
|
unix.Close(fd)
|
||||||
return -1, errno
|
return 0, err
|
||||||
}
|
}
|
||||||
|
err = unix.SetNonblock(fd, true)
|
||||||
if err = unix.SetNonblock(fd, true); err != nil {
|
if err != nil {
|
||||||
unix.Close(fd)
|
unix.Close(fd)
|
||||||
return -1, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fd, nil
|
return fd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,22 +153,10 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.options.GSO {
|
if t.options.GSO {
|
||||||
var vnetHdrEnabled bool
|
err = t.enableGSO()
|
||||||
vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
|
t.options.Logger.Warn(err)
|
||||||
}
|
}
|
||||||
if !vnetHdrEnabled {
|
|
||||||
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
|
|
||||||
}
|
|
||||||
err = setTCPOffload(t.tunFd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.gsoEnabled = true
|
|
||||||
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
|
|
||||||
t.tcp4GROTable = newTCPGROTable()
|
|
||||||
t.tcp6GROTable = newTCPGROTable()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var rxChecksumOffload bool
|
var rxChecksumOffload bool
|
||||||
|
@ -280,7 +171,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err == nil && !txChecksumOffload {
|
if !txChecksumOffload {
|
||||||
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
|
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -292,6 +183,83 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) enableGSO() error {
|
||||||
|
vnetHdrEnabled, err := checkVNETHDREnabled(t.tunFd, t.options.Name)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
|
||||||
|
}
|
||||||
|
if !vnetHdrEnabled {
|
||||||
|
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
|
||||||
|
}
|
||||||
|
err = setTCPOffload(t.tunFd)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "enable TCP offload")
|
||||||
|
}
|
||||||
|
t.vnetHdr = true
|
||||||
|
t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
|
||||||
|
t.tcpGROTable = newTCPGROTable()
|
||||||
|
t.udpGROTable = newUDPGROTable()
|
||||||
|
err = setUDPOffload(t.tunFd)
|
||||||
|
if err != nil {
|
||||||
|
t.gro.disableUDPGRO()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) probeTCPGRO() error {
|
||||||
|
ipPort := netip.AddrPortFrom(t.options.Inet4Address[0].Addr(), 0)
|
||||||
|
fingerprint := []byte("sing-tun-probe-tun-gro")
|
||||||
|
segmentSize := len(fingerprint)
|
||||||
|
iphLen := 20
|
||||||
|
tcphLen := 20
|
||||||
|
totalLen := iphLen + tcphLen + segmentSize
|
||||||
|
bufs := make([][]byte, 2)
|
||||||
|
for i := range bufs {
|
||||||
|
bufs[i] = make([]byte, virtioNetHdrLen+totalLen, virtioNetHdrLen+(totalLen*2))
|
||||||
|
ipv4H := header.IPv4(bufs[i][virtioNetHdrLen:])
|
||||||
|
ipv4H.Encode(&header.IPv4Fields{
|
||||||
|
SrcAddr: ipPort.Addr(),
|
||||||
|
DstAddr: ipPort.Addr(),
|
||||||
|
Protocol: unix.IPPROTO_TCP,
|
||||||
|
// Use a zero value TTL as best effort means to reduce chance of
|
||||||
|
// probe packet leaking further than it needs to.
|
||||||
|
TTL: 0,
|
||||||
|
TotalLength: uint16(totalLen),
|
||||||
|
})
|
||||||
|
tcpH := header.TCP(bufs[i][virtioNetHdrLen+iphLen:])
|
||||||
|
tcpH.Encode(&header.TCPFields{
|
||||||
|
SrcPort: ipPort.Port(),
|
||||||
|
DstPort: ipPort.Port(),
|
||||||
|
SeqNum: 1 + uint32(i*segmentSize),
|
||||||
|
AckNum: 1,
|
||||||
|
DataOffset: 20,
|
||||||
|
Flags: header.TCPFlagAck,
|
||||||
|
WindowSize: 3000,
|
||||||
|
})
|
||||||
|
copy(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], fingerprint)
|
||||||
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddressSlice(), ipv4H.DestinationAddressSlice(), uint16(tcphLen+segmentSize))
|
||||||
|
pseudoCsum = checksum.Checksum(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], pseudoCsum)
|
||||||
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||||
|
}
|
||||||
|
_, err := t.BatchWrite(bufs, virtioNetHdrLen)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) Name() (string, error) {
|
||||||
|
var ifr [unix.IFNAMSIZ + 64]byte
|
||||||
|
_, _, errno := unix.Syscall(
|
||||||
|
unix.SYS_IOCTL,
|
||||||
|
uintptr(t.tunFd),
|
||||||
|
uintptr(unix.TUNGETIFF),
|
||||||
|
uintptr(unsafe.Pointer(&ifr[0])),
|
||||||
|
)
|
||||||
|
if errno != 0 {
|
||||||
|
return "", os.NewSyscallError("ioctl TUNGETIFF", errno)
|
||||||
|
}
|
||||||
|
return unix.ByteSliceToString(ifr[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *NativeTun) Start() error {
|
func (t *NativeTun) Start() error {
|
||||||
if t.options.FileDescriptor != 0 {
|
if t.options.FileDescriptor != 0 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -307,6 +275,15 @@ func (t *NativeTun) Start() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.vnetHdr && len(t.options.Inet4Address) > 0 {
|
||||||
|
err = t.probeTCPGRO()
|
||||||
|
if err != nil {
|
||||||
|
t.gro.disableTCPGRO()
|
||||||
|
t.gro.disableUDPGRO()
|
||||||
|
t.options.Logger.Warn(E.Cause(err, "disabled TUN TCP & UDP GRO due to GRO probe error"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if t.options.IPRoute2TableIndex == 0 {
|
if t.options.IPRoute2TableIndex == 0 {
|
||||||
for {
|
for {
|
||||||
t.options.IPRoute2TableIndex = int(rand.Uint32())
|
t.options.IPRoute2TableIndex = int(rand.Uint32())
|
||||||
|
@ -348,6 +325,164 @@ func (t *NativeTun) Close() error {
|
||||||
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
|
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
||||||
|
if t.vnetHdr {
|
||||||
|
n, err = t.tunFile.Read(t.writeBuffer)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, syscall.EBADFD) {
|
||||||
|
err = os.ErrClosed
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var sizes [1]int
|
||||||
|
n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n = sizes[0]
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
return t.tunFile.Read(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||||
|
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||||
|
// and returns the number of packets read.
|
||||||
|
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
|
var hdr virtioNetHdr
|
||||||
|
err := hdr.decode(in)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
in = in[virtioNetHdrLen:]
|
||||||
|
|
||||||
|
options, err := hdr.toGSOOptions()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't trust HdrLen from the kernel as it can be equal to the length
|
||||||
|
// of the entire first packet when the kernel is handling it as part of a
|
||||||
|
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||||
|
// CsumStart, which is synonymous for IP header length.
|
||||||
|
if options.GSOType == GSOUDPL4 {
|
||||||
|
options.HdrLen = options.CsumStart + 8
|
||||||
|
} else if options.GSOType != GSONone {
|
||||||
|
if len(in) <= int(options.CsumStart+12) {
|
||||||
|
return 0, errors.New("packet is too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
|
||||||
|
if tcpHLen < 20 || tcpHLen > 60 {
|
||||||
|
// A TCP header must be between 20 and 60 bytes in length.
|
||||||
|
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||||
|
}
|
||||||
|
options.HdrLen = options.CsumStart + tcpHLen
|
||||||
|
}
|
||||||
|
|
||||||
|
return GSOSplit(in, options, bufs, sizes, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) Write(p []byte) (n int, err error) {
|
||||||
|
if t.vnetHdr {
|
||||||
|
buffer := buf.Get(virtioNetHdrLen + len(p))
|
||||||
|
copy(buffer[virtioNetHdrLen:], p)
|
||||||
|
_, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen)
|
||||||
|
buf.Put(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n = len(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return t.tunFile.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
||||||
|
if t.vnetHdr {
|
||||||
|
n := buf.LenMulti(buffers)
|
||||||
|
buffer := buf.NewSize(virtioNetHdrLen + n)
|
||||||
|
buffer.Truncate(virtioNetHdrLen)
|
||||||
|
buf.CopyMulti(buffer.Extend(n), buffers)
|
||||||
|
_, err := t.tunFile.Write(buffer.Bytes())
|
||||||
|
buffer.Release()
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
return t.tunWriter.WriteVectorised(buffers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) FrontHeadroom() int {
|
||||||
|
if t.vnetHdr {
|
||||||
|
return virtioNetHdrLen
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) BatchSize() int {
|
||||||
|
if !t.vnetHdr {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
|
||||||
|
batchSize := int(gsoMaxSize/t.options.MTU) * 2
|
||||||
|
if batchSize > idealBatchSize {
|
||||||
|
batchSize = idealBatchSize
|
||||||
|
}
|
||||||
|
return batchSize*/
|
||||||
|
return idealBatchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
|
||||||
|
t.readAccess.Lock()
|
||||||
|
defer t.readAccess.Unlock()
|
||||||
|
n, err = t.tunFile.Read(t.writeBuffer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) {
|
||||||
|
t.writeAccess.Lock()
|
||||||
|
defer func() {
|
||||||
|
t.tcpGROTable.reset()
|
||||||
|
t.udpGROTable.reset()
|
||||||
|
t.writeAccess.Unlock()
|
||||||
|
}()
|
||||||
|
var (
|
||||||
|
errs error
|
||||||
|
total int
|
||||||
|
)
|
||||||
|
t.gsoToWrite = t.gsoToWrite[:0]
|
||||||
|
if t.vnetHdr {
|
||||||
|
err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
offset -= virtioNetHdrLen
|
||||||
|
} else {
|
||||||
|
for i := range buffers {
|
||||||
|
t.gsoToWrite = append(t.gsoToWrite, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, toWrite := range t.gsoToWrite {
|
||||||
|
n, err := t.tunFile.Write(buffers[toWrite][offset:])
|
||||||
|
if errors.Is(err, syscall.EBADFD) {
|
||||||
|
return total, os.ErrClosed
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errs = errors.Join(errs, err)
|
||||||
|
} else {
|
||||||
|
total += n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, errs
|
||||||
|
}
|
||||||
|
|
||||||
func (t *NativeTun) TXChecksumOffload() bool {
|
func (t *NativeTun) TXChecksumOffload() bool {
|
||||||
return t.txChecksumOffload
|
return t.txChecksumOffload
|
||||||
}
|
}
|
||||||
|
@ -359,6 +494,25 @@ func prefixToIPNet(prefix netip.Prefix) *net.IPNet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
|
||||||
|
if t.options.FileDescriptor > 0 {
|
||||||
|
return nil
|
||||||
|
} else if !t.options.AutoRoute {
|
||||||
|
t.options = tunOptions
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tunLink, err := netlink.LinkByName(t.options.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = t.unsetRoute0(tunLink)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.options = tunOptions
|
||||||
|
return t.setRoute(tunLink)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
|
func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
|
||||||
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -12,6 +12,12 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TODO: support TSO with ECN bits
|
||||||
|
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||||
|
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||||
|
)
|
||||||
|
|
||||||
func checkVNETHDREnabled(fd int, name string) (bool, error) {
|
func checkVNETHDREnabled(fd int, name string) (bool, error) {
|
||||||
ifr, err := unix.NewIfreq(name)
|
ifr, err := unix.NewIfreq(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTCPOffload(fd int) error {
|
func setTCPOffload(fd int) error {
|
||||||
const (
|
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads)
|
||||||
// TODO: support TSO with ECN bits
|
|
||||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
|
||||||
)
|
|
||||||
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
|
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setUDPOffload(fd int) error {
|
||||||
|
return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads)
|
||||||
|
}
|
||||||
|
|
||||||
type ifreqData struct {
|
type ifreqData struct {
|
||||||
ifrName [unix.IFNAMSIZ]byte
|
ifrName [unix.IFNAMSIZ]byte
|
||||||
ifrData uintptr
|
ifrData uintptr
|
||||||
|
|
|
@ -10,11 +10,12 @@ import (
|
||||||
var _ GVisorTun = (*NativeTun)(nil)
|
var _ GVisorTun = (*NativeTun)(nil)
|
||||||
|
|
||||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||||
if t.gsoEnabled {
|
if t.vnetHdr {
|
||||||
return fdbased.New(&fdbased.Options{
|
return fdbased.New(&fdbased.Options{
|
||||||
FDs: []int{t.tunFd},
|
FDs: []int{t.tunFd},
|
||||||
MTU: t.options.MTU,
|
MTU: t.options.MTU,
|
||||||
GSOMaxSize: gsoMaxSize,
|
GSOMaxSize: gsoMaxSize,
|
||||||
|
GRO: true,
|
||||||
RXChecksumOffload: true,
|
RXChecksumOffload: true,
|
||||||
TXChecksumOffload: t.txChecksumOffload,
|
TXChecksumOffload: t.txChecksumOffload,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,768 +0,0 @@
|
||||||
//go:build linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
gsoMaxSize = 65536
|
|
||||||
tcpFlagsOffset = 13
|
|
||||||
idealBatchSize = 128
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
tcpFlagFIN uint8 = 0x01
|
|
||||||
tcpFlagPSH uint8 = 0x08
|
|
||||||
tcpFlagACK uint8 = 0x10
|
|
||||||
)
|
|
||||||
|
|
||||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
|
||||||
// kernel symbol is virtio_net_hdr.
|
|
||||||
type virtioNetHdr struct {
|
|
||||||
flags uint8
|
|
||||||
gsoType uint8
|
|
||||||
hdrLen uint16
|
|
||||||
gsoSize uint16
|
|
||||||
csumStart uint16
|
|
||||||
csumOffset uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) decode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) encode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
|
||||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
|
||||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
|
||||||
)
|
|
||||||
|
|
||||||
// flowKey represents the key for a flow.
|
|
||||||
type flowKey struct {
|
|
||||||
srcAddr, dstAddr [16]byte
|
|
||||||
srcPort, dstPort uint16
|
|
||||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
|
||||||
type tcpGROTable struct {
|
|
||||||
itemsByFlow map[flowKey][]tcpGROItem
|
|
||||||
itemsPool [][]tcpGROItem
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTCPGROTable() *tcpGROTable {
|
|
||||||
t := &tcpGROTable{
|
|
||||||
itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize),
|
|
||||||
itemsPool: make([][]tcpGROItem, idealBatchSize),
|
|
||||||
}
|
|
||||||
for i := range t.itemsPool {
|
|
||||||
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
|
||||||
key := flowKey{}
|
|
||||||
addrSize := dstAddr - srcAddr
|
|
||||||
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
|
||||||
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
|
||||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
|
||||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
|
||||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
|
||||||
// returning the packets found for the flow, or inserting a new one if none
|
|
||||||
// is found.
|
|
||||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if ok {
|
|
||||||
return items, ok
|
|
||||||
}
|
|
||||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
|
||||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert an item in the table for the provided packet and packet metadata.
|
|
||||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
item := tcpGROItem{
|
|
||||||
key: key,
|
|
||||||
bufsIndex: uint16(bufsIndex),
|
|
||||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
|
||||||
iphLen: uint8(tcphOffset),
|
|
||||||
tcphLen: uint8(tcphLen),
|
|
||||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
|
||||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
|
||||||
}
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if !ok {
|
|
||||||
items = t.newItems()
|
|
||||||
}
|
|
||||||
items = append(items, item)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
|
||||||
items, _ := t.itemsByFlow[item.key]
|
|
||||||
items[i] = item
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
|
||||||
items, _ := t.itemsByFlow[key]
|
|
||||||
items = append(items[:i], items[i+1:]...)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
|
||||||
// of a GRO evaluation across a vector of packets.
|
|
||||||
type tcpGROItem struct {
|
|
||||||
key flowKey
|
|
||||||
sentSeq uint32 // the sequence number
|
|
||||||
bufsIndex uint16 // the index into the original bufs slice
|
|
||||||
numMerged uint16 // the number of packets merged into this item
|
|
||||||
gsoSize uint16 // payload size
|
|
||||||
iphLen uint8 // ip header len
|
|
||||||
tcphLen uint8 // tcp header len
|
|
||||||
pshSet bool // psh flag is set
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
|
||||||
var items []tcpGROItem
|
|
||||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
|
||||||
return items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) reset() {
|
|
||||||
for k, items := range t.itemsByFlow {
|
|
||||||
items = items[:0]
|
|
||||||
t.itemsPool = append(t.itemsPool, items)
|
|
||||||
delete(t.itemsByFlow, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
|
||||||
// candidates for coalescing.
|
|
||||||
type canCoalesce int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalescePrepend canCoalesce = -1
|
|
||||||
coalesceUnavailable canCoalesce = 0
|
|
||||||
coalesceAppend canCoalesce = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
|
||||||
// described by item. This function makes considerations that match the kernel's
|
|
||||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
|
||||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
|
||||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if tcphLen != item.tcphLen {
|
|
||||||
// cannot coalesce with unequal tcp options len
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if tcphLen > 20 {
|
|
||||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
|
||||||
// cannot coalesce with unequal tcp options
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pkt[0]>>4 == 6 {
|
|
||||||
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
|
||||||
// cannot coalesce with unequal Traffic class values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[7] != pktTarget[7] {
|
|
||||||
// cannot coalesce with unequal Hop limit values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if pkt[1] != pktTarget[1] {
|
|
||||||
// cannot coalesce with unequal ToS values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[6]>>5 != pktTarget[6]>>5 {
|
|
||||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
|
||||||
// further up the stack.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[8] != pktTarget[8] {
|
|
||||||
// cannot coalesce with unequal TTL values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// seq adjacency
|
|
||||||
lhsLen := item.gsoSize
|
|
||||||
lhsLen += item.numMerged * item.gsoSize
|
|
||||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
|
||||||
if item.pshSet {
|
|
||||||
// We cannot append to a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
|
||||||
// A smaller than gsoSize packet has been appended previously.
|
|
||||||
// Nothing can come after a smaller packet on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalesceAppend
|
|
||||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
|
||||||
if pshSet {
|
|
||||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize < item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
|
||||||
// There's at least one previous merge, and we're larger than all
|
|
||||||
// previous. This would put multiple smaller packets on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalescePrepend
|
|
||||||
}
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
|
||||||
srcAddrAt := ipv4SrcAddrOffset
|
|
||||||
addrSize := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrAt = ipv6SrcAddrOffset
|
|
||||||
addrSize = 16
|
|
||||||
}
|
|
||||||
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
|
||||||
return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
|
||||||
// packets.
|
|
||||||
type coalesceResult int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalesceInsufficientCap coalesceResult = iota
|
|
||||||
coalescePSHEnding
|
|
||||||
coalesceItemInvalidCSum
|
|
||||||
coalescePktInvalidCSum
|
|
||||||
coalesceSuccess
|
|
||||||
)
|
|
||||||
|
|
||||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
|
||||||
// item, returning the outcome. This function may swap bufs elements in the
|
|
||||||
// event of a prepend as item's bufs index is already being tracked for writing
|
|
||||||
// to a Device.
|
|
||||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
|
||||||
var pktHead []byte // the packet that will end up at the front
|
|
||||||
headersLen := item.iphLen + item.tcphLen
|
|
||||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
|
||||||
|
|
||||||
// Copy data
|
|
||||||
if mode == coalescePrepend {
|
|
||||||
pktHead = pkt
|
|
||||||
if cap(pkt)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
return coalescePSHEnding
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
item.sentSeq = seq
|
|
||||||
extendBy := coalescedLen - len(pktHead)
|
|
||||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
|
||||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
|
||||||
// is already being tracked for writing.
|
|
||||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
|
||||||
} else {
|
|
||||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
// We are appending a segment with PSH set.
|
|
||||||
item.pshSet = pshSet
|
|
||||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
|
||||||
}
|
|
||||||
extendBy := len(pkt) - int(headersLen)
|
|
||||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
item.gsoSize = gsoSize
|
|
||||||
}
|
|
||||||
|
|
||||||
item.numMerged++
|
|
||||||
return coalesceSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4FlagMoreFragments uint8 = 0x20
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4SrcAddrOffset = 12
|
|
||||||
ipv6SrcAddrOffset = 8
|
|
||||||
maxUint16 = 1<<16 - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
type tcpGROResult int
|
|
||||||
|
|
||||||
const (
|
|
||||||
tcpGROResultNoop tcpGROResult = iota
|
|
||||||
tcpGROResultTableInsert
|
|
||||||
tcpGROResultCoalesced
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
|
||||||
// existing packets tracked in table. It returns a tcpGROResultNoop when no
|
|
||||||
// action was taken, tcpGROResultTableInsert when the evaluated packet was
|
|
||||||
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
|
|
||||||
// coalesced with another packet in table.
|
|
||||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
|
|
||||||
pkt := bufs[pktI][offset:]
|
|
||||||
if len(pkt) > maxUint16 {
|
|
||||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
|
||||||
if isV6 {
|
|
||||||
iphLen = 40
|
|
||||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
|
||||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
|
||||||
if totalLen != len(pkt) {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
|
||||||
if tcphLen < 20 || tcphLen > 60 {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen+tcphLen {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
if !isV6 {
|
|
||||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
|
||||||
// no GRO support for fragmented segments for now
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
|
||||||
var pshSet bool
|
|
||||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
|
||||||
if tcpFlags != tcpFlagACK {
|
|
||||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
pshSet = true
|
|
||||||
}
|
|
||||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
|
||||||
// not a candidate if payload len is 0
|
|
||||||
if gsoSize < 1 {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
|
||||||
srcAddrOffset := ipv4SrcAddrOffset
|
|
||||||
addrLen := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrOffset = ipv6SrcAddrOffset
|
|
||||||
addrLen = 16
|
|
||||||
}
|
|
||||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
if !existing {
|
|
||||||
return tcpGROResultNoop
|
|
||||||
}
|
|
||||||
for i := len(items) - 1; i >= 0; i-- {
|
|
||||||
// In the best case of packets arriving in order iterating in reverse is
|
|
||||||
// more efficient if there are multiple items for a given flow. This
|
|
||||||
// also enables a natural table.deleteAt() in the
|
|
||||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
|
||||||
// This algorithm makes a best effort to coalesce in the event of
|
|
||||||
// unordered packets, where pkt may land anywhere in items from a
|
|
||||||
// sequence number perspective, however once an item is inserted into
|
|
||||||
// the table it is never compared across other items later.
|
|
||||||
item := items[i]
|
|
||||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
|
||||||
if can != coalesceUnavailable {
|
|
||||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
|
||||||
switch result {
|
|
||||||
case coalesceSuccess:
|
|
||||||
table.updateAt(item, i)
|
|
||||||
return tcpGROResultCoalesced
|
|
||||||
case coalesceItemInvalidCSum:
|
|
||||||
// delete the item with an invalid csum
|
|
||||||
table.deleteAt(item.key, i)
|
|
||||||
case coalescePktInvalidCSum:
|
|
||||||
// no point in inserting an item that we can't coalesce
|
|
||||||
return tcpGROResultNoop
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// failed to coalesce with any other packets; store the item in the flow
|
|
||||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
return tcpGROResultTableInsert
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP4NoIPOptions(b []byte) bool {
|
|
||||||
if len(b) < 40 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]&0x0F != 5 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[9] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP6NoEH(b []byte) bool {
|
|
||||||
if len(b) < 60 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 6 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[6] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyCoalesceAccounting updates bufs to account for coalescing based on the
|
|
||||||
// metadata found in table.
|
|
||||||
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
|
|
||||||
for _, items := range table.itemsByFlow {
|
|
||||||
for _, item := range items {
|
|
||||||
if item.numMerged > 0 {
|
|
||||||
hdr := virtioNetHdr{
|
|
||||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
|
||||||
hdrLen: uint16(item.iphLen + item.tcphLen),
|
|
||||||
gsoSize: item.gsoSize,
|
|
||||||
csumStart: uint16(item.iphLen),
|
|
||||||
csumOffset: 16,
|
|
||||||
}
|
|
||||||
pkt := bufs[item.bufsIndex][offset:]
|
|
||||||
|
|
||||||
// Recalculate the total len (IPv4) or payload len (IPv6).
|
|
||||||
// Recalculate the (IPv4) header checksum.
|
|
||||||
if isV6 {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
|
||||||
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
|
||||||
} else {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
|
||||||
pkt[10], pkt[11] = 0, 0
|
|
||||||
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
|
||||||
iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
|
||||||
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
|
||||||
}
|
|
||||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the pseudo header checksum and place it at the TCP
|
|
||||||
// checksum offset. Downstream checksum offloading will combine
|
|
||||||
// this with computation of the tcp header and payload checksum.
|
|
||||||
addrLen := 4
|
|
||||||
addrOffset := ipv4SrcAddrOffset
|
|
||||||
if isV6 {
|
|
||||||
addrLen = 16
|
|
||||||
addrOffset = ipv6SrcAddrOffset
|
|
||||||
}
|
|
||||||
srcAddrAt := offset + addrOffset
|
|
||||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
|
||||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
|
||||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
|
||||||
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum))
|
|
||||||
} else {
|
|
||||||
hdr := virtioNetHdr{}
|
|
||||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
|
||||||
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
|
||||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
|
||||||
// and recycle them across vectors of packets.
|
|
||||||
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
|
||||||
for i := range bufs {
|
|
||||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
|
||||||
return errors.New("invalid offset")
|
|
||||||
}
|
|
||||||
var result tcpGROResult
|
|
||||||
switch {
|
|
||||||
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
|
||||||
result = tcpGRO(bufs, offset, i, tcp4Table, false)
|
|
||||||
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
|
||||||
result = tcpGRO(bufs, offset, i, tcp6Table, true)
|
|
||||||
}
|
|
||||||
switch result {
|
|
||||||
case tcpGROResultNoop:
|
|
||||||
hdr := virtioNetHdr{}
|
|
||||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
case tcpGROResultTableInsert:
|
|
||||||
*toWrite = append(*toWrite, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
|
|
||||||
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
|
|
||||||
return E.Errors(err4, err6)
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
|
||||||
// element into sizes. It returns the number of buffers populated, and/or an
|
|
||||||
// error.
|
|
||||||
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
|
||||||
iphLen := int(hdr.csumStart)
|
|
||||||
srcAddrOffset := ipv6SrcAddrOffset
|
|
||||||
addrLen := 16
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
|
||||||
srcAddrOffset = ipv4SrcAddrOffset
|
|
||||||
addrLen = 4
|
|
||||||
}
|
|
||||||
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
|
||||||
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
|
||||||
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
|
||||||
nextSegmentDataAt := int(hdr.hdrLen)
|
|
||||||
i := 0
|
|
||||||
for ; nextSegmentDataAt < len(in); i++ {
|
|
||||||
if i == len(outBuffs) {
|
|
||||||
return i - 1, ErrTooManySegments
|
|
||||||
}
|
|
||||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
|
||||||
if nextSegmentEnd > len(in) {
|
|
||||||
nextSegmentEnd = len(in)
|
|
||||||
}
|
|
||||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
|
||||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
|
||||||
sizes[i] = totalLen
|
|
||||||
out := outBuffs[i][outOffset:]
|
|
||||||
|
|
||||||
copy(out, in[:iphLen])
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
// For IPv4 we are responsible for incrementing the ID field,
|
|
||||||
// updating the total len field, and recalculating the header
|
|
||||||
// checksum.
|
|
||||||
if i > 0 {
|
|
||||||
id := binary.BigEndian.Uint16(out[4:])
|
|
||||||
id += uint16(i)
|
|
||||||
binary.BigEndian.PutUint16(out[4:], id)
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
|
||||||
ipv4CSum := ^checksumFold(out[:iphLen], 0)
|
|
||||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
|
||||||
} else {
|
|
||||||
// For IPv6 we are responsible for updating the payload length field.
|
|
||||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCP header
|
|
||||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
|
||||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
|
||||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
|
||||||
if nextSegmentEnd != len(in) {
|
|
||||||
// FIN and PSH should only be set on last segment
|
|
||||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
|
||||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
|
||||||
}
|
|
||||||
|
|
||||||
// payload
|
|
||||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
|
||||||
|
|
||||||
// TCP checksum
|
|
||||||
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
|
||||||
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
|
||||||
tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
|
||||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
|
||||||
|
|
||||||
nextSegmentDataAt += int(hdr.gsoSize)
|
|
||||||
}
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
|
||||||
cSumAt := cSumStart + cSumOffset
|
|
||||||
// The initial value at the checksum offset should be summed with the
|
|
||||||
// checksum we compute. This is typically the pseudo-header checksum.
|
|
||||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
|
||||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
|
||||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
|
||||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
|
||||||
// and returns the number of packets read.
|
|
||||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
|
||||||
var hdr virtioNetHdr
|
|
||||||
err := hdr.decode(in)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
in = in[virtioNetHdrLen:]
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
|
||||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
|
||||||
// This means CHECKSUM_PARTIAL in skb context. We are responsible
|
|
||||||
// for computing the checksum starting at hdr.csumStart and placing
|
|
||||||
// at hdr.csumOffset.
|
|
||||||
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(in) > len(bufs[0][offset:]) {
|
|
||||||
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
|
||||||
}
|
|
||||||
n := copy(bufs[0][offset:], in)
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
|
||||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipVersion := in[0] >> 4
|
|
||||||
switch ipVersion {
|
|
||||||
case 4:
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
|
||||||
}
|
|
||||||
case 6:
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in) <= int(hdr.csumStart+12) {
|
|
||||||
return 0, errors.New("packet is too short")
|
|
||||||
}
|
|
||||||
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
|
||||||
// of the entire first packet when the kernel is handling it as part of a
|
|
||||||
// FORWARD path. Instead, parse the TCP header length and add it onto
|
|
||||||
// csumStart, which is synonymous for IP header length.
|
|
||||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
|
||||||
if tcpHLen < 20 || tcpHLen > 60 {
|
|
||||||
// A TCP header must be between 20 and 60 bytes in length.
|
|
||||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
|
||||||
}
|
|
||||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
|
||||||
|
|
||||||
if len(in) < int(hdr.hdrLen) {
|
|
||||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.hdrLen < hdr.csumStart {
|
|
||||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
|
||||||
}
|
|
||||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
|
||||||
if cSumAt+1 >= len(in) {
|
|
||||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpTSO(in, hdr, bufs, sizes, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
|
||||||
return uint64(checksum.Checksum(b, uint16(initial)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func checksumFold(b []byte, initial uint64) uint16 {
|
|
||||||
ac := checksumNoFold(b, initial)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
return uint16(ac)
|
|
||||||
}
|
|
||||||
|
|
||||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
|
||||||
sum := checksumNoFold(srcAddr, 0)
|
|
||||||
sum = checksumNoFold(dstAddr, sum)
|
|
||||||
sum = checksumNoFold([]byte{0, protocol}, sum)
|
|
||||||
tmp := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
|
||||||
return checksumNoFold(tmp, sum)
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
package tun
|
|
||||||
|
|
||||||
import E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
|
|
||||||
var ErrTooManySegments = E.New("too many segments")
|
|
229
tun_offload.go
Normal file
229
tun_offload.go
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
gsoMaxSize = 65536
|
||||||
|
idealBatchSize = 128
|
||||||
|
)
|
||||||
|
|
||||||
|
// GSOType represents the type of segmentation offload.
|
||||||
|
type GSOType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
GSONone GSOType = iota
|
||||||
|
GSOTCPv4
|
||||||
|
GSOTCPv6
|
||||||
|
GSOUDPL4
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g GSOType) String() string {
|
||||||
|
switch g {
|
||||||
|
case GSONone:
|
||||||
|
return "GSONone"
|
||||||
|
case GSOTCPv4:
|
||||||
|
return "GSOTCPv4"
|
||||||
|
case GSOTCPv6:
|
||||||
|
return "GSOTCPv6"
|
||||||
|
case GSOUDPL4:
|
||||||
|
return "GSOUDPL4"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO
|
||||||
|
// specification. It is a common representation of GSO metadata that can be
|
||||||
|
// applied to support packet GSO across tun.Device implementations.
|
||||||
|
type GSOOptions struct {
|
||||||
|
// GSOType represents the type of segmentation offload.
|
||||||
|
GSOType GSOType
|
||||||
|
// HdrLen is the sum of the layer 3 and 4 header lengths. This field may be
|
||||||
|
// zero when GSOType == GSONone.
|
||||||
|
HdrLen uint16
|
||||||
|
// CsumStart is the head byte index of the packet data to be checksummed,
|
||||||
|
// i.e. the start of the TCP or UDP header.
|
||||||
|
CsumStart uint16
|
||||||
|
// CsumOffset is the offset from CsumStart where the 2-byte checksum value
|
||||||
|
// should be placed.
|
||||||
|
CsumOffset uint16
|
||||||
|
// GSOSize is the size of each segment exclusive of HdrLen. The tail segment
|
||||||
|
// may be smaller than this value.
|
||||||
|
GSOSize uint16
|
||||||
|
// NeedsCsum may be set where GSOType == GSONone. When set, the checksum
|
||||||
|
// at CsumStart + CsumOffset must be a partial checksum, i.e. the
|
||||||
|
// pseudo-header sum.
|
||||||
|
NeedsCsum bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4SrcAddrOffset = 12
|
||||||
|
ipv6SrcAddrOffset = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
const tcpFlagsOffset = 13
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpFlagFIN uint8 = 0x01
|
||||||
|
tcpFlagPSH uint8 = 0x08
|
||||||
|
tcpFlagACK uint8 = 0x10
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defined here in order to avoid importation of any platform-specific pkgs
|
||||||
|
ipProtoTCP = 6
|
||||||
|
ipProtoUDP = 17
|
||||||
|
)
|
||||||
|
|
||||||
|
// GSOSplit splits packets from 'in' into outBufs[<index>][outOffset:], writing
|
||||||
|
// the size of each element into sizes. It returns the number of buffers
|
||||||
|
// populated, and/or an error. Callers may pass an 'in' slice that overlaps with
|
||||||
|
// the first element of outBuffers, i.e. &in[0] may be equal to
|
||||||
|
// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the
|
||||||
|
// value of options.NeedsCsum. Length of each outBufs element must be greater
|
||||||
|
// than or equal to the length of 'in', otherwise output may be silently
|
||||||
|
// truncated.
|
||||||
|
func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) {
|
||||||
|
cSumAt := int(options.CsumStart) + int(options.CsumOffset)
|
||||||
|
if cSumAt+1 >= len(in) {
|
||||||
|
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in) < int(options.HdrLen) {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the conditions where we are copying a single element to outBuffs.
|
||||||
|
payloadLen := len(in) - int(options.HdrLen)
|
||||||
|
if options.GSOType == GSONone || payloadLen < int(options.GSOSize) {
|
||||||
|
if len(in) > len(outBufs[0][outOffset:]) {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:]))
|
||||||
|
}
|
||||||
|
if options.NeedsCsum {
|
||||||
|
// The initial value at the checksum offset should be summed with
|
||||||
|
// the checksum we compute. This is typically the pseudo-header sum.
|
||||||
|
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||||
|
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum.Checksum(in[options.CsumStart:], initial))
|
||||||
|
}
|
||||||
|
sizes[0] = copy(outBufs[0][outOffset:], in)
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.HdrLen < options.CsumStart {
|
||||||
|
return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipVersion := in[0] >> 4
|
||||||
|
switch ipVersion {
|
||||||
|
case 4:
|
||||||
|
if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 {
|
||||||
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
|
||||||
|
}
|
||||||
|
if len(in) < 20 {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20)
|
||||||
|
}
|
||||||
|
case 6:
|
||||||
|
if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 {
|
||||||
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
|
||||||
|
}
|
||||||
|
if len(in) < 40 {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
iphLen := int(options.CsumStart)
|
||||||
|
srcAddrOffset := ipv6SrcAddrOffset
|
||||||
|
addrLen := 16
|
||||||
|
if ipVersion == 4 {
|
||||||
|
srcAddrOffset = ipv4SrcAddrOffset
|
||||||
|
addrLen = 4
|
||||||
|
}
|
||||||
|
transportCsumAt := int(options.CsumStart + options.CsumOffset)
|
||||||
|
var firstTCPSeqNum uint32
|
||||||
|
var protocol uint8
|
||||||
|
if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 {
|
||||||
|
protocol = ipProtoTCP
|
||||||
|
if len(in) < int(options.CsumStart)+20 {
|
||||||
|
return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)",
|
||||||
|
len(in), options.CsumStart, 20)
|
||||||
|
}
|
||||||
|
firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:])
|
||||||
|
} else {
|
||||||
|
protocol = ipProtoUDP
|
||||||
|
}
|
||||||
|
nextSegmentDataAt := int(options.HdrLen)
|
||||||
|
i := 0
|
||||||
|
for ; nextSegmentDataAt < len(in); i++ {
|
||||||
|
if i == len(outBufs) {
|
||||||
|
return i - 1, ErrTooManySegments
|
||||||
|
}
|
||||||
|
nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize)
|
||||||
|
if nextSegmentEnd > len(in) {
|
||||||
|
nextSegmentEnd = len(in)
|
||||||
|
}
|
||||||
|
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||||
|
totalLen := int(options.HdrLen) + segmentDataLen
|
||||||
|
sizes[i] = totalLen
|
||||||
|
out := outBufs[i][outOffset:]
|
||||||
|
|
||||||
|
copy(out, in[:iphLen])
|
||||||
|
if ipVersion == 4 {
|
||||||
|
// For IPv4 we are responsible for incrementing the ID field,
|
||||||
|
// updating the total len field, and recalculating the header
|
||||||
|
// checksum.
|
||||||
|
if i > 0 {
|
||||||
|
id := binary.BigEndian.Uint16(out[4:])
|
||||||
|
id += uint16(i)
|
||||||
|
binary.BigEndian.PutUint16(out[4:], id)
|
||||||
|
}
|
||||||
|
out[10], out[11] = 0, 0 // clear ipv4 header checksum
|
||||||
|
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||||
|
ipv4CSum := ^checksum.Checksum(out[:iphLen], 0)
|
||||||
|
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||||
|
} else {
|
||||||
|
// For IPv6 we are responsible for updating the payload length field.
|
||||||
|
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy transport header
|
||||||
|
copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen])
|
||||||
|
|
||||||
|
if protocol == ipProtoTCP {
|
||||||
|
// set TCP seq and adjust TCP flags
|
||||||
|
tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i))
|
||||||
|
binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq)
|
||||||
|
if nextSegmentEnd != len(in) {
|
||||||
|
// FIN and PSH should only be set on last segment
|
||||||
|
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||||
|
out[options.CsumStart+tcpFlagsOffset] &^= clearFlags
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// set UDP header len
|
||||||
|
binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload
|
||||||
|
copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||||
|
|
||||||
|
// transport checksum
|
||||||
|
out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
|
||||||
|
transportHeaderLen := int(options.HdrLen - options.CsumStart)
|
||||||
|
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
|
||||||
|
transportCSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(protocol), in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
|
||||||
|
transportCSum = ^checksum.Checksum(out[options.CsumStart:totalLen], transportCSum)
|
||||||
|
binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum)
|
||||||
|
|
||||||
|
nextSegmentDataAt += int(options.GSOSize)
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
|
}
|
10
tun_offload_errors.go
Normal file
10
tun_offload_errors.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrTooManySegments is returned by Device.Read() when segmentation
|
||||||
|
// overflows the length of supplied buffers. This error should not cause
|
||||||
|
// reads to cease.
|
||||||
|
var ErrTooManySegments = errors.New("too many segments")
|
937
tun_offload_linux.go
Normal file
937
tun_offload_linux.go
Normal file
|
@ -0,0 +1,937 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||||
|
// kernel symbol is virtio_net_hdr.
|
||||||
|
type virtioNetHdr struct {
|
||||||
|
flags uint8
|
||||||
|
gsoType uint8
|
||||||
|
hdrLen uint16
|
||||||
|
gsoSize uint16
|
||||||
|
csumStart uint16
|
||||||
|
csumOffset uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) {
|
||||||
|
var gsoType GSOType
|
||||||
|
switch v.gsoType {
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_NONE:
|
||||||
|
gsoType = GSONone
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
|
||||||
|
gsoType = GSOTCPv4
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||||
|
gsoType = GSOTCPv6
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
|
||||||
|
gsoType = GSOUDPL4
|
||||||
|
default:
|
||||||
|
return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType)
|
||||||
|
}
|
||||||
|
return GSOOptions{
|
||||||
|
GSOType: gsoType,
|
||||||
|
HdrLen: v.hdrLen,
|
||||||
|
CsumStart: v.csumStart,
|
||||||
|
CsumOffset: v.csumOffset,
|
||||||
|
GSOSize: v.gsoSize,
|
||||||
|
NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) decode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) encode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||||
|
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||||
|
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpFlowKey represents the key for a TCP flow.
|
||||||
|
type tcpFlowKey struct {
|
||||||
|
srcAddr, dstAddr [16]byte
|
||||||
|
srcPort, dstPort uint16
|
||||||
|
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||||
|
isV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
|
||||||
|
type tcpGROTable struct {
|
||||||
|
itemsByFlow map[tcpFlowKey][]tcpGROItem
|
||||||
|
itemsPool [][]tcpGROItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPGROTable() *tcpGROTable {
|
||||||
|
t := &tcpGROTable{
|
||||||
|
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, idealBatchSize),
|
||||||
|
itemsPool: make([][]tcpGROItem, idealBatchSize),
|
||||||
|
}
|
||||||
|
for i := range t.itemsPool {
|
||||||
|
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
|
||||||
|
key := tcpFlowKey{}
|
||||||
|
addrSize := dstAddrOffset - srcAddrOffset
|
||||||
|
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||||
|
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||||
|
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||||
|
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||||
|
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||||
|
key.isV6 = addrSize == 16
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||||
|
// returning the packets found for the flow, or inserting a new one if none
|
||||||
|
// is found.
|
||||||
|
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||||
|
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if ok {
|
||||||
|
return items, ok
|
||||||
|
}
|
||||||
|
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||||
|
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert an item in the table for the provided packet and packet metadata.
|
||||||
|
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||||
|
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
item := tcpGROItem{
|
||||||
|
key: key,
|
||||||
|
bufsIndex: uint16(bufsIndex),
|
||||||
|
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||||
|
iphLen: uint8(tcphOffset),
|
||||||
|
tcphLen: uint8(tcphLen),
|
||||||
|
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||||
|
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||||
|
}
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if !ok {
|
||||||
|
items = t.newItems()
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||||
|
items, _ := t.itemsByFlow[item.key]
|
||||||
|
items[i] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
|
||||||
|
items, _ := t.itemsByFlow[key]
|
||||||
|
items = append(items[:i], items[i+1:]...)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||||
|
// of a GRO evaluation across a vector of packets.
|
||||||
|
type tcpGROItem struct {
|
||||||
|
key tcpFlowKey
|
||||||
|
sentSeq uint32 // the sequence number
|
||||||
|
bufsIndex uint16 // the index into the original bufs slice
|
||||||
|
numMerged uint16 // the number of packets merged into this item
|
||||||
|
gsoSize uint16 // payload size
|
||||||
|
iphLen uint8 // ip header len
|
||||||
|
tcphLen uint8 // tcp header len
|
||||||
|
pshSet bool // psh flag is set
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||||
|
var items []tcpGROItem
|
||||||
|
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) reset() {
|
||||||
|
for k, items := range t.itemsByFlow {
|
||||||
|
items = items[:0]
|
||||||
|
t.itemsPool = append(t.itemsPool, items)
|
||||||
|
delete(t.itemsByFlow, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpFlowKey represents the key for a UDP flow.
|
||||||
|
type udpFlowKey struct {
|
||||||
|
srcAddr, dstAddr [16]byte
|
||||||
|
srcPort, dstPort uint16
|
||||||
|
isV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
|
||||||
|
type udpGROTable struct {
|
||||||
|
itemsByFlow map[udpFlowKey][]udpGROItem
|
||||||
|
itemsPool [][]udpGROItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPGROTable() *udpGROTable {
|
||||||
|
u := &udpGROTable{
|
||||||
|
itemsByFlow: make(map[udpFlowKey][]udpGROItem, idealBatchSize),
|
||||||
|
itemsPool: make([][]udpGROItem, idealBatchSize),
|
||||||
|
}
|
||||||
|
for i := range u.itemsPool {
|
||||||
|
u.itemsPool[i] = make([]udpGROItem, 0, idealBatchSize)
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
|
||||||
|
key := udpFlowKey{}
|
||||||
|
addrSize := dstAddrOffset - srcAddrOffset
|
||||||
|
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||||
|
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||||
|
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
|
||||||
|
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
|
||||||
|
key.isV6 = addrSize == 16
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||||
|
// returning the packets found for the flow, or inserting a new one if none
|
||||||
|
// is found.
|
||||||
|
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
|
||||||
|
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||||
|
items, ok := u.itemsByFlow[key]
|
||||||
|
if ok {
|
||||||
|
return items, ok
|
||||||
|
}
|
||||||
|
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||||
|
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert an item in the table for the provided packet and packet metadata.
|
||||||
|
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
|
||||||
|
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||||
|
item := udpGROItem{
|
||||||
|
key: key,
|
||||||
|
bufsIndex: uint16(bufsIndex),
|
||||||
|
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
|
||||||
|
iphLen: uint8(udphOffset),
|
||||||
|
cSumKnownInvalid: cSumKnownInvalid,
|
||||||
|
}
|
||||||
|
items, ok := u.itemsByFlow[key]
|
||||||
|
if !ok {
|
||||||
|
items = u.newItems()
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
u.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
|
||||||
|
items, _ := u.itemsByFlow[item.key]
|
||||||
|
items[i] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
|
||||||
|
// of a GRO evaluation across a vector of packets.
|
||||||
|
type udpGROItem struct {
|
||||||
|
key udpFlowKey
|
||||||
|
bufsIndex uint16 // the index into the original bufs slice
|
||||||
|
numMerged uint16 // the number of packets merged into this item
|
||||||
|
gsoSize uint16 // payload size
|
||||||
|
iphLen uint8 // ip header len
|
||||||
|
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) newItems() []udpGROItem {
|
||||||
|
var items []udpGROItem
|
||||||
|
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) reset() {
|
||||||
|
for k, items := range u.itemsByFlow {
|
||||||
|
items = items[:0]
|
||||||
|
u.itemsPool = append(u.itemsPool, items)
|
||||||
|
delete(u.itemsByFlow, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||||
|
// candidates for coalescing.
|
||||||
|
type canCoalesce int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalescePrepend canCoalesce = -1
|
||||||
|
coalesceUnavailable canCoalesce = 0
|
||||||
|
coalesceAppend canCoalesce = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
|
||||||
|
// meet all requirements to be merged as part of a GRO operation, otherwise it
|
||||||
|
// returns false.
|
||||||
|
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
|
||||||
|
if len(pktA) < 9 || len(pktB) < 9 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[0]>>4 == 6 {
|
||||||
|
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
|
||||||
|
// cannot coalesce with unequal Traffic class values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[7] != pktB[7] {
|
||||||
|
// cannot coalesce with unequal Hop limit values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pktA[1] != pktB[1] {
|
||||||
|
// cannot coalesce with unequal ToS values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[6]>>5 != pktB[6]>>5 {
|
||||||
|
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||||
|
// further up the stack.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[8] != pktB[8] {
|
||||||
|
// cannot coalesce with unequal TTL values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||||
|
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
|
||||||
|
// packets involved in the current GRO evaluation. bufsOffset is the offset at
|
||||||
|
// which packet data begins within bufs.
|
||||||
|
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||||
|
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
|
||||||
|
// A smaller than gsoSize packet has been appended previously.
|
||||||
|
// Nothing can come after a smaller packet on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalesceAppend
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||||
|
// described by item. This function makes considerations that match the kernel's
|
||||||
|
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||||
|
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||||
|
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if tcphLen != item.tcphLen {
|
||||||
|
// cannot coalesce with unequal tcp options len
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if tcphLen > 20 {
|
||||||
|
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||||
|
// cannot coalesce with unequal tcp options
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
// seq adjacency
|
||||||
|
lhsLen := item.gsoSize
|
||||||
|
lhsLen += item.numMerged * item.gsoSize
|
||||||
|
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||||
|
if item.pshSet {
|
||||||
|
// We cannot append to a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||||
|
// A smaller than gsoSize packet has been appended previously.
|
||||||
|
// Nothing can come after a smaller packet on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalesceAppend
|
||||||
|
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||||
|
if pshSet {
|
||||||
|
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize < item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||||
|
// There's at least one previous merge, and we're larger than all
|
||||||
|
// previous. This would put multiple smaller packets on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalescePrepend
|
||||||
|
}
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
|
||||||
|
srcAddrAt := ipv4SrcAddrOffset
|
||||||
|
addrSize := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrAt = ipv6SrcAddrOffset
|
||||||
|
addrSize = 16
|
||||||
|
}
|
||||||
|
lenForPseudo := uint16(len(pkt) - int(iphLen))
|
||||||
|
cSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(proto), pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
|
||||||
|
return ^checksum.Checksum(pkt[iphLen:], cSum) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||||
|
// packets.
|
||||||
|
type coalesceResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalesceInsufficientCap coalesceResult = iota
|
||||||
|
coalescePSHEnding
|
||||||
|
coalesceItemInvalidCSum
|
||||||
|
coalescePktInvalidCSum
|
||||||
|
coalesceSuccess
|
||||||
|
)
|
||||||
|
|
||||||
|
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
|
||||||
|
// item, and returns the outcome.
|
||||||
|
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||||
|
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
|
||||||
|
headersLen := item.iphLen + udphLen
|
||||||
|
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||||
|
|
||||||
|
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
extendBy := len(pkt) - int(headersLen)
|
||||||
|
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||||
|
|
||||||
|
item.numMerged++
|
||||||
|
return coalesceSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||||
|
// item, and returns the outcome. This function may swap bufs elements in the
|
||||||
|
// event of a prepend as item's bufs index is already being tracked for writing
|
||||||
|
// to a Device.
|
||||||
|
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||||
|
var pktHead []byte // the packet that will end up at the front
|
||||||
|
headersLen := item.iphLen + item.tcphLen
|
||||||
|
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||||
|
|
||||||
|
// Copy data
|
||||||
|
if mode == coalescePrepend {
|
||||||
|
pktHead = pkt
|
||||||
|
if cap(pkt)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
return coalescePSHEnding
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
item.sentSeq = seq
|
||||||
|
extendBy := coalescedLen - len(pktHead)
|
||||||
|
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||||
|
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||||
|
// is already being tracked for writing.
|
||||||
|
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||||
|
} else {
|
||||||
|
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
// We are appending a segment with PSH set.
|
||||||
|
item.pshSet = pshSet
|
||||||
|
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||||
|
}
|
||||||
|
extendBy := len(pkt) - int(headersLen)
|
||||||
|
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
item.gsoSize = gsoSize
|
||||||
|
}
|
||||||
|
|
||||||
|
item.numMerged++
|
||||||
|
return coalesceSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4FlagMoreFragments uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxUint16 = 1<<16 - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type groResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
groResultNoop groResult = iota
|
||||||
|
groResultTableInsert
|
||||||
|
groResultCoalesced
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||||
|
// existing packets tracked in table. It returns a groResultNoop when no
|
||||||
|
// action was taken, groResultTableInsert when the evaluated packet was
|
||||||
|
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||||
|
// coalesced with another packet in table.
|
||||||
|
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
|
||||||
|
pkt := bufs[pktI][offset:]
|
||||||
|
if len(pkt) > maxUint16 {
|
||||||
|
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||||
|
if isV6 {
|
||||||
|
iphLen = 40
|
||||||
|
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||||
|
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||||
|
if totalLen != len(pkt) {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||||
|
if tcphLen < 20 || tcphLen > 60 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen+tcphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if !isV6 {
|
||||||
|
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||||
|
// no GRO support for fragmented segments for now
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||||
|
var pshSet bool
|
||||||
|
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||||
|
if tcpFlags != tcpFlagACK {
|
||||||
|
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
pshSet = true
|
||||||
|
}
|
||||||
|
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||||
|
// not a candidate if payload len is 0
|
||||||
|
if gsoSize < 1 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||||
|
srcAddrOffset := ipv4SrcAddrOffset
|
||||||
|
addrLen := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrOffset = ipv6SrcAddrOffset
|
||||||
|
addrLen = 16
|
||||||
|
}
|
||||||
|
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
if !existing {
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
for i := len(items) - 1; i >= 0; i-- {
|
||||||
|
// In the best case of packets arriving in order iterating in reverse is
|
||||||
|
// more efficient if there are multiple items for a given flow. This
|
||||||
|
// also enables a natural table.deleteAt() in the
|
||||||
|
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||||
|
// This algorithm makes a best effort to coalesce in the event of
|
||||||
|
// unordered packets, where pkt may land anywhere in items from a
|
||||||
|
// sequence number perspective, however once an item is inserted into
|
||||||
|
// the table it is never compared across other items later.
|
||||||
|
item := items[i]
|
||||||
|
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||||
|
if can != coalesceUnavailable {
|
||||||
|
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||||
|
switch result {
|
||||||
|
case coalesceSuccess:
|
||||||
|
table.updateAt(item, i)
|
||||||
|
return groResultCoalesced
|
||||||
|
case coalesceItemInvalidCSum:
|
||||||
|
// delete the item with an invalid csum
|
||||||
|
table.deleteAt(item.key, i)
|
||||||
|
case coalescePktInvalidCSum:
|
||||||
|
// no point in inserting an item that we can't coalesce
|
||||||
|
return groResultNoop
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// failed to coalesce with any other packets; store the item in the flow
|
||||||
|
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||||
|
// metadata found in table.
|
||||||
|
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
|
||||||
|
for _, items := range table.itemsByFlow {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.numMerged > 0 {
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||||
|
hdrLen: uint16(item.iphLen + item.tcphLen),
|
||||||
|
gsoSize: item.gsoSize,
|
||||||
|
csumStart: uint16(item.iphLen),
|
||||||
|
csumOffset: 16,
|
||||||
|
}
|
||||||
|
pkt := bufs[item.bufsIndex][offset:]
|
||||||
|
|
||||||
|
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||||
|
// Recalculate the (IPv4) header checksum.
|
||||||
|
if item.key.isV6 {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||||
|
} else {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
pkt[10], pkt[11] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||||
|
iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||||
|
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||||
|
}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the pseudo header checksum and place it at the TCP
|
||||||
|
// checksum offset. Downstream checksum offloading will combine
|
||||||
|
// this with computation of the tcp header and payload checksum.
|
||||||
|
addrLen := 4
|
||||||
|
addrOffset := ipv4SrcAddrOffset
|
||||||
|
if item.key.isV6 {
|
||||||
|
addrLen = 16
|
||||||
|
addrOffset = ipv6SrcAddrOffset
|
||||||
|
}
|
||||||
|
srcAddrAt := offset + addrOffset
|
||||||
|
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||||
|
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||||
|
psum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||||
|
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum))
|
||||||
|
} else {
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||||
|
// metadata found in table.
|
||||||
|
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
|
||||||
|
for _, items := range table.itemsByFlow {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.numMerged > 0 {
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||||
|
hdrLen: uint16(item.iphLen + udphLen),
|
||||||
|
gsoSize: item.gsoSize,
|
||||||
|
csumStart: uint16(item.iphLen),
|
||||||
|
csumOffset: 6,
|
||||||
|
}
|
||||||
|
pkt := bufs[item.bufsIndex][offset:]
|
||||||
|
|
||||||
|
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||||
|
// Recalculate the (IPv4) header checksum.
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
|
||||||
|
if item.key.isV6 {
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||||
|
} else {
|
||||||
|
pkt[10], pkt[11] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||||
|
iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||||
|
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||||
|
}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate the UDP len field value
|
||||||
|
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
|
||||||
|
|
||||||
|
// Calculate the pseudo header checksum and place it at the UDP
|
||||||
|
// checksum offset. Downstream checksum offloading will combine
|
||||||
|
// this with computation of the udp header and payload checksum.
|
||||||
|
addrLen := 4
|
||||||
|
addrOffset := ipv4SrcAddrOffset
|
||||||
|
if item.key.isV6 {
|
||||||
|
addrLen = 16
|
||||||
|
addrOffset = ipv6SrcAddrOffset
|
||||||
|
}
|
||||||
|
srcAddrAt := offset + addrOffset
|
||||||
|
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||||
|
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||||
|
psum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||||
|
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum))
|
||||||
|
} else {
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type groCandidateType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
notGROCandidate groCandidateType = iota
|
||||||
|
tcp4GROCandidate
|
||||||
|
tcp6GROCandidate
|
||||||
|
udp4GROCandidate
|
||||||
|
udp6GROCandidate
|
||||||
|
)
|
||||||
|
|
||||||
|
type groDisablementFlags int
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpGRODisabled groDisablementFlags = 1 << iota
|
||||||
|
udpGRODisabled
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *groDisablementFlags) disableTCPGRO() {
|
||||||
|
*g |= tcpGRODisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groDisablementFlags) canTCPGRO() bool {
|
||||||
|
return (*g)&tcpGRODisabled == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groDisablementFlags) disableUDPGRO() {
|
||||||
|
*g |= udpGRODisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *groDisablementFlags) canUDPGRO() bool {
|
||||||
|
return (*g)&udpGRODisabled == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType {
|
||||||
|
if len(b) < 28 {
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
if b[0]>>4 == 4 {
|
||||||
|
if b[0]&0x0F != 5 {
|
||||||
|
// IPv4 packets w/IP options do not coalesce
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() {
|
||||||
|
return tcp4GROCandidate
|
||||||
|
}
|
||||||
|
if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() {
|
||||||
|
return udp4GROCandidate
|
||||||
|
}
|
||||||
|
} else if b[0]>>4 == 6 {
|
||||||
|
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() {
|
||||||
|
return tcp6GROCandidate
|
||||||
|
}
|
||||||
|
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() {
|
||||||
|
return udp6GROCandidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
udphLen = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
|
||||||
|
// existing packets tracked in table. It returns a groResultNoop when no
|
||||||
|
// action was taken, groResultTableInsert when the evaluated packet was
|
||||||
|
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||||
|
// coalesced with another packet in table.
|
||||||
|
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
|
||||||
|
pkt := bufs[pktI][offset:]
|
||||||
|
if len(pkt) > maxUint16 {
|
||||||
|
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||||
|
if isV6 {
|
||||||
|
iphLen = 40
|
||||||
|
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||||
|
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||||
|
if totalLen != len(pkt) {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen+udphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if !isV6 {
|
||||||
|
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||||
|
// no GRO support for fragmented segments for now
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gsoSize := uint16(len(pkt) - udphLen - iphLen)
|
||||||
|
// not a candidate if payload len is 0
|
||||||
|
if gsoSize < 1 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
srcAddrOffset := ipv4SrcAddrOffset
|
||||||
|
addrLen := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrOffset = ipv6SrcAddrOffset
|
||||||
|
addrLen = 16
|
||||||
|
}
|
||||||
|
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
|
||||||
|
if !existing {
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
// With UDP we only check the last item, otherwise we could reorder packets
|
||||||
|
// for a given flow. We must also always insert a new item, or successfully
|
||||||
|
// coalesce with an existing item, for the same reason.
|
||||||
|
item := items[len(items)-1]
|
||||||
|
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
|
||||||
|
var pktCSumKnownInvalid bool
|
||||||
|
if can == coalesceAppend {
|
||||||
|
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
|
||||||
|
switch result {
|
||||||
|
case coalesceSuccess:
|
||||||
|
table.updateAt(item, len(items)-1)
|
||||||
|
return groResultCoalesced
|
||||||
|
case coalesceItemInvalidCSum:
|
||||||
|
// If the existing item has an invalid csum we take no action. A new
|
||||||
|
// item will be stored after it, and the existing item will never be
|
||||||
|
// revisited as part of future coalescing candidacy checks.
|
||||||
|
case coalescePktInvalidCSum:
|
||||||
|
// We must insert a new item, but we also mark it as invalid csum
|
||||||
|
// to prevent a repeat checksum validation.
|
||||||
|
pktCSumKnownInvalid = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// failed to coalesce with any other packets; store the item in the flow
|
||||||
|
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||||
|
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
|
||||||
|
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||||
|
// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO
|
||||||
|
// are supported/enabled.
|
||||||
|
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error {
|
||||||
|
for i := range bufs {
|
||||||
|
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||||
|
return errors.New("invalid offset")
|
||||||
|
}
|
||||||
|
var result groResult
|
||||||
|
switch packetIsGROCandidate(bufs[i][offset:], gro) {
|
||||||
|
case tcp4GROCandidate:
|
||||||
|
result = tcpGRO(bufs, offset, i, tcpTable, false)
|
||||||
|
case tcp6GROCandidate:
|
||||||
|
result = tcpGRO(bufs, offset, i, tcpTable, true)
|
||||||
|
case udp4GROCandidate:
|
||||||
|
result = udpGRO(bufs, offset, i, udpTable, false)
|
||||||
|
case udp6GROCandidate:
|
||||||
|
result = udpGRO(bufs, offset, i, udpTable, true)
|
||||||
|
}
|
||||||
|
switch result {
|
||||||
|
case groResultNoop:
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case groResultTableInsert:
|
||||||
|
*toWrite = append(*toWrite, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
|
||||||
|
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
|
||||||
|
return errors.Join(errTCP, errUDP)
|
||||||
|
}
|
66
tun_rules.go
66
tun_rules.go
|
@ -108,7 +108,7 @@ const autoRouteUseSubRanges = runtime.GOOS == "darwin"
|
||||||
|
|
||||||
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
|
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
|
||||||
var routeRanges []netip.Prefix
|
var routeRanges []netip.Prefix
|
||||||
if o.AutoRoute && len(o.Inet4Address) > 0 {
|
if len(o.Inet4Address) > 0 {
|
||||||
var inet4Ranges []netip.Prefix
|
var inet4Ranges []netip.Prefix
|
||||||
if len(o.Inet4RouteAddress) > 0 {
|
if len(o.Inet4RouteAddress) > 0 {
|
||||||
inet4Ranges = o.Inet4RouteAddress
|
inet4Ranges = o.Inet4RouteAddress
|
||||||
|
@ -119,19 +119,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if autoRouteUseSubRanges && !underNetworkExtension {
|
} else if o.AutoRoute {
|
||||||
inet4Ranges = []netip.Prefix{
|
if autoRouteUseSubRanges && !underNetworkExtension {
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
|
inet4Ranges = []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "darwin" {
|
||||||
|
for _, address := range o.Inet4Address {
|
||||||
|
if address.Bits() < 32 {
|
||||||
|
inet4Ranges = append(inet4Ranges, address.Masked())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
}
|
||||||
if len(o.Inet4RouteExcludeAddress) == 0 {
|
if len(o.Inet4RouteExcludeAddress) == 0 {
|
||||||
routeRanges = append(routeRanges, inet4Ranges...)
|
routeRanges = append(routeRanges, inet4Ranges...)
|
||||||
|
@ -161,19 +169,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if autoRouteUseSubRanges && !underNetworkExtension {
|
} else if o.AutoRoute {
|
||||||
inet6Ranges = []netip.Prefix{
|
if autoRouteUseSubRanges && !underNetworkExtension {
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
|
inet6Ranges = []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
|
||||||
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "darwin" {
|
||||||
|
for _, address := range o.Inet6Address {
|
||||||
|
if address.Bits() < 32 {
|
||||||
|
inet6Ranges = append(inet6Ranges, address.Masked())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
}
|
}
|
||||||
if len(o.Inet6RouteExcludeAddress) == 0 {
|
if len(o.Inet6RouteExcludeAddress) == 0 {
|
||||||
routeRanges = append(routeRanges, inet6Ranges...)
|
routeRanges = append(routeRanges, inet6Ranges...)
|
||||||
|
|
118
tun_windows.go
118
tun_windows.go
|
@ -72,7 +72,7 @@ func (t *NativeTun) configure() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "set ipv4 address")
|
return E.Cause(err, "set ipv4 address")
|
||||||
}
|
}
|
||||||
if !t.options.EXP_DisableDNSHijack {
|
if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack {
|
||||||
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4)
|
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4)
|
||||||
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) {
|
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) {
|
||||||
dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()}
|
dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()}
|
||||||
|
@ -83,6 +83,11 @@ func (t *NativeTun) configure() error {
|
||||||
return E.Cause(err, "set ipv4 dns")
|
return E.Cause(err, "set ipv4 dns")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "set ipv4 dns")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(t.options.Inet6Address) > 0 {
|
if len(t.options.Inet6Address) > 0 {
|
||||||
|
@ -90,7 +95,7 @@ func (t *NativeTun) configure() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "set ipv6 address")
|
return E.Cause(err, "set ipv6 address")
|
||||||
}
|
}
|
||||||
if !t.options.EXP_DisableDNSHijack {
|
if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack {
|
||||||
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6)
|
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6)
|
||||||
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) {
|
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) {
|
||||||
dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()}
|
dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()}
|
||||||
|
@ -101,6 +106,11 @@ func (t *NativeTun) configure() error {
|
||||||
return E.Cause(err, "set ipv6 dns")
|
return E.Cause(err, "set ipv6 dns")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "set ipv6 dns")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
|
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
|
||||||
|
@ -148,6 +158,10 @@ func (t *NativeTun) configure() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) Name() (string, error) {
|
||||||
|
return t.options.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *NativeTun) Start() error {
|
func (t *NativeTun) Start() error {
|
||||||
if !t.options.AutoRoute {
|
if !t.options.AutoRoute {
|
||||||
return nil
|
return nil
|
||||||
|
@ -158,13 +172,7 @@ func (t *NativeTun) Start() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, routeRange := range routeRanges {
|
err = addRouteList(luid, routeRanges, gateway4, gateway6, 0)
|
||||||
if routeRange.Addr().Is4() {
|
|
||||||
err = luid.AddRoute(routeRange, gateway4, 0)
|
|
||||||
} else {
|
|
||||||
err = luid.AddRoute(routeRange, gateway6, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -349,7 +357,40 @@ func (t *NativeTun) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
||||||
return 0, os.ErrInvalid
|
t.running.Add(1)
|
||||||
|
defer t.running.Done()
|
||||||
|
retry:
|
||||||
|
if t.close.Load() == 1 {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
start := nanotime()
|
||||||
|
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
|
||||||
|
for {
|
||||||
|
if t.close.Load() == 1 {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
var packet []byte
|
||||||
|
packet, err = t.session.ReceivePacket()
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
n = copy(p, packet)
|
||||||
|
t.session.ReleaseReceivePacket(packet)
|
||||||
|
t.rate.update(uint64(n))
|
||||||
|
return
|
||||||
|
case windows.ERROR_NO_MORE_ITEMS:
|
||||||
|
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||||
|
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
procyield(1)
|
||||||
|
continue
|
||||||
|
case windows.ERROR_HANDLE_EOF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
case windows.ERROR_INVALID_DATA:
|
||||||
|
return 0, errors.New("send ring corrupt")
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("read failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
|
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
|
||||||
|
@ -498,6 +539,63 @@ func (t *NativeTun) Close() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
|
||||||
|
t.options = tunOptions
|
||||||
|
if !t.options.AutoRoute {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
|
||||||
|
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
luid := winipcfg.LUID(t.adapter.LUID())
|
||||||
|
err = luid.FlushRoutes(windows.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = addRouteList(luid, routeRanges, gateway4, gateway6, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = windnsapi.FlushResolverCache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRouteList(luid winipcfg.LUID, destinations []netip.Prefix, gateway4 netip.Addr, gateway6 netip.Addr, metric uint32) error {
|
||||||
|
row := winipcfg.MibIPforwardRow2{}
|
||||||
|
row.Init()
|
||||||
|
row.InterfaceLUID = luid
|
||||||
|
row.Metric = metric
|
||||||
|
nextHop4 := row.NextHop
|
||||||
|
nextHop6 := row.NextHop
|
||||||
|
if gateway4.IsValid() {
|
||||||
|
nextHop4.SetAddr(gateway4)
|
||||||
|
}
|
||||||
|
if gateway6.IsValid() {
|
||||||
|
nextHop6.SetAddr(gateway6)
|
||||||
|
}
|
||||||
|
for _, destination := range destinations {
|
||||||
|
err := row.DestinationPrefix.SetPrefix(destination)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if destination.Addr().Is4() {
|
||||||
|
row.NextHop = nextHop4
|
||||||
|
} else {
|
||||||
|
row.NextHop = nextHop6
|
||||||
|
}
|
||||||
|
err = row.Create()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func generateGUIDByDeviceName(name string) *windows.GUID {
|
func generateGUIDByDeviceName(name string) *windows.GUID {
|
||||||
hash := md5.New()
|
hash := md5.New()
|
||||||
hash.Write([]byte("wintun"))
|
hash.Write([]byte("wintun"))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue