Compare commits

...

27 commits

Author SHA1 Message Date
世界
35b5747b44
Ignore UDP offload error 2025-03-19 20:37:03 +08:00
wwqgtxx
5cb6d27288
Add bench test to makefile 2025-03-19 20:37:02 +08:00
greathongtu
9105485a50
Fix typo of 'Android' in ErrNetlinkBanned error message 2025-03-19 20:37:02 +08:00
世界
57aba1a5c4
Fix checksum bench 2025-03-19 20:37:02 +08:00
wwqgtxx
7f3343169a
Better atomic using 2025-03-19 20:37:02 +08:00
wwqgtxx
22b811f938
Remove gvisor's buffer dependence of gtcpip 2025-03-19 20:37:02 +08:00
世界
618be14c7b
Fix generate darwin rules 2025-02-17 21:56:54 +08:00
bemarkt
c8c2984261
Fix wrong parameter in ICMPv4Checksum 2025-02-10 14:20:22 +08:00
世界
8cc5351bb3
auto-redirect: Move initialize to start 2025-02-06 08:44:55 +08:00
世界
d093b82064
auto-redirect: Fetch interfaces 2025-01-07 19:01:54 +08:00
世界
aa9d9c6296
Improve add windows route 2024-12-29 21:19:14 +08:00
世界
f457988090
Add tun.UpdateRouteOptions 2024-12-23 22:17:27 +08:00
世界
d0887eabba
Fix set gateway 2024-12-13 09:31:03 +08:00
世界
edabb6d7ba
Fix set windows dns 2024-12-13 09:30:52 +08:00
世界
d38f9adaef
Fix Read not implemented for windows 2024-12-13 09:30:52 +08:00
Zxneric
618d3f9a52
Fix wrong condition on interface check 2024-11-29 22:46:05 +08:00
世界
c21c623174
Add Tun.Name 2024-11-28 14:27:11 +08:00
世界
091b5da950
Revert "Update udpant usages" 2024-11-27 18:04:39 +08:00
世界
59a6bdc1fa
Update udpant usages 2024-11-26 19:17:09 +08:00
世界
c177abb523
Copy TCP GRO probe from tailscale 2024-11-23 13:13:48 +08:00
世界
6ef42f019b
Fix build on bsd systems 2024-11-23 12:58:52 +08:00
世界
c1f61d08ba
Make GSO optional 2024-11-23 12:58:52 +08:00
世界
2b8115e83b
Copy UDP GSO support from tailscale 2024-11-23 12:58:52 +08:00
世界
06b4d4ecd1
Add tailscale checksum 2024-11-23 12:58:52 +08:00
世界
4ebeb2fa86
Export interface for WireGuard 2024-11-23 12:58:52 +08:00
Benyamin
8a18f0c99e
Fix android routing rules about vpn protection 2024-11-18 18:30:51 +08:00
世界
b8b3759826
Fix HandshakeFailure usages 2024-11-14 22:04:31 +08:00
41 changed files with 4254 additions and 1550 deletions

View file

@ -29,4 +29,5 @@ lint_install:
test:
go build -v .
go test -bench=. ./internal/checksum_test
#go test -v .

4
go.mod
View file

@ -6,10 +6,10 @@ require (
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-alpha.6
github.com/sagernet/sing v0.6.0-beta.2
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0

8
go.sum
View file

@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM=
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.6.0-alpha.6 h1:R0abM8ZeazyAKo9d3DNxtrgW17g3tZAD8al7O5+ADOw=
github.com/sagernet/sing v0.6.0-alpha.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=

View file

@ -0,0 +1,33 @@
package checksum_test
import (
"crypto/rand"
"testing"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/tschecksum"
)
func BenchmarkTsChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tschecksum.Checksum(packet[i%1000], 0)
}
}
func BenchmarkGChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum.ChecksumDefault(packet[i%1000], 0)
}
}

View file

@ -30,34 +30,6 @@ func Put(b []byte, xsum uint16) {
binary.BigEndian.PutUint16(b, xsum)
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}
// Checksumer calculates checksum defined in RFC 1071.
type Checksumer struct {
sum uint16
odd bool
}
// Add adds b to checksum.
func (c *Checksumer) Add(b []byte) {
if len(b) > 0 {
c.sum, c.odd = calculateChecksum(b, c.odd, c.sum)
}
}
// Checksum returns the latest checksum value.
func (c *Checksumer) Checksum() uint16 {
return c.sum
}
// Combine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
@ -66,3 +38,8 @@ func Combine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}
func ChecksumDefault(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}

View file

@ -0,0 +1,12 @@
//go:build !amd64
package checksum
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
return ChecksumDefault(buf, initial)
}

View file

@ -0,0 +1,9 @@
//go:build amd64
package checksum
import "github.com/sagernet/sing-tun/internal/tschecksum"
func Checksum(buf []byte, initial uint16) uint16 {
return tschecksum.Checksum(buf, initial)
}

View file

@ -0,0 +1,136 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"net/netip"
tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)
// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16
// Destination returns the value of the "destination port" field.
DestinationPort() uint16
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)
// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// Payload returns the data carried in the transport buffer.
Payload() []byte
}
// ChecksummableTransport is a Transport that supports checksumming.
type ChecksummableTransport interface {
Transport
// SetSourcePortWithChecksumUpdate sets the source port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetSourcePortWithChecksumUpdate(port uint16)
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetDestinationPortWithChecksumUpdate(port uint16)
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
//
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
// fully calculated checksum. Otherwise, it is assumed to hold a partially
// calculated checksum which only reflects the pseudo header.
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
}
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
DestinationAddr() netip.Addr
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)
SetDestinationAddr(addr netip.Addr)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber
// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte
// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
// ChecksummableNetwork is a Network that supports checksumming.
type ChecksummableNetwork interface {
Network
// SetSourceAddressAndChecksum sets the source address and updates the
// checksum to reflect the new address.
SetSourceAddressWithChecksumUpdate(tcpip.Address)
// SetDestinationAddressAndChecksum sets the destination address and
// updates the checksum to reflect the new address.
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
}

View file

@ -18,10 +18,8 @@ import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing/common"
)
@ -145,79 +143,6 @@ func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) i
return ((padLen + align - 1) & ^(align - 1)) - padLen
}
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
// These headers include IPv6 extension headers or upper layer data.
type IPv6PayloadHeader interface {
isIPv6PayloadHeader()
// Release frees all resources held by the header.
Release()
}
// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
// encounters a Next Header field it does not recognize as an IPv6 extension
// header. The caller is responsible for releasing the underlying buffer after
// it's no longer needed.
type IPv6RawPayloadHeader struct {
Identifier IPv6ExtensionHeaderIdentifier
Buf buffer.Buffer
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (i IPv6RawPayloadHeader) Release() {
i.Buf.Release()
}
// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
type ipv6OptionsExtHdr struct {
buf *buffer.View
}
// Release implements IPv6PayloadHeader.Release.
func (i ipv6OptionsExtHdr) Release() {
if i.buf != nil {
i.buf.Release()
}
}
// Iter returns an iterator over the IPv6 extension header options held in b.
func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
it := IPv6OptionsExtHdrOptionsIterator{}
it.reader = i.buf
return it
}
// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
// options.
//
// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
// used, no changes to the underlying buffer may happen. Doing so may cause
// undefined and unexpected behaviour. It is fine to obtain an
// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader *buffer.View
// optionOffset is the number of bytes from the first byte of the
// options field to the beginning of the current option.
optionOffset uint32
// nextOptionOffset is the offset of the next option.
nextOptionOffset uint32
}
// OptionOffset returns the number of bytes parsed while processing the
// option field of the current Extension Header.
func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
type IPv6OptionUnknownAction int
@ -294,143 +219,6 @@ func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUn
// is malformed.
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
Identifier IPv6ExtHdrOptionIdentifier
Data *buffer.View
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// Next returns the next option in the options data.
//
// If the next item is not a known extension header option,
// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
//
// The return is of the format (option, done, error). done will be true when
// Next is unable to return anything because the iterator has reached the end of
// the options data, or an error occurred.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
i.nextOptionOffset = i.optionOffset + 1
continue
}
length, err := i.reader.ReadByte()
if err != nil {
if err != io.EOF {
// ReadByte should only ever return nil or io.EOF.
panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
}
// We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
// we start parsing an option; we expect the reader to contain enough
// bytes for the whole option.
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
// Do we have enough bytes in the reader for the next option?
if n := i.reader.Size(); n < int(length) {
// Consume the remaining buffer.
i.reader.TrimFront(i.reader.Size())
// We return the same error as if we failed to read a non-padding option
// so consumers of this iterator don't need to differentiate between
// padding and non-padding options.
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
}
i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
switch id {
case ipv6PadNExtHdrOptionIdentifier:
// Special-case the variable length padding option to avoid a copy.
i.reader.TrimFront(int(length))
continue
case ipv6RouterAlertHopByHopOptionIdentifier:
var routerAlertValue [ipv6RouterAlertPayloadLength]byte
if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
default:
return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
}
} else if n != int(length) {
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
}
return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := buffer.NewView(int(length))
if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
}
// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
// extension header.
type IPv6HopByHopOptionsExtHdr struct {
ipv6OptionsExtHdr
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
// extension header.
type IPv6DestinationOptionsExtHdr struct {
ipv6OptionsExtHdr
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
// data as outlined in RFC 8200 section 4.4.
type IPv6RoutingExtHdr struct {
Buf *buffer.View
}
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (b IPv6RoutingExtHdr) Release() {
b.Buf.Release()
}
// SegmentsLeft returns the Segments Left field.
func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx]
}
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
// data as outlined in RFC 8200 section 4.5.
//
@ -473,242 +261,6 @@ func (b IPv6FragmentExtHdr) IsAtomic() bool {
return !b.More() && b.FragmentOffset() == 0
}
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
//
// The IPv6 payload may contain IPv6 extension headers before any upper layer
// data.
//
// Note, between when an IPv6PayloadIterator is obtained and last used, no
// changes to the payload may happen. Doing so may cause undefined and
// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
// over the first few headers then modify the backing payload so long as the
// IPv6PayloadIterator obtained before modification is no longer used.
type IPv6PayloadIterator struct {
// The identifier of the next header to parse.
nextHdrIdentifier IPv6ExtensionHeaderIdentifier
payload buffer.Buffer
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
// headerOffset is the offset of the beginning of the current extension
// header starting from the beginning of the fixed header.
headerOffset uint32
// parseOffset is the byte offset into the current extension header of the
// field we are currently examining. It can be added to the header offset
// if the absolute offset within the packet is required.
parseOffset uint32
// nextOffset is the offset of the next header.
nextOffset uint32
}
// HeaderOffset returns the offset to the start of the extension
// header most recently processed.
func (i IPv6PayloadIterator) HeaderOffset() uint32 {
return i.headerOffset
}
// ParseOffset returns the number of bytes successfully parsed.
func (i IPv6PayloadIterator) ParseOffset() uint32 {
return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
// extension headers, or a raw payload if the payload cannot be parsed. The
// iterator takes ownership of the payload.
func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator {
return IPv6PayloadIterator{
nextHdrIdentifier: nextHdrIdentifier,
payload: payload,
nextOffset: IPv6FixedHeaderSize,
}
}
// Release frees the resources owned by the iterator.
func (i *IPv6PayloadIterator) Release() {
i.payload.Release()
}
// AsRawHeader returns the remaining payload of i as a raw header and
// optionally consumes the iterator.
//
// If consume is true, calls to Next after calling AsRawHeader on i will
// indicate that the iterator is done. The returned header takes ownership of
// its payload.
func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
identifier := i.nextHdrIdentifier
var buf buffer.Buffer
if consume {
// Since we consume the iterator, we return the payload as is.
buf = i.payload
// Mark i as done, but keep track of where we were for error reporting.
*i = IPv6PayloadIterator{
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
headerOffset: i.headerOffset,
nextOffset: i.nextOffset,
}
} else {
buf = i.payload.Clone()
}
return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
}
// Next returns the next item in the payload.
//
// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
// will be returned with the remaining bytes and next header identifier.
//
// The return is of the format (header, done, error). done will be true when
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occurred.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
i.headerOffset = i.nextOffset
i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
if i.forceRaw {
return i.AsRawHeader(true /* consume */), false, nil
}
// Is the header we are parsing a known extension header?
switch i.nextHdrIdentifier {
case IPv6HopByHopOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6RoutingExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6RoutingExtHdr{view}, false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
// We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
return nil, true, err
}
fragmentExtHdr := IPv6FragmentExtHdr(data)
// If the packet is not the first fragment, do not attempt to parse anything
// after the fragment extension header as the payload following the fragment
// extension header should not contain any headers; the first fragment must
// hold all the headers up to and including any upper layer headers, as per
// RFC 8200 section 4.5.
if fragmentExtHdr.FragmentOffset() != 0 {
i.forceRaw = true
}
i.nextHdrIdentifier = nextHdrIdentifier
return fragmentExtHdr, false, nil
case IPv6DestinationOptionsExtHdrIdentifier:
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
if err != nil {
return nil, true, err
}
i.nextHdrIdentifier = nextHdrIdentifier
return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
case IPv6NoNextHeaderIdentifier:
// This indicates the end of the IPv6 payload.
return nil, true, nil
default:
// The header we are parsing is not a known extension header. Return the
// raw payload.
return i.AsRawHeader(true /* consume */), false, nil
}
}
// NextHeaderIdentifier returns the identifier of the header next returned by
// it.Next().
func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier {
return i.nextHdrIdentifier
}
// nextHeaderData returns the extension header's Next Header field and raw data.
//
// fragmentHdr indicates that the extension header being parsed is the Fragment
// extension header so the Length field should be ignored as it is Reserved
// for the Fragment extension header.
//
// If bytes is not nil, extension header specific data will be read into bytes
// if it has enough capacity. If bytes is provided but does not have enough
// capacity for the data, nextHeaderData will panic.
func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
// We ignore the number of bytes read because we know we will only ever read
// at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
// would return io.EOF to indicate that io.Reader has reached the end of the
// payload.
rdr := i.payload.AsBufferReader()
nextHdrIdentifier, err := rdr.ReadByte()
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
i.parseOffset++
var length uint8
length, err = rdr.ReadByte()
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
if fragmentHdr {
length = 0
}
// Make parseOffset point to the first byte of the Extension Header
// specific data.
i.parseOffset++
// length is in 8 byte chunks but doesn't include the first one.
// See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
// in section 4.8 for new extension headers at the top of page 24.
// [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
// units, not including the first 8 octets.
i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if fragmentHdr {
if n := len(bytes); n < bytesLen {
panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
}
if n, err := io.ReadFull(&rdr, bytes); err != nil {
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
}
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil
}
v := buffer.NewView(bytesLen)
if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
v.Release()
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
}
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil
}
// IPv6SerializableExtHdr provides serialization for IPv6 extension
// headers.
type IPv6SerializableExtHdr interface {

View file

@ -0,0 +1,712 @@
package tschecksum
import (
"encoding/binary"
"math/bits"
"strconv"
"golang.org/x/sys/cpu"
)
// checksumGeneric64 is a reference implementation of checksum using 64 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric64(b []byte, initial uint16) uint16 {
var ac uint64
var carry uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 128 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
}
b = b[128:]
}
if len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
} else {
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
}
}
folded := ipChecksumFold64(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32 is a reference implementation of checksum using 32 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric32(b []byte, initial uint16) uint16 {
var ac uint32
var carry uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
} else {
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
}
}
folded := ipChecksumFold32(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32Alternate is an alternate reference implementation of
// checksum using 32 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
var ac uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b))
} else {
ac += uint32(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint32(b[0]) << 8
} else {
ac += uint32(b[0])
}
}
folded := ipChecksumFold32(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric64Alternate is an alternate reference implementation of
// checksum using 64 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
var ac uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b))
} else {
ac += uint64(binary.LittleEndian.Uint32(b))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint16(b))
} else {
ac += uint64(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint64(b[0]) << 8
} else {
ac += uint64(b[0])
}
}
folded := ipChecksumFold64(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
// no need to save the carry flag
sum = (sum >> 16) + (sum & 0xffff) + carry
// sum <= 0x1_fffe therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
// sum <= 0x1_ffff:
// 0xffff + 0xffff = 0x1_fffe
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
sum = (sum >> 16) + (sum & 0xffff)
// sum <= 0x1_0000 therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
} else {
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
} else {
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint64
if cpu.IsBigEndian {
sum = uint64(totalLen) + uint64(protocol)
} else {
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
}
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
foldedSum := ipChecksumFold64(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint32
if cpu.IsBigEndian {
sum = uint32(totalLen) + uint32(protocol)
} else {
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
}
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
foldedSum := ipChecksumFold32(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
// dstAddr must be 4 or 16 bytes in length.
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
if strconv.IntSize < 64 {
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
}
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
}

View file

@ -0,0 +1,23 @@
package tschecksum
import "golang.org/x/sys/cpu"
var checksum = checksumAMD64
// Checksum computes an IP checksum starting with the provided initial value.
// The length of data should be at least 128 bytes for best performance. Smaller
// buffers will still compute a correct result.
func Checksum(data []byte, initial uint16) uint16 {
return checksum(data, initial)
}
func init() {
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
checksum = checksumAVX2
return
}
if cpu.X86.HasSSE2 {
checksum = checksumSSE2
return
}
}

View file

@ -0,0 +1,18 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
package tschecksum
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
//
//go:noescape
func checksumAVX2(b []byte, initial uint16) uint16
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
//
//go:noescape
func checksumSSE2(b []byte, initial uint16) uint16
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
//
//go:noescape
func checksumAMD64(b []byte, initial uint16) uint16

View file

@ -0,0 +1,851 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
#include "textflag.h"
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
// func checksumAVX2(b []byte, initial uint16) uint16
// Requires: AVX, AVX2, BMI2
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
VPXOR Y0, Y0, Y0
VPXOR Y1, Y1, Y1
VPXOR Y2, Y2, Y2
VPXOR Y3, Y3, Y3
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 16(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 32(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 48(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 64(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 80(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 96(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 112(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 128(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 144(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 160(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 176(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 192(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 208(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 224(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 240(DX), Y4
VPADDD Y4, Y3, Y3
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
VMOVDQU -16(DX)(BX*1), X4
VPAND -16(CX)(BX*8), X4, X4
VPMOVZXWD X4, Y4
VPADDD Y4, Y0, Y0
doneSIMD:
// Multi-chain loop is done, combine the accumulators
VPADDD Y1, Y0, Y0
VPADDD Y2, Y0, Y0
VPADDD Y3, Y0, Y0
// extract the YMM into a pair of XMM and sum them
VEXTRACTI128 $0x01, Y0, X1
VPADDD X0, X1, X0
// extract the XMM into GP64
VPEXTRQ $0x00, X0, CX
VPEXTRQ $0x01, X0, DX
// no more AVX code, clear upper registers to avoid SSE slowdowns
VZEROUPPER
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
RORXQ $0x20, AX, CX
ADCL CX, AX
RORXL $0x10, AX, CX
ADCW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumSSE2(b []byte, initial uint16) uint16
// Requires: SSE2
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
PXOR X0, X0
PXOR X1, X1
PXOR X2, X2
PXOR X3, X3
PXOR X4, X4
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 16(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 32(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 48(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 64(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 80(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 96(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 112(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 128(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 144(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 160(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 176(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 192(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 208(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 224(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 240(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
MOVOU -16(DX)(BX*1), X5
PAND -16(CX)(BX*8), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
doneSIMD:
// Multi-chain loop is done, combine the accumulators
PADDD X1, X0
PADDD X2, X0
PADDD X3, X0
// extract the XMM into GP64
MOVQ X0, CX
PSRLDQ $0x08, X0
MOVQ X0, DX
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumAMD64(b []byte, initial uint16) uint16
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// Number of 256 byte iterations into loop counter
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
SHRQ $0x08, CX
JZ startCleanup
CLC
XORQ SI, SI
XORQ DI, DI
XORQ R8, R8
XORQ R9, R9
XORQ R10, R10
XORQ R11, R11
XORQ R12, R12
bigLoop:
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ 32(DX), DI
ADCQ 40(DX), DI
ADCQ 48(DX), DI
ADCQ 56(DX), DI
ADCQ $0x00, R8
ADDQ 64(DX), R9
ADCQ 72(DX), R9
ADCQ 80(DX), R9
ADCQ 88(DX), R9
ADCQ $0x00, R10
ADDQ 96(DX), R11
ADCQ 104(DX), R11
ADCQ 112(DX), R11
ADCQ 120(DX), R11
ADCQ $0x00, R12
ADDQ 128(DX), AX
ADCQ 136(DX), AX
ADCQ 144(DX), AX
ADCQ 152(DX), AX
ADCQ $0x00, SI
ADDQ 160(DX), DI
ADCQ 168(DX), DI
ADCQ 176(DX), DI
ADCQ 184(DX), DI
ADCQ $0x00, R8
ADDQ 192(DX), R9
ADCQ 200(DX), R9
ADCQ 208(DX), R9
ADCQ 216(DX), R9
ADCQ $0x00, R10
ADDQ 224(DX), R11
ADCQ 232(DX), R11
ADCQ 240(DX), R11
ADCQ 248(DX), R11
ADCQ $0x00, R12
ADDQ $0x00000100, DX
SUBQ $0x01, CX
JNZ bigLoop
ADDQ SI, AX
ADCQ DI, AX
ADCQ R8, AX
ADCQ R9, AX
ADCQ R10, AX
ADCQ R11, AX
ADCQ R12, AX
// accumulate CF (twice, in case the first time overflows)
ADCQ $0x00, AX
ADCQ $0x00, AX
startCleanup:
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET

View file

@ -0,0 +1,15 @@
// This file contains IP checksum algorithms that are not specific to any
// architecture and don't use hardware acceleration.
//go:build !amd64
package tschecksum
import "strconv"
func Checksum(data []byte, initial uint16) uint16 {
if strconv.IntSize < 64 {
return checksumGeneric32(data, initial)
}
return checksumGeneric64(data, initial)
}

View file

@ -0,0 +1,578 @@
//go:build ignore
//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go
package main
import (
"fmt"
"math"
"math/bits"
"github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
const checksumSignature = "func(b []byte, initial uint16) uint16"
func loadParams() (accum, buf, n reg.GPVirtual) {
accum, buf, n = GP64(), GP64(), GP64()
Load(Param("initial"), accum)
XCHGB(accum.As8H(), accum.As8L())
Load(Param("b").Base(), buf)
Load(Param("b").Len(), n)
return
}
type simdStrategy int
const (
sse2 = iota
avx2
)
const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes.
func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const simdReadSize = 16
if minSIMDSize > tinyBufferSize {
Comment("skip all SIMD for small buffers")
if minSIMDSize <= math.MaxUint8 {
CMPQ(n, operand.U8(minSIMDSize))
} else {
CMPQ(n, operand.U32(minSIMDSize))
}
JGE(operand.LabelRef("startSIMD"))
handleRemaining(n, buf, accum64, minSIMDSize-1)
JMP(operand.LabelRef("foldAndReturn"))
}
Label("startSIMD")
// chains is the number of accumulators to use. This improves speed via
// reduced data dependency. We combine the accumulators once when the big
// loop is complete.
simdAccumulate := make([]reg.VecVirtual, chains)
for i := range simdAccumulate {
switch strategy {
case sse2:
simdAccumulate[i] = XMM()
PXOR(simdAccumulate[i], simdAccumulate[i])
case avx2:
simdAccumulate[i] = YMM()
VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i])
}
}
var zero reg.VecVirtual
if strategy == sse2 {
zero = XMM()
PXOR(zero, zero)
}
// Number of loads per big loop
const unroll = 16
// Number of bytes
loopSize := uint64(simdReadSize * unroll)
if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
Comment(fmt.Sprintf("Number of %d byte iterations", loopSize))
SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount)
JZ(operand.LabelRef("smallLoop"))
Label("bigLoop")
for i := 0; i < unroll; i++ {
chain := i % chains
switch strategy {
case sse2:
sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains])
case avx2:
avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain])
}
}
ADDQ(operand.U32(loopSize), buf)
DECQ(loopCount)
JNZ(operand.LabelRef("bigLoop"))
Label("bigCleanup")
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JLT(operand.LabelRef("doneSmallLoop"))
Commentf("now read a single %d byte unit of data at a time", simdReadSize)
Label("smallLoop")
switch strategy {
case sse2:
sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1])
case avx2:
avx2AccumulateStep(0, buf, simdAccumulate[0])
}
ADDQ(operand.Imm(uint64(simdReadSize)), buf)
SUBQ(operand.Imm(uint64(simdReadSize)), n)
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JGE(operand.LabelRef("smallLoop"))
Label("doneSmallLoop")
CMPQ(n, operand.Imm(0))
JE(operand.LabelRef("doneSIMD"))
Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1)
maskDataPtr := GP64()
LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr)
dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize}
// scale 8 is only correct here because n is guaranteed to be even and we
// do not generate masks for odd lengths
maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16}
remainder := XMM()
switch strategy {
case sse2:
MOVOU(dataAddr, remainder)
PAND(maskAddr, remainder)
low := XMM()
MOVOA(remainder, low)
PUNPCKHWL(zero, remainder)
PUNPCKLWL(zero, low)
PADDD(remainder, simdAccumulate[0])
PADDD(low, simdAccumulate[1])
case avx2:
// Note: this is very similar to the sse2 path but MOVOU has a massive
// performance hit if used here, presumably due to switching between SSE
// and AVX2 modes.
VMOVDQU(dataAddr, remainder)
VPAND(maskAddr, remainder, remainder)
temp := YMM()
VPMOVZXWD(remainder, temp)
VPADDD(temp, simdAccumulate[0], simdAccumulate[0])
}
Label("doneSIMD")
Comment("Multi-chain loop is done, combine the accumulators")
for i := range simdAccumulate {
if i == 0 {
continue
}
switch strategy {
case sse2:
PADDD(simdAccumulate[i], simdAccumulate[0])
case avx2:
VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0])
}
}
if strategy == avx2 {
Comment("extract the YMM into a pair of XMM and sum them")
tmp := YMM()
VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX())
xAccumulate := XMM()
VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate)
simdAccumulate = []reg.VecVirtual{xAccumulate}
}
Comment("extract the XMM into GP64")
low, high := GP64(), GP64()
switch strategy {
case sse2:
MOVQ(simdAccumulate[0], low)
PSRLDQ(operand.Imm(8), simdAccumulate[0])
MOVQ(simdAccumulate[0], high)
case avx2:
VPEXTRQ(operand.Imm(0), simdAccumulate[0], low)
VPEXTRQ(operand.Imm(1), simdAccumulate[0], high)
Comment("no more AVX code, clear upper registers to avoid SSE slowdowns")
VZEROUPPER()
}
ADDQ(low, accum64)
ADCQ(high, accum64)
Label("foldAndReturn")
foldWithCF(accum64, strategy == avx2)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleOddLength generates instructions to incorporate the last byte into
// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to
// handle that if overflow is possible.
func handleOddLength(n, buf, accum64 reg.GPVirtual) {
Comment("handle odd length buffers; they are difficult to handle in general")
TESTQ(operand.U32(1), n)
JZ(operand.LabelRef("lengthIsEven"))
tmp := GP64()
MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp)
DECQ(n)
ADDQ(tmp, accum64)
Label("lengthIsEven")
}
func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) {
high, low := XMM(), XMM()
MOVOU(operand.Mem{Disp: offset, Base: buf}, high)
MOVOA(high, low)
PUNPCKHWL(zero, high)
PUNPCKLWL(zero, low)
PADDD(high, accumulate1)
PADDD(low, accumulate2)
}
func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) {
tmp := YMM()
VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp)
VPADDD(tmp, accumulate, accumulate)
}
func generateAMD64Checksum(name, doc string) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const (
// numChains is the number of accumulators and carry counters to use.
// This improves speed via reduced data dependency. We combine the
// accumulators and carry counters once when the loop is complete.
numChains = 4
unroll = 32 // The number of 64-bit reads to perform per iteration of the loop.
loopSize = 8 * unroll // The number of bytes read per iteration of the loop.
)
if bits.Len(loopSize) != bits.Len(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize))
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount)
JZ(operand.LabelRef("startCleanup"))
CLC()
chains := make([]struct {
accum reg.GPVirtual
carries reg.GPVirtual
}, numChains)
for i := range chains {
if i == 0 {
chains[i].accum = accum64
} else {
chains[i].accum = GP64()
XORQ(chains[i].accum, chains[i].accum)
}
chains[i].carries = GP64()
XORQ(chains[i].carries, chains[i].carries)
}
Label("bigLoop")
var curChain int
for i := 0; i < unroll; i++ {
// It is significantly faster to use a ADCX/ADOX pair instead of plain
// ADC, which results in two dependency chains, however those require
// ADX support, which was added after AVX2. If AVX2 is available, that's
// even better than ADCX/ADOX.
//
// However, multiple dependency chains using multiple accumulators and
// occasionally storing CF into temporary counters seems to work almost
// as well.
addr := operand.Mem{Disp: i * 8, Base: buf}
if i%4 == 0 {
if i > 0 {
ADCQ(operand.Imm(0), chains[curChain].carries)
curChain = (curChain + 1) % len(chains)
}
ADDQ(addr, chains[curChain].accum)
} else {
ADCQ(addr, chains[curChain].accum)
}
}
ADCQ(operand.Imm(0), chains[curChain].carries)
ADDQ(operand.U32(loopSize), buf)
SUBQ(operand.Imm(1), loopCount)
JNZ(operand.LabelRef("bigLoop"))
for i := range chains {
if i == 0 {
ADDQ(chains[i].carries, accum64)
continue
}
ADCQ(chains[i].accum, accum64)
ADCQ(chains[i].carries, accum64)
}
accumulateCF(accum64)
Label("startCleanup")
handleRemaining(n, buf, accum64, loopSize-1)
Label("foldAndReturn")
foldWithCF(accum64, false)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleTinyBuffers computes checksums if the buffer length (the n parameter)
// is less than 32. After computing the checksum, a jump to returnLabel will
// be executed. Otherwise, if the buffer length is at least 32, nothing will be
// modified; a jump to continueLabel will be executed instead.
//
// When jumping to returnLabel, CF may be set and must be accommodated e.g.
// using foldWithCF or accumulateCF.
//
// Anecdotally, this appears to be faster than attempting to coordinate an
// overlapped read (which would also require special handling for buffers
// smaller than 8).
func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) {
Comment("handle tiny buffers (<=31 bytes) specially")
CMPQ(n, operand.Imm(tinyBufferSize))
JGT(continueLabel)
tmp2, tmp4, tmp8 := GP64(), GP64(), GP64()
XORQ(tmp2, tmp2)
XORQ(tmp4, tmp4)
XORQ(tmp8, tmp8)
Comment("shift twice to start because length is guaranteed to be even",
"n = n >> 2; CF = originalN & 2")
SHRQ(operand.Imm(2), n)
JNC(operand.LabelRef("handleTiny4"))
Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]")
MOVWQZX(operand.Mem{Base: buf}, tmp2)
ADDQ(operand.Imm(2), buf)
Label("handleTiny4")
Comment("n = n >> 1; CF = originalN & 4")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny8"))
Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]")
MOVLQZX(operand.Mem{Base: buf}, tmp4)
ADDQ(operand.Imm(4), buf)
Label("handleTiny8")
Comment("n = n >> 1; CF = originalN & 8")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny16"))
Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]")
MOVQ(operand.Mem{Base: buf}, tmp8)
ADDQ(operand.Imm(8), buf)
Label("handleTiny16")
Comment("n = n >> 1; CF = originalN & 16",
"n == 0 now, otherwise we would have branched after comparing with tinyBufferSize")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTinyFinish"))
ADDQ(operand.Mem{Base: buf}, accum)
ADCQ(operand.Mem{Base: buf, Disp: 8}, accum)
Label("handleTinyFinish")
Comment("CF should be included from the previous add, so we use ADCQ.",
"If we arrived via the JNC above, then CF=0 due to the branch condition,",
"so ADCQ will still produce the correct result.")
ADCQ(tmp2, accum)
ADCQ(tmp4, accum)
ADCQ(tmp8, accum)
JMP(returnLabel)
}
// handleRemaining generates a series of conditional unrolled additions,
// starting with 8 bytes long and doubling each time until the length reaches
// max. This is the reverse order of what may be intuitive, but makes the branch
// conditions convenient to compute: perform one right shift each time and test
// against CF.
//
// When done, CF may be set and must be accommodated e.g., using foldWithCF or
// accumulateCF.
//
// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer
// will be performed, overlapping with data that will be read later. The
// duplicate data will be shifted off.
//
// The original buffer length must have been at least 8 bytes long, even if
// n < 8, otherwise this will access memory before the start of the buffer,
// which may be unsafe.
func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) {
Comment("Accumulate carries in this register. It is never expected to overflow.")
carries := GP64()
XORQ(carries, carries)
Comment("We will perform an overlapped read for buffers with length not a multiple of 8.",
"Overlapped in this context means some memory will be read twice, but a shift will",
"eliminate the duplicated data. This extra read is performed at the end of the buffer to",
"preserve any alignment that may exist for the start of the buffer.")
leftover := reg.RCX
MOVQ(n, leftover)
SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining
ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end
JZ(operand.LabelRef("handleRemaining8"))
endBuf := GP64()
// endBuf is the position near the end of the buffer that is just past the
// last multiple of 8: (buf + len(buf)) & ^0x7
LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf)
overlapRead := GP64()
// equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)])
MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead)
Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)")
SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8
NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression
ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8)
SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL
ADDQ(overlapRead, accum64)
ADCQ(operand.Imm(0), carries)
for curBytes := 8; curBytes <= max; curBytes *= 2 {
Label(fmt.Sprintf("handleRemaining%d", curBytes))
SHRQ(operand.Imm(1), n)
if curBytes*2 <= max {
JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
} else {
JNC(operand.LabelRef("handleRemainingComplete"))
}
numLoads := curBytes / 8
for i := 0; i < numLoads; i++ {
addr := operand.Mem{Base: buf, Disp: i * 8}
// It is possible to add the multiple dependency chains trick here
// that generateAMD64Checksum uses but anecdotally it does not
// appear to outweigh the cost.
if i == 0 {
ADDQ(addr, accum64)
continue
}
ADCQ(addr, accum64)
}
ADCQ(operand.Imm(0), carries)
if curBytes > math.MaxUint8 {
ADDQ(operand.U32(uint64(curBytes)), buf)
} else {
ADDQ(operand.U8(uint64(curBytes)), buf)
}
if curBytes*2 >= max {
continue
}
JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
}
Label("handleRemainingComplete")
ADDQ(carries, accum64)
}
func accumulateCF(accum64 reg.GPVirtual) {
Comment("accumulate CF (twice, in case the first time overflows)")
// accum64 += CF
ADCQ(operand.Imm(0), accum64)
// accum64 += CF again if the previous add overflowed. The previous add was
// 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can
// never overflow.
ADCQ(operand.Imm(0), accum64)
}
// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value
// according to ones-complement arithmetic. BMI2 instructions will be used if
// allowBMI2 is true (requires fewer instructions).
func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) {
Comment("add CF and fold")
// CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff
tmp := GP64()
if allowBMI2 {
// effectively, tmp = accum >> 32 (technically, this is a rotate)
RORXQ(operand.Imm(32), accum, tmp)
// accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set
ADCL(tmp.As32(), accum.As32())
// effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate)
RORXL(operand.Imm(16), accum.As32(), tmp.As32())
// accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set
ADCW(tmp.As16(), accum.As16())
} else {
// tmp = uint32(accum); max value 0xffff_ffff
// MOVL clears the upper 32 bits of a GP64 so this is equivalent to the
// non-existent MOVLQZX.
MOVL(accum.As32(), tmp.As32())
// tmp += CF; max value 0x1_0000_0000, CF unset
ADCQ(operand.Imm(0), tmp)
// accum = accum >> 32; max value 0xffff_ffff
SHRQ(operand.Imm(32), accum)
// accum = accum + tmp; max value 0x1_ffff_ffff + CF unset
ADDQ(tmp, accum)
// tmp = uint16(accum); max value 0xffff
MOVWQZX(accum.As16(), tmp)
// accum = accum >> 16; max value 0x1_ffff
SHRQ(operand.Imm(16), accum)
// accum = accum + tmp; max value 0x2_fffe + CF unset
ADDQ(tmp, accum)
// tmp as uint16 = uint16(accum); max value 0xffff
MOVW(accum.As16(), tmp.As16())
// accum = accum >> 16; max value 0x2
SHRQ(operand.Imm(16), accum)
// accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set
ADDW(tmp.As16(), accum.As16())
}
// accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe
ADCW(operand.Imm(0), accum.As16())
}
func generateLoadMasks() {
var offset int
// xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14
GLOBL("xmmLoadMasks", RODATA|NOPTR)
for n := 2; n < 16; n += 2 {
var pattern [16]byte
for i := 0; i < len(pattern); i++ {
if i < len(pattern)-n {
pattern[i] = 0
continue
}
pattern[i] = 0xff
}
DATA(offset, operand.String(pattern[:]))
offset += len(pattern)
}
}
func main() {
generateLoadMasks()
generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2)
generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2)
generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions")
Generate()
}

View file

@ -51,12 +51,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
}
m.defaultInterface.Store(newInterface)
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
return nil
}

View file

@ -165,12 +165,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
if defaultInterface == nil {
return ErrNoRoute
}
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
if err != nil {
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
}
m.defaultInterface.Store(newInterface)
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
return nil
}

View file

@ -27,7 +27,7 @@ type networkUpdateMonitor struct {
var ErrNetlinkBanned = E.New(
"netlink socket in Android is banned by Google, " +
"use the root or system (ADB) user to run sing-box, " +
"or switch to the sing-box Adnroid graphical interface client",
"or switch to the sing-box Android graphical interface client",
)
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {

View file

@ -25,12 +25,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
}
m.defaultInterface.Store(newInterface)
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
return nil
}

View file

@ -102,13 +102,12 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return ErrNoRoute
}
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(index)
if err != nil {
return E.Cause(err, "find updated interface: ", alias)
}
m.defaultInterface.Store(newInterface)
if oldInterface != nil && !oldInterface.Equals(*newInterface) {
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
return nil
}
m.emit(newInterface, 0)

View file

@ -44,7 +44,7 @@ type autoRedirect struct {
}
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
r := &autoRedirect{
return &autoRedirect{
tunOptions: options.TunOptions,
ctx: options.Context,
handler: options.Handler,
@ -56,7 +56,10 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
customRedirectPortFunc: options.CustomRedirectPort,
routeAddressSet: options.RouteAddressSet,
routeExcludeAddressSet: options.RouteExcludeAddressSet,
}
}, nil
}
func (r *autoRedirect) Start() error {
var err error
if runtime.GOOS == "android" {
r.enableIPv4 = true
@ -74,7 +77,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
}
}
if err != nil {
return nil, E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
return E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
}
}
} else {
@ -90,7 +93,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
if !r.useNFTables {
r.iptablesPath, err = exec.LookPath("iptables")
if err != nil {
return nil, E.Cause(err, "iptables is required")
return E.Cause(err, "iptables is required")
}
}
}
@ -100,7 +103,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
r.ip6tablesPath, err = exec.LookPath("ip6tables")
if err != nil {
if !r.enableIPv4 {
return nil, E.Cause(err, "ip6tables is required")
return E.Cause(err, "ip6tables is required")
} else {
r.enableIPv6 = false
r.logger.Error("device has no ip6tables nat support: ", err)
@ -109,10 +112,6 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
}
}
}
return r, nil
}
func (r *autoRedirect) Start() error {
if r.customRedirectPortFunc != nil {
r.customRedirectPort = r.customRedirectPortFunc()
}
@ -132,7 +131,6 @@ func (r *autoRedirect) Start() error {
}
r.redirectServer = server
}
var err error
if r.useNFTables {
r.cleanupNFTables()
err = r.setupNFTables()

View file

@ -32,6 +32,10 @@ func (r *autoRedirect) setupNFTables() error {
return err
}
err = r.interfaceFinder.Update()
if err != nil {
return err
}
r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()

View file

@ -19,13 +19,11 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const WithGVisor = true
const defaultNIC tcpip.NICID = 1
const DefaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
@ -68,28 +66,11 @@ func (t *GVisor) Start() error {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
ipStack, err := NewGVisorStack(linkEndpoint)
if err != nil {
return err
}
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: t.ctx,
stack: t.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
}
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
icmp.NewProtocol6,
},
})
err := ipStack.CreateNIC(defaultNIC, ep)
err := ipStack.CreateNIC(DefaultNIC, ep)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
})
err = ipStack.SetSpoofing(defaultNIC, true)
err = ipStack.SetSpoofing(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
err = ipStack.SetPromiscuousMode(defaultNIC, true)
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}

View file

@ -5,6 +5,7 @@ package tun
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
@ -64,7 +65,7 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
func (c *gLazyConn) HandshakeFailure(err error) error {
if c.handshakeDone {
return nil
return os.ErrInvalid
}
c.request.Complete(err != ErrDrop)
c.handshakeDone = true

51
stack_gvisor_tcp.go Normal file
View file

@ -0,0 +1,51 @@
//go:build with_gvisor
package tun
import (
"context"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
forwarder *tcp.Forwarder
}
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
}
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
return f.forwarder.HandlePacket(id, pkt)
}
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: f.ctx,
stack: f.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
}

View file

@ -102,13 +102,13 @@ func (w *UDPBackWriter) HandshakeSuccess() error {
func (w *UDPBackWriter) HandshakeFailure(err error) error {
w.access.Lock()
defer w.access.Unlock()
if w.packet != nil {
wErr := gWriteUnreachable(w.stack, w.packet)
w.packet.DecRef()
w.packet = nil
return wErr
if w.packet == nil {
return os.ErrInvalid
}
return nil
wErr := gWriteUnreachable(w.stack, w.packet)
w.packet.DecRef()
w.packet = nil
return wErr
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
DefaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,

View file

@ -38,7 +38,7 @@ func (m *Mixed) Start() error {
return err
}
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
ipStack, err := NewGVisorStack(endpoint)
if err != nil {
return err
}
@ -50,6 +50,18 @@ func (m *Mixed) Start() error {
return nil
}
func (m *Mixed) Close() error {
if m.stack == nil {
return nil
}
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}
func (m *Mixed) tunLoop() {
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
m.wintunLoop(winTun)
@ -137,7 +149,7 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
_, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
@ -151,10 +163,10 @@ func (m *Mixed) processPacket(packet []byte) bool {
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = m.processIPv4(packet)
case 6:
case header.IPv6Version:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
@ -222,15 +234,3 @@ func (m *Mixed) packetLoop() {
packet.DecRef()
}
}
func (m *Mixed) Close() error {
if m.stack == nil {
return nil
}
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}

View file

@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -586,7 +586,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, checksum.Checksum(icmpHdr.Payload(), 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return nil
@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
}))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetChecksum(0)
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}

34
stack_system_packet.go Normal file
View file

@ -0,0 +1,34 @@
package tun
import (
"net/netip"
"syscall"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
func PacketIPVersion(packet []byte) int {
return header.IPVersion(packet)
}
func PacketFillHeader(packet []byte, ipVersion int) {
if PacketOffset > 0 {
switch ipVersion {
case header.IPv4Version:
packet[3] = syscall.AF_INET
case header.IPv6Version:
packet[3] = syscall.AF_INET6
}
}
}
func PacketDestination(packet []byte) netip.Addr {
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
return header.IPv4(packet).DestinationAddr()
case header.IPv6Version:
return header.IPv6(packet).DestinationAddr()
default:
return netip.Addr{}
}
}

27
tun.go
View file

@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"github.com/sagernet/sing/common/control"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
@ -24,8 +25,10 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
Name() (string, error)
Start() error
Close() error
UpdateRouteOptions(tunOptions Options) error
}
type WinTun interface {
@ -38,7 +41,7 @@ type LinuxTUN interface {
N.FrontHeadroom
BatchSize() int
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
BatchWrite(buffers [][]byte, offset int) (n int, err error)
TXChecksumOffload() bool
}
@ -54,6 +57,7 @@ type Options struct {
MTU uint32
GSO bool
AutoRoute bool
InterfaceScope bool
Inet4Gateway netip.Addr
Inet6Gateway netip.Addr
DNSServers []netip.Addr
@ -74,6 +78,7 @@ type Options struct {
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceFinder control.InterfaceFinder
InterfaceMonitor DefaultInterfaceMonitor
FileDescriptor int
Logger logger.Logger
@ -99,10 +104,12 @@ func (o *Options) Inet4GatewayAddr() netip.Addr {
case "darwin":
return o.Inet4Address[0].Addr()
default:
if HasNextAddress(o.Inet4Address[0], 1) {
return o.Inet4Address[0].Addr().Next()
} else {
return o.Inet4Address[0].Addr()
if !o.InterfaceScope {
if HasNextAddress(o.Inet4Address[0], 1) {
return o.Inet4Address[0].Addr().Next()
} else {
return o.Inet4Address[0].Addr()
}
}
}
}
@ -123,10 +130,12 @@ func (o *Options) Inet6GatewayAddr() netip.Addr {
case "darwin":
return o.Inet6Address[0].Addr()
default:
if HasNextAddress(o.Inet6Address[0], 1) {
return o.Inet6Address[0].Addr().Next()
} else {
return o.Inet6Address[0].Addr()
if !o.InterfaceScope {
if HasNextAddress(o.Inet6Address[0], 1) {
return o.Inet6Address[0].Addr().Next()
} else {
return o.Inet6Address[0].Addr()
}
}
}
}

View file

@ -9,6 +9,7 @@ import (
"syscall"
"unsafe"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -28,7 +29,15 @@ type NativeTun struct {
options Options
inet4Address [4]byte
inet6Address [16]byte
routerSet bool
routeSet bool
}
func (t *NativeTun) Name() (string, error) {
return unix.GetsockoptString(
int(t.tunFile.Fd()),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
)
}
func New(options Options) (Tun, error) {
@ -96,9 +105,10 @@ var (
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if buffers[0].Byte(0)>>4 == 4 {
switch header.IPVersion(buffers[0].Bytes()) {
case header.IPv4Version:
packetHeader = packetHeader4[:]
} else {
case header.IPv6Version:
packetHeader = packetHeader6[:]
}
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
@ -248,44 +258,63 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
return nil
}
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
err := t.unsetRoutes()
if err != nil {
return err
}
t.options = tunOptions
return t.setRoutes()
}
func (t *NativeTun) setRoutes() error {
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
if t.options.FileDescriptor == 0 {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
if len(routeRanges) > 0 {
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
return E.Cause(err, "add route: ", destination)
gateway = gateway6
}
var interfaceIndex int
if t.options.InterfaceScope {
iff, err := t.options.InterfaceFinder.ByName(t.options.Name)
if err != nil {
return err
}
interfaceIndex = iff.Index
}
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
} else {
return E.Cause(err, "add route: ", destination)
}
}
}
flushDNSCache()
t.routeSet = true
}
flushDNSCache()
t.routerSet = true
}
return nil
}
func (t *NativeTun) unsetRoutes() error {
if !t.routerSet {
if !t.routeSet {
return nil
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
@ -300,7 +329,7 @@ func (t *NativeTun) unsetRoutes() error {
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_DELETE, destination, gateway)
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
if err != nil {
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
}
@ -317,7 +346,7 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
return block(socketFd)
}
func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
func execRoute(rtmType int, interfaceScope bool, interfaceIndex int, destination netip.Prefix, gateway netip.Addr) error {
routeMessage := route.RouteMessage{
Type: rtmType,
Version: unix.RTM_VERSION,
@ -326,6 +355,10 @@ func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error
}
if rtmType == unix.RTM_ADD {
routeMessage.Flags |= unix.RTF_UP
if interfaceScope {
routeMessage.Flags |= unix.RTF_IFSCOPE
routeMessage.Index = interfaceIndex
}
}
if gateway.Is4() {
routeMessage.Addrs = []route.Addr{

View file

@ -2,6 +2,7 @@ package tun
import (
"errors"
"fmt"
"math/rand"
"net"
"net/netip"
@ -13,6 +14,8 @@ import (
"unsafe"
"github.com/sagernet/netlink"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -35,13 +38,15 @@ type NativeTun struct {
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
readAccess sync.Mutex
writeAccess sync.Mutex
vnetHdr bool
writeBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
tcpGROTable *tcpGROTable
udpGroAccess sync.Mutex
udpGROTable *udpGROTable
gro groDisablementFlags
txChecksumOffload bool
}
@ -80,105 +85,6 @@ func New(options Options) (Tun, error) {
return nativeTun, nil
}
func (t *NativeTun) FrontHeadroom() int {
if t.gsoEnabled {
return virtioNetHdrLen
}
return 0
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
var sizes [1]int
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
if err != nil {
return
}
if n == 0 {
return
}
n = sizes[0]
return
} else {
return t.tunFile.Read(p)
}
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if err != nil {
return
}
n = len(p)
return
}
return t.tunFile.Write(p)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.gsoEnabled {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
buf.CopyMulti(buffer.Extend(n), buffers)
_, err := t.tunFile.Write(buffer.Bytes())
buffer.Release()
return err
} else {
return t.tunWriter.WriteVectorised(buffers)
}
}
func (t *NativeTun) BatchSize() int {
if !t.gsoEnabled {
return 1
}
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
batchSize := int(gsoMaxSize/t.options.MTU) * 2
if batchSize > idealBatchSize {
batchSize = idealBatchSize
}
return batchSize*/
return idealBatchSize
}
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}
}
return nil
}
var controlPath string
func init() {
@ -196,29 +102,26 @@ func open(name string, vnetHdr bool) (int, error) {
if err != nil {
return -1, err
}
var ifr struct {
name [16]byte
flags uint16
_ [22]byte
ifr, err := unix.NewIfreq(name)
if err != nil {
unix.Close(fd)
return 0, err
}
copy(ifr.name[:], name)
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
flags := unix.IFF_TUN | unix.IFF_NO_PI
if vnetHdr {
ifr.flags |= unix.IFF_VNET_HDR
flags |= unix.IFF_VNET_HDR
}
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
ifr.SetUint16(uint16(flags))
err = unix.IoctlIfreq(fd, unix.TUNSETIFF, ifr)
if err != nil {
unix.Close(fd)
return -1, errno
return 0, err
}
if err = unix.SetNonblock(fd, true); err != nil {
err = unix.SetNonblock(fd, true)
if err != nil {
unix.Close(fd)
return -1, err
return 0, err
}
return fd, nil
}
@ -250,22 +153,10 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
}
if t.options.GSO {
var vnetHdrEnabled bool
vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name)
err = t.enableGSO()
if err != nil {
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
t.options.Logger.Warn(err)
}
if !vnetHdrEnabled {
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
}
err = setTCPOffload(t.tunFd)
if err != nil {
return err
}
t.gsoEnabled = true
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcp4GROTable = newTCPGROTable()
t.tcp6GROTable = newTCPGROTable()
}
var rxChecksumOffload bool
@ -280,7 +171,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
if err != nil {
return err
}
if err == nil && !txChecksumOffload {
if !txChecksumOffload {
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
if err != nil {
return err
@ -292,6 +183,83 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
return nil
}
func (t *NativeTun) enableGSO() error {
vnetHdrEnabled, err := checkVNETHDREnabled(t.tunFd, t.options.Name)
if err != nil {
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
}
if !vnetHdrEnabled {
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
}
err = setTCPOffload(t.tunFd)
if err != nil {
return E.Cause(err, "enable TCP offload")
}
t.vnetHdr = true
t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcpGROTable = newTCPGROTable()
t.udpGROTable = newUDPGROTable()
err = setUDPOffload(t.tunFd)
if err != nil {
t.gro.disableUDPGRO()
}
return nil
}
func (t *NativeTun) probeTCPGRO() error {
ipPort := netip.AddrPortFrom(t.options.Inet4Address[0].Addr(), 0)
fingerprint := []byte("sing-tun-probe-tun-gro")
segmentSize := len(fingerprint)
iphLen := 20
tcphLen := 20
totalLen := iphLen + tcphLen + segmentSize
bufs := make([][]byte, 2)
for i := range bufs {
bufs[i] = make([]byte, virtioNetHdrLen+totalLen, virtioNetHdrLen+(totalLen*2))
ipv4H := header.IPv4(bufs[i][virtioNetHdrLen:])
ipv4H.Encode(&header.IPv4Fields{
SrcAddr: ipPort.Addr(),
DstAddr: ipPort.Addr(),
Protocol: unix.IPPROTO_TCP,
// Use a zero value TTL as best effort means to reduce chance of
// probe packet leaking further than it needs to.
TTL: 0,
TotalLength: uint16(totalLen),
})
tcpH := header.TCP(bufs[i][virtioNetHdrLen+iphLen:])
tcpH.Encode(&header.TCPFields{
SrcPort: ipPort.Port(),
DstPort: ipPort.Port(),
SeqNum: 1 + uint32(i*segmentSize),
AckNum: 1,
DataOffset: 20,
Flags: header.TCPFlagAck,
WindowSize: 3000,
})
copy(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], fingerprint)
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddressSlice(), ipv4H.DestinationAddressSlice(), uint16(tcphLen+segmentSize))
pseudoCsum = checksum.Checksum(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], pseudoCsum)
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
}
_, err := t.BatchWrite(bufs, virtioNetHdrLen)
return err
}
func (t *NativeTun) Name() (string, error) {
var ifr [unix.IFNAMSIZ + 64]byte
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(t.tunFd),
uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return "", os.NewSyscallError("ioctl TUNGETIFF", errno)
}
return unix.ByteSliceToString(ifr[:]), nil
}
func (t *NativeTun) Start() error {
if t.options.FileDescriptor != 0 {
return nil
@ -307,6 +275,15 @@ func (t *NativeTun) Start() error {
return err
}
if t.vnetHdr && len(t.options.Inet4Address) > 0 {
err = t.probeTCPGRO()
if err != nil {
t.gro.disableTCPGRO()
t.gro.disableUDPGRO()
t.options.Logger.Warn(E.Cause(err, "disabled TUN TCP & UDP GRO due to GRO probe error"))
}
}
if t.options.IPRoute2TableIndex == 0 {
for {
t.options.IPRoute2TableIndex = int(rand.Uint32())
@ -348,6 +325,164 @@ func (t *NativeTun) Close() error {
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
if t.vnetHdr {
n, err = t.tunFile.Read(t.writeBuffer)
if err != nil {
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
return
}
var sizes [1]int
n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0)
if err != nil {
return
}
if n == 0 {
return
}
n = sizes[0]
return
} else {
return t.tunFile.Read(p)
}
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
options, err := hdr.toGSOOptions()
if err != nil {
return 0, err
}
// Don't trust HdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// CsumStart, which is synonymous for IP header length.
if options.GSOType == GSOUDPL4 {
options.HdrLen = options.CsumStart + 8
} else if options.GSOType != GSONone {
if len(in) <= int(options.CsumStart+12) {
return 0, errors.New("packet is too short")
}
tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
options.HdrLen = options.CsumStart + tcpHLen
}
return GSOSplit(in, options, bufs, sizes, offset)
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.vnetHdr {
buffer := buf.Get(virtioNetHdrLen + len(p))
copy(buffer[virtioNetHdrLen:], p)
_, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen)
buf.Put(buffer)
if err != nil {
return
}
n = len(p)
return
}
return t.tunFile.Write(p)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.vnetHdr {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
buf.CopyMulti(buffer.Extend(n), buffers)
_, err := t.tunFile.Write(buffer.Bytes())
buffer.Release()
return err
} else {
return t.tunWriter.WriteVectorised(buffers)
}
}
func (t *NativeTun) FrontHeadroom() int {
if t.vnetHdr {
return virtioNetHdrLen
}
return 0
}
func (t *NativeTun) BatchSize() int {
if !t.vnetHdr {
return 1
}
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
batchSize := int(gsoMaxSize/t.options.MTU) * 2
if batchSize > idealBatchSize {
batchSize = idealBatchSize
}
return batchSize*/
return idealBatchSize
}
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.readAccess.Lock()
defer t.readAccess.Unlock()
n, err = t.tunFile.Read(t.writeBuffer)
if err != nil {
return
}
return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset)
}
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) {
t.writeAccess.Lock()
defer func() {
t.tcpGROTable.reset()
t.udpGROTable.reset()
t.writeAccess.Unlock()
}()
var (
errs error
total int
)
t.gsoToWrite = t.gsoToWrite[:0]
if t.vnetHdr {
err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
for i := range buffers {
t.gsoToWrite = append(t.gsoToWrite, i)
}
}
for _, toWrite := range t.gsoToWrite {
n, err := t.tunFile.Write(buffers[toWrite][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
errs = errors.Join(errs, err)
} else {
total += n
}
}
return total, errs
}
func (t *NativeTun) TXChecksumOffload() bool {
return t.txChecksumOffload
}
@ -359,6 +494,25 @@ func prefixToIPNet(prefix netip.Prefix) *net.IPNet {
}
}
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
if t.options.FileDescriptor > 0 {
return nil
} else if !t.options.AutoRoute {
t.options = tunOptions
return nil
}
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil {
return err
}
err = t.unsetRoute0(tunLink)
if err != nil {
return err
}
t.options = tunOptions
return t.setRoute(tunLink)
}
func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
@ -592,6 +746,7 @@ func (t *NativeTun) rules() []*netlink.Rule {
it = netlink.NewRule()
if t.options.InterfaceMonitor.OverrideAndroidVPN() {
it.Mark = protectedFromVPN
it.MarkSet = true
}
it.Mask = protectedFromVPN
it.Priority = priority
@ -604,6 +759,7 @@ func (t *NativeTun) rules() []*netlink.Rule {
it = netlink.NewRule()
if t.options.InterfaceMonitor.OverrideAndroidVPN() {
it.Mark = protectedFromVPN
it.MarkSet = true
}
it.Mask = protectedFromVPN
it.Family = unix.AF_INET6

View file

@ -12,6 +12,12 @@ import (
"golang.org/x/sys/unix"
)
const (
// TODO: support TSO with ECN bits
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
func checkVNETHDREnabled(fd int, name string) (bool, error) {
ifr, err := unix.NewIfreq(name)
if err != nil {
@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) {
}
func setTCPOffload(fd int) error {
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads)
if err != nil {
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
}
return nil
}
func setUDPOffload(fd int) error {
return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads)
}
type ifreqData struct {
ifrName [unix.IFNAMSIZ]byte
ifrData uintptr

View file

@ -10,11 +10,12 @@ import (
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
if t.gsoEnabled {
if t.vnetHdr {
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
GSOMaxSize: gsoMaxSize,
GRO: true,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})

View file

@ -1,768 +0,0 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"unsafe"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
const (
gsoMaxSize = 65536
tcpFlagsOffset = 13
idealBatchSize = 128
)
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize),
itemsPool: make([][]tcpGROItem, idealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
type tcpGROResult int
const (
tcpGROResultNoop tcpGROResult = iota
tcpGROResultTableInsert
tcpGROResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a tcpGROResultNoop when no
// action was taken, tcpGROResultTableInsert when the evaluated packet was
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return tcpGROResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return tcpGROResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return tcpGROResultNoop
}
}
if len(pkt) < iphLen {
return tcpGROResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return tcpGROResultNoop
}
if len(pkt) < iphLen+tcphLen {
return tcpGROResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return tcpGROResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return tcpGROResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return tcpGROResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return tcpGROResultNoop
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return tcpGROResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return tcpGROResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return tcpGROResultTableInsert
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// applyCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result tcpGROResult
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
result = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
result = tcpGRO(bufs, offset, i, tcp6Table, true)
}
switch result {
case tcpGROResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case tcpGROResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
return E.Errors(err4, err6)
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksumFold(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial)))
return nil
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
// This means CHECKSUM_PARTIAL in skb context. We are responsible
// for computing the checksum starting at hdr.csumStart and placing
// at hdr.csumOffset.
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
if err != nil {
return 0, err
}
}
if len(in) > len(bufs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
}
n := copy(bufs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
}
func checksumNoFold(b []byte, initial uint64) uint64 {
return uint64(checksum.Checksum(b, uint16(initial)))
}
func checksumFold(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

View file

@ -1,5 +0,0 @@
package tun
import E "github.com/sagernet/sing/common/exceptions"
var ErrTooManySegments = E.New("too many segments")

229
tun_offload.go Normal file
View file

@ -0,0 +1,229 @@
package tun
import (
"encoding/binary"
"fmt"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
const (
gsoMaxSize = 65536
idealBatchSize = 128
)
// GSOType represents the type of segmentation offload.
type GSOType int
const (
GSONone GSOType = iota
GSOTCPv4
GSOTCPv6
GSOUDPL4
)
func (g GSOType) String() string {
switch g {
case GSONone:
return "GSONone"
case GSOTCPv4:
return "GSOTCPv4"
case GSOTCPv6:
return "GSOTCPv6"
case GSOUDPL4:
return "GSOUDPL4"
default:
return "unknown"
}
}
// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO
// specification. It is a common representation of GSO metadata that can be
// applied to support packet GSO across tun.Device implementations.
type GSOOptions struct {
// GSOType represents the type of segmentation offload.
GSOType GSOType
// HdrLen is the sum of the layer 3 and 4 header lengths. This field may be
// zero when GSOType == GSONone.
HdrLen uint16
// CsumStart is the head byte index of the packet data to be checksummed,
// i.e. the start of the TCP or UDP header.
CsumStart uint16
// CsumOffset is the offset from CsumStart where the 2-byte checksum value
// should be placed.
CsumOffset uint16
// GSOSize is the size of each segment exclusive of HdrLen. The tail segment
// may be smaller than this value.
GSOSize uint16
// NeedsCsum may be set where GSOType == GSONone. When set, the checksum
// at CsumStart + CsumOffset must be a partial checksum, i.e. the
// pseudo-header sum.
NeedsCsum bool
}
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
)
const tcpFlagsOffset = 13
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
const (
// defined here in order to avoid importation of any platform-specific pkgs
ipProtoTCP = 6
ipProtoUDP = 17
)
// GSOSplit splits packets from 'in' into outBufs[<index>][outOffset:], writing
// the size of each element into sizes. It returns the number of buffers
// populated, and/or an error. Callers may pass an 'in' slice that overlaps with
// the first element of outBuffers, i.e. &in[0] may be equal to
// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the
// value of options.NeedsCsum. Length of each outBufs element must be greater
// than or equal to the length of 'in', otherwise output may be silently
// truncated.
func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) {
cSumAt := int(options.CsumStart) + int(options.CsumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
if len(in) < int(options.HdrLen) {
return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen)
}
// Handle the conditions where we are copying a single element to outBuffs.
payloadLen := len(in) - int(options.HdrLen)
if options.GSOType == GSONone || payloadLen < int(options.GSOSize) {
if len(in) > len(outBufs[0][outOffset:]) {
return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:]))
}
if options.NeedsCsum {
// The initial value at the checksum offset should be summed with
// the checksum we compute. This is typically the pseudo-header sum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum.Checksum(in[options.CsumStart:], initial))
}
sizes[0] = copy(outBufs[0][outOffset:], in)
return 1, nil
}
if options.HdrLen < options.CsumStart {
return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
}
if len(in) < 20 {
return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20)
}
case 6:
if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
}
if len(in) < 40 {
return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
iphLen := int(options.CsumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if ipVersion == 4 {
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
transportCsumAt := int(options.CsumStart + options.CsumOffset)
var firstTCPSeqNum uint32
var protocol uint8
if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 {
protocol = ipProtoTCP
if len(in) < int(options.CsumStart)+20 {
return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)",
len(in), options.CsumStart, 20)
}
firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:])
} else {
protocol = ipProtoUDP
}
nextSegmentDataAt := int(options.HdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBufs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(options.HdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBufs[i][outOffset:]
copy(out, in[:iphLen])
if ipVersion == 4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
out[10], out[11] = 0, 0 // clear ipv4 header checksum
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksum.Checksum(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// copy transport header
copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen])
if protocol == ipProtoTCP {
// set TCP seq and adjust TCP flags
tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i))
binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[options.CsumStart+tcpFlagsOffset] &^= clearFlags
}
} else {
// set UDP header len
binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart))
}
// payload
copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// transport checksum
out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
transportHeaderLen := int(options.HdrLen - options.CsumStart)
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
transportCSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(protocol), in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
transportCSum = ^checksum.Checksum(out[options.CsumStart:totalLen], transportCSum)
binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum)
nextSegmentDataAt += int(options.GSOSize)
}
return i, nil
}

10
tun_offload_errors.go Normal file
View file

@ -0,0 +1,10 @@
package tun
import (
"errors"
)
// ErrTooManySegments is returned by Device.Read() when segmentation
// overflows the length of supplied buffers. This error should not cause
// reads to cease.
var ErrTooManySegments = errors.New("too many segments")

937
tun_offload_linux.go Normal file
View file

@ -0,0 +1,937 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"unsafe"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"golang.org/x/sys/unix"
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) {
var gsoType GSOType
switch v.gsoType {
case unix.VIRTIO_NET_HDR_GSO_NONE:
gsoType = GSONone
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
gsoType = GSOTCPv4
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
gsoType = GSOTCPv6
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
gsoType = GSOUDPL4
default:
return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType)
}
return GSOOptions{
GSOType: gsoType,
HdrLen: v.hdrLen,
CsumStart: v.csumStart,
CsumOffset: v.csumOffset,
GSOSize: v.gsoSize,
NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0,
}, nil
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// tcpFlowKey represents the key for a TCP flow.
type tcpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
isV6 bool
}
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
type tcpGROTable struct {
itemsByFlow map[tcpFlowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, idealBatchSize),
itemsPool: make([][]tcpGROItem, idealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
}
return t
}
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
key := tcpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
key.isV6 = addrSize == 16
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key tcpFlowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// udpFlowKey represents the key for a UDP flow.
type udpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
isV6 bool
}
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
type udpGROTable struct {
itemsByFlow map[udpFlowKey][]udpGROItem
itemsPool [][]udpGROItem
}
func newUDPGROTable() *udpGROTable {
u := &udpGROTable{
itemsByFlow: make(map[udpFlowKey][]udpGROItem, idealBatchSize),
itemsPool: make([][]udpGROItem, idealBatchSize),
}
for i := range u.itemsPool {
u.itemsPool[i] = make([]udpGROItem, 0, idealBatchSize)
}
return u
}
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
key := udpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
key.isV6 = addrSize == 16
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
items, ok := u.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
item := udpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
iphLen: uint8(udphOffset),
cSumKnownInvalid: cSumKnownInvalid,
}
items, ok := u.itemsByFlow[key]
if !ok {
items = u.newItems()
}
items = append(items, item)
u.itemsByFlow[key] = items
}
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
items, _ := u.itemsByFlow[item.key]
items[i] = item
}
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type udpGROItem struct {
key udpFlowKey
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
}
func (u *udpGROTable) newItems() []udpGROItem {
var items []udpGROItem
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
return items
}
func (u *udpGROTable) reset() {
for k, items := range u.itemsByFlow {
items = items[:0]
u.itemsPool = append(u.itemsPool, items)
delete(u.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
// meet all requirements to be merged as part of a GRO operation, otherwise it
// returns false.
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
if len(pktA) < 9 || len(pktB) < 9 {
return false
}
if pktA[0]>>4 == 6 {
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
// cannot coalesce with unequal Traffic class values
return false
}
if pktA[7] != pktB[7] {
// cannot coalesce with unequal Hop limit values
return false
}
} else {
if pktA[1] != pktB[1] {
// cannot coalesce with unequal ToS values
return false
}
if pktA[6]>>5 != pktB[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return false
}
if pktA[8] != pktB[8] {
// cannot coalesce with unequal TTL values
return false
}
}
return true
}
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
// packets involved in the current GRO evaluation. bufsOffset is the offset at
// which packet data begins within bufs.
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
}
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
lenForPseudo := uint16(len(pkt) - int(iphLen))
cSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(proto), pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
return ^checksum.Checksum(pkt[iphLen:], cSum) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
// item, and returns the outcome.
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
headersLen := item.iphLen + udphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalescePktInvalidCSum
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
item.numMerged++
return coalesceSuccess
}
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, and returns the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
maxUint16 = 1<<16 - 1
)
type groResult int
const (
groResultNoop groResult = iota
groResultTableInsert
groResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return groResultNoop
}
}
if len(pkt) < iphLen {
return groResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return groResultNoop
}
if len(pkt) < iphLen+tcphLen {
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return groResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return groResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return groResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return groResultTableInsert
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return groResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return groResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return groResultTableInsert
}
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if item.key.isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + udphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 6,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
if item.key.isV6 {
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Recalculate the UDP len field value
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
// Calculate the pseudo header checksum and place it at the UDP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the udp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
type groCandidateType uint8
const (
notGROCandidate groCandidateType = iota
tcp4GROCandidate
tcp6GROCandidate
udp4GROCandidate
udp6GROCandidate
)
type groDisablementFlags int
const (
tcpGRODisabled groDisablementFlags = 1 << iota
udpGRODisabled
)
func (g *groDisablementFlags) disableTCPGRO() {
*g |= tcpGRODisabled
}
func (g *groDisablementFlags) canTCPGRO() bool {
return (*g)&tcpGRODisabled == 0
}
func (g *groDisablementFlags) disableUDPGRO() {
*g |= udpGRODisabled
}
func (g *groDisablementFlags) canUDPGRO() bool {
return (*g)&udpGRODisabled == 0
}
func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType {
if len(b) < 28 {
return notGROCandidate
}
if b[0]>>4 == 4 {
if b[0]&0x0F != 5 {
// IPv4 packets w/IP options do not coalesce
return notGROCandidate
}
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() {
return tcp4GROCandidate
}
if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() {
return udp4GROCandidate
}
} else if b[0]>>4 == 6 {
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() {
return tcp6GROCandidate
}
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() {
return udp6GROCandidate
}
}
return notGROCandidate
}
const (
udphLen = 8
)
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return groResultNoop
}
}
if len(pkt) < iphLen {
return groResultNoop
}
if len(pkt) < iphLen+udphLen {
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return groResultNoop
}
}
gsoSize := uint16(len(pkt) - udphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return groResultNoop
}
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
if !existing {
return groResultTableInsert
}
// With UDP we only check the last item, otherwise we could reorder packets
// for a given flow. We must also always insert a new item, or successfully
// coalesce with an existing item, for the same reason.
item := items[len(items)-1]
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
var pktCSumKnownInvalid bool
if can == coalesceAppend {
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, len(items)-1)
return groResultCoalesced
case coalesceItemInvalidCSum:
// If the existing item has an invalid csum we take no action. A new
// item will be stored after it, and the existing item will never be
// revisited as part of future coalescing candidacy checks.
case coalescePktInvalidCSum:
// We must insert a new item, but we also mark it as invalid csum
// to prevent a repeat checksum validation.
pktCSumKnownInvalid = true
default:
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
return groResultTableInsert
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO
// are supported/enabled.
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result groResult
switch packetIsGROCandidate(bufs[i][offset:], gro) {
case tcp4GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, false)
case tcp6GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, true)
case udp4GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, false)
case udp6GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, true)
}
switch result {
case groResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case groResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
return errors.Join(errTCP, errUDP)
}

View file

@ -108,7 +108,7 @@ const autoRouteUseSubRanges = runtime.GOOS == "darwin"
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
var routeRanges []netip.Prefix
if o.AutoRoute && len(o.Inet4Address) > 0 {
if len(o.Inet4Address) > 0 {
var inet4Ranges []netip.Prefix
if len(o.Inet4RouteAddress) > 0 {
inet4Ranges = o.Inet4RouteAddress
@ -119,19 +119,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
}
}
}
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet4Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
} else if o.AutoRoute {
if autoRouteUseSubRanges && !underNetworkExtension {
inet4Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
}
} else {
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
} else if runtime.GOOS == "darwin" {
for _, address := range o.Inet4Address {
if address.Bits() < 32 {
inet4Ranges = append(inet4Ranges, address.Masked())
}
}
} else {
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if len(o.Inet4RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet4Ranges...)
@ -161,19 +169,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
}
}
}
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet6Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
} else if o.AutoRoute {
if autoRouteUseSubRanges && !underNetworkExtension {
inet6Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
}
} else {
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
} else if runtime.GOOS == "darwin" {
for _, address := range o.Inet6Address {
if address.Bits() < 32 {
inet6Ranges = append(inet6Ranges, address.Masked())
}
}
} else {
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
if len(o.Inet6RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet6Ranges...)

View file

@ -72,7 +72,7 @@ func (t *NativeTun) configure() error {
if err != nil {
return E.Cause(err, "set ipv4 address")
}
if !t.options.EXP_DisableDNSHijack {
if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack {
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4)
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) {
dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()}
@ -83,6 +83,11 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv4 dns")
}
}
} else {
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), nil, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
}
}
}
if len(t.options.Inet6Address) > 0 {
@ -90,7 +95,7 @@ func (t *NativeTun) configure() error {
if err != nil {
return E.Cause(err, "set ipv6 address")
}
if !t.options.EXP_DisableDNSHijack {
if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack {
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6)
if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) {
dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()}
@ -101,6 +106,11 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 dns")
}
}
} else {
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), nil, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
}
}
}
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
@ -148,6 +158,10 @@ func (t *NativeTun) configure() error {
return nil
}
func (t *NativeTun) Name() (string, error) {
return t.options.Name, nil
}
func (t *NativeTun) Start() error {
if !t.options.AutoRoute {
return nil
@ -158,13 +172,7 @@ func (t *NativeTun) Start() error {
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
err = addRouteList(luid, routeRanges, gateway4, gateway6, 0)
if err != nil {
return err
}
@ -349,7 +357,40 @@ func (t *NativeTun) Start() error {
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
return 0, os.ErrInvalid
t.running.Add(1)
defer t.running.Done()
retry:
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
start := nanotime()
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
var packet []byte
packet, err = t.session.ReceivePacket()
switch err {
case nil:
n = copy(p, packet)
t.session.ReleaseReceivePacket(packet)
t.rate.update(uint64(n))
return
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_INVALID_DATA:
return 0, errors.New("send ring corrupt")
}
return 0, fmt.Errorf("read failed: %w", err)
}
}
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
@ -498,6 +539,63 @@ func (t *NativeTun) Close() error {
return err
}
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
t.options = tunOptions
if !t.options.AutoRoute {
return nil
}
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
luid := winipcfg.LUID(t.adapter.LUID())
err = luid.FlushRoutes(windows.AF_UNSPEC)
if err != nil {
return err
}
err = addRouteList(luid, routeRanges, gateway4, gateway6, 0)
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
return nil
}
func addRouteList(luid winipcfg.LUID, destinations []netip.Prefix, gateway4 netip.Addr, gateway6 netip.Addr, metric uint32) error {
row := winipcfg.MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
row.Metric = metric
nextHop4 := row.NextHop
nextHop6 := row.NextHop
if gateway4.IsValid() {
nextHop4.SetAddr(gateway4)
}
if gateway6.IsValid() {
nextHop6.SetAddr(gateway6)
}
for _, destination := range destinations {
err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
if destination.Addr().Is4() {
row.NextHop = nextHop4
} else {
row.NextHop = nextHop6
}
err = row.Create()
if err != nil {
return err
}
}
return nil
}
func generateGUIDByDeviceName(name string) *windows.GUID {
hash := md5.New()
hash.Write([]byte("wintun"))