diff --git a/Makefile b/Makefile index 3efdba9..f7a8532 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,5 @@ lint_install: test: go build -v . + go test -bench=. ./internal/checksum_test #go test -v . diff --git a/go.mod b/go.mod index e7644c0..0bb96e7 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ require ( github.com/go-ole/go-ole v1.3.0 github.com/google/btree v1.1.3 github.com/sagernet/fswatch v0.1.1 - github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 + github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.6.0-alpha.11 + github.com/sagernet/sing v0.6.0-beta.2 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 270fa12..e2ea32a 100644 --- a/go.sum +++ b/go.sum @@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs= github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o= -github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM= -github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= +github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs= +github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.6.0-alpha.11 h1:ZcZlA0/vdDeiipAbjK73x9VabGJ/RRcAJgWhOo/OoBk= -github.com/sagernet/sing v0.6.0-alpha.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ= +github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/internal/checksum_test/sum_bench_test.go b/internal/checksum_test/sum_bench_test.go new file mode 100644 index 0000000..35ee021 --- /dev/null +++ b/internal/checksum_test/sum_bench_test.go @@ -0,0 +1,33 @@ +package checksum_test + +import ( + "crypto/rand" + "testing" + + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/tschecksum" +) + +func BenchmarkTsChecksum(b *testing.B) { + packet := make([][]byte, 1000) + for i := 0; i < 1000; i++ { + packet[i] = make([]byte, 1500) + rand.Read(packet[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + tschecksum.Checksum(packet[i%1000], 0) + } +} + +func BenchmarkGChecksum(b *testing.B) { + packet := make([][]byte, 1000) + for i := 0; i < 1000; i++ { + packet[i] = make([]byte, 1500) + rand.Read(packet[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum.ChecksumDefault(packet[i%1000], 0) + } +} diff --git a/internal/gtcpip/checksum/checksum.go b/internal/gtcpip/checksum/checksum.go index 5d4e117..db03e64 100644 --- a/internal/gtcpip/checksum/checksum.go +++ b/internal/gtcpip/checksum/checksum.go @@ -30,34 +30,6 @@ func Put(b []byte, xsum uint16) { binary.BigEndian.PutUint16(b, xsum) } -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. This function uses an optimized version of the checksum -// algorithm. -// -// The initial checksum must have been computed on an even number of bytes. -func Checksum(buf []byte, initial uint16) uint16 { - s, _ := calculateChecksum(buf, false, initial) - return s -} - -// Checksumer calculates checksum defined in RFC 1071. -type Checksumer struct { - sum uint16 - odd bool -} - -// Add adds b to checksum. -func (c *Checksumer) Add(b []byte) { - if len(b) > 0 { - c.sum, c.odd = calculateChecksum(b, c.odd, c.sum) - } -} - -// Checksum returns the latest checksum value. -func (c *Checksumer) Checksum() uint16 { - return c.sum -} - // Combine combines the two uint16 to form their checksum. This is done // by adding them and the carry. // @@ -66,3 +38,8 @@ func Combine(a, b uint16) uint16 { v := uint32(a) + uint32(b) return uint16(v + v>>16) } + +func ChecksumDefault(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, initial) + return s +} diff --git a/internal/gtcpip/checksum/checksum_default.go b/internal/gtcpip/checksum/checksum_default.go new file mode 100644 index 0000000..99a2d75 --- /dev/null +++ b/internal/gtcpip/checksum/checksum_default.go @@ -0,0 +1,12 @@ +//go:build !amd64 + +package checksum + +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. This function uses an optimized version of the checksum +// algorithm. +// +// The initial checksum must have been computed on an even number of bytes. +func Checksum(buf []byte, initial uint16) uint16 { + return ChecksumDefault(buf, initial) +} diff --git a/internal/gtcpip/checksum/checksum_ts.go b/internal/gtcpip/checksum/checksum_ts.go new file mode 100644 index 0000000..f6766d3 --- /dev/null +++ b/internal/gtcpip/checksum/checksum_ts.go @@ -0,0 +1,9 @@ +//go:build amd64 + +package checksum + +import "github.com/sagernet/sing-tun/internal/tschecksum" + +func Checksum(buf []byte, initial uint16) uint16 { + return tschecksum.Checksum(buf, initial) +} diff --git a/internal/gtcpip/header/interfaces.go b/internal/gtcpip/header/interfaces.go new file mode 100644 index 0000000..b304532 --- /dev/null +++ b/internal/gtcpip/header/interfaces.go @@ -0,0 +1,136 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "net/netip" + + tcpip "github.com/sagernet/sing-tun/internal/gtcpip" +) + +const ( + // MaxIPPacketSize is the maximum supported IP packet size, excluding + // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit + // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is + // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where + // m is the minimum IPv6 header size; we leave room for some potential + // IP options. + MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize +) + +// Transport offers generic methods to query and/or update the fields of the +// header of a transport protocol buffer. +type Transport interface { + // SourcePort returns the value of the "source port" field. + SourcePort() uint16 + + // Destination returns the value of the "destination port" field. + DestinationPort() uint16 + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourcePort sets the value of the "source port" field. + SetSourcePort(uint16) + + // SetDestinationPort sets the value of the "destination port" field. + SetDestinationPort(uint16) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // Payload returns the data carried in the transport buffer. + Payload() []byte +} + +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + +// Network offers generic methods to query and/or update the fields of the +// header of a network protocol buffer. +type Network interface { + // SourceAddress returns the value of the "source address" field. + SourceAddress() tcpip.Address + + // DestinationAddress returns the value of the "destination address" + // field. + DestinationAddress() tcpip.Address + + DestinationAddr() netip.Addr + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourceAddress sets the value of the "source address" field. + SetSourceAddress(tcpip.Address) + + // SetDestinationAddress sets the value of the "destination address" + // field. + SetDestinationAddress(tcpip.Address) + + SetDestinationAddr(addr netip.Addr) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // TransportProtocol returns the number of the transport protocol + // stored in the payload. + TransportProtocol() tcpip.TransportProtocolNumber + + // Payload returns a byte slice containing the payload of the network + // packet. + Payload() []byte + + // TOS returns the values of the "type of service" and "flow label" fields. + TOS() (uint8, uint32) + + // SetTOS sets the values of the "type of service" and "flow label" fields. + SetTOS(t uint8, l uint32) +} + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/internal/gtcpip/header/ipv6_extension_headers.go b/internal/gtcpip/header/ipv6_extension_headers.go index 3ab135d..20064d8 100644 --- a/internal/gtcpip/header/ipv6_extension_headers.go +++ b/internal/gtcpip/header/ipv6_extension_headers.go @@ -18,10 +18,8 @@ import ( "encoding/binary" "errors" "fmt" - "io" "math" - "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/sing-tun/internal/gtcpip" "github.com/sagernet/sing/common" ) @@ -145,79 +143,6 @@ func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) i return ((padLen + align - 1) & ^(align - 1)) - padLen } -// IPv6PayloadHeader is implemented by the various headers that can be found -// in an IPv6 payload. -// -// These headers include IPv6 extension headers or upper layer data. -type IPv6PayloadHeader interface { - isIPv6PayloadHeader() - - // Release frees all resources held by the header. - Release() -} - -// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator -// encounters a Next Header field it does not recognize as an IPv6 extension -// header. The caller is responsible for releasing the underlying buffer after -// it's no longer needed. -type IPv6RawPayloadHeader struct { - Identifier IPv6ExtensionHeaderIdentifier - Buf buffer.Buffer -} - -// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. -func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {} - -// Release implements IPv6PayloadHeader.Release. -func (i IPv6RawPayloadHeader) Release() { - i.Buf.Release() -} - -// ipv6OptionsExtHdr is an IPv6 extension header that holds options. -type ipv6OptionsExtHdr struct { - buf *buffer.View -} - -// Release implements IPv6PayloadHeader.Release. -func (i ipv6OptionsExtHdr) Release() { - if i.buf != nil { - i.buf.Release() - } -} - -// Iter returns an iterator over the IPv6 extension header options held in b. -func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { - it := IPv6OptionsExtHdrOptionsIterator{} - it.reader = i.buf - return it -} - -// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header -// options. -// -// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last -// used, no changes to the underlying buffer may happen. Doing so may cause -// undefined and unexpected behaviour. It is fine to obtain an -// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then -// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator -// obtained before modification is no longer used. -type IPv6OptionsExtHdrOptionsIterator struct { - reader *buffer.View - - // optionOffset is the number of bytes from the first byte of the - // options field to the beginning of the current option. - optionOffset uint32 - - // nextOptionOffset is the offset of the next option. - nextOptionOffset uint32 -} - -// OptionOffset returns the number of bytes parsed while processing the -// option field of the current Extension Header. -func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { - return i.optionOffset -} - // IPv6OptionUnknownAction is the action that must be taken if the processing // IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2. type IPv6OptionUnknownAction int @@ -294,143 +219,6 @@ func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUn // is malformed. var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") -// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension -// header option that is unknown by the parsing utilities. -type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIdentifier - Data *buffer.View -} - -// UnknownAction implements IPv6OptionUnknownAction.UnknownAction. -func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return ipv6UnknownActionFromIdentifier(o.Identifier) -} - -// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. -func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} - -// Next returns the next option in the options data. -// -// If the next item is not a known extension header option, -// IPv6UnknownExtHdrOption will be returned with the option identifier and data. -// -// The return is of the format (option, done, error). done will be true when -// Next is unable to return anything because the iterator has reached the end of -// the options data, or an error occurred. -func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { - for { - i.optionOffset = i.nextOptionOffset - temp, err := i.reader.ReadByte() - if err != nil { - // If we can't read the first byte of a new option, then we know the - // options buffer has been exhausted and we are done iterating. - return nil, true, nil - } - id := IPv6ExtHdrOptionIdentifier(temp) - - // If the option identifier indicates the option is a Pad1 option, then we - // know the option does not have Length and Data fields. End processing of - // the Pad1 option and continue processing the buffer as a new option. - if id == ipv6Pad1ExtHdrOptionIdentifier { - i.nextOptionOffset = i.optionOffset + 1 - continue - } - - length, err := i.reader.ReadByte() - if err != nil { - if err != io.EOF { - // ReadByte should only ever return nil or io.EOF. - panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err)) - } - - // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once - // we start parsing an option; we expect the reader to contain enough - // bytes for the whole option. - return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) - } - - // Do we have enough bytes in the reader for the next option? - if n := i.reader.Size(); n < int(length) { - // Consume the remaining buffer. - i.reader.TrimFront(i.reader.Size()) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } - - i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ - - switch id { - case ipv6PadNExtHdrOptionIdentifier: - // Special-case the variable length padding option to avoid a copy. - i.reader.TrimFront(int(length)) - continue - case ipv6RouterAlertHopByHopOptionIdentifier: - var routerAlertValue [ipv6RouterAlertPayloadLength]byte - if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil { - switch err { - case io.EOF, io.ErrUnexpectedEOF: - return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) - default: - return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) - } - } else if n != int(length) { - return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) - } - return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil - default: - bytes := buffer.NewView(int(length)) - if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) - } - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil - } - } -} - -// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options -// extension header. -type IPv6HopByHopOptionsExtHdr struct { - ipv6OptionsExtHdr -} - -// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. -func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {} - -// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options -// extension header. -type IPv6DestinationOptionsExtHdr struct { - ipv6OptionsExtHdr -} - -// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. -func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {} - -// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific -// data as outlined in RFC 8200 section 4.4. -type IPv6RoutingExtHdr struct { - Buf *buffer.View -} - -// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. -func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {} - -// Release implements IPv6PayloadHeader.Release. -func (b IPv6RoutingExtHdr) Release() { - b.Buf.Release() -} - -// SegmentsLeft returns the Segments Left field. -func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 { - return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx] -} - // IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific // data as outlined in RFC 8200 section 4.5. // @@ -473,242 +261,6 @@ func (b IPv6FragmentExtHdr) IsAtomic() bool { return !b.More() && b.FragmentOffset() == 0 } -// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. -// -// The IPv6 payload may contain IPv6 extension headers before any upper layer -// data. -// -// Note, between when an IPv6PayloadIterator is obtained and last used, no -// changes to the payload may happen. Doing so may cause undefined and -// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate -// over the first few headers then modify the backing payload so long as the -// IPv6PayloadIterator obtained before modification is no longer used. -type IPv6PayloadIterator struct { - // The identifier of the next header to parse. - nextHdrIdentifier IPv6ExtensionHeaderIdentifier - - payload buffer.Buffer - - // Indicates to the iterator that it should return the remaining payload as a - // raw payload on the next call to Next. - forceRaw bool - - // headerOffset is the offset of the beginning of the current extension - // header starting from the beginning of the fixed header. - headerOffset uint32 - - // parseOffset is the byte offset into the current extension header of the - // field we are currently examining. It can be added to the header offset - // if the absolute offset within the packet is required. - parseOffset uint32 - - // nextOffset is the offset of the next header. - nextOffset uint32 -} - -// HeaderOffset returns the offset to the start of the extension -// header most recently processed. -func (i IPv6PayloadIterator) HeaderOffset() uint32 { - return i.headerOffset -} - -// ParseOffset returns the number of bytes successfully parsed. -func (i IPv6PayloadIterator) ParseOffset() uint32 { - return i.headerOffset + i.parseOffset -} - -// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing -// extension headers, or a raw payload if the payload cannot be parsed. The -// iterator takes ownership of the payload. -func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator { - return IPv6PayloadIterator{ - nextHdrIdentifier: nextHdrIdentifier, - payload: payload, - nextOffset: IPv6FixedHeaderSize, - } -} - -// Release frees the resources owned by the iterator. -func (i *IPv6PayloadIterator) Release() { - i.payload.Release() -} - -// AsRawHeader returns the remaining payload of i as a raw header and -// optionally consumes the iterator. -// -// If consume is true, calls to Next after calling AsRawHeader on i will -// indicate that the iterator is done. The returned header takes ownership of -// its payload. -func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { - identifier := i.nextHdrIdentifier - - var buf buffer.Buffer - if consume { - // Since we consume the iterator, we return the payload as is. - buf = i.payload - - // Mark i as done, but keep track of where we were for error reporting. - *i = IPv6PayloadIterator{ - nextHdrIdentifier: IPv6NoNextHeaderIdentifier, - headerOffset: i.headerOffset, - nextOffset: i.nextOffset, - } - } else { - buf = i.payload.Clone() - } - - return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf} -} - -// Next returns the next item in the payload. -// -// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader -// will be returned with the remaining bytes and next header identifier. -// -// The return is of the format (header, done, error). done will be true when -// Next is unable to return anything because the iterator has reached the end of -// the payload, or an error occurred. -func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { - i.headerOffset = i.nextOffset - i.parseOffset = 0 - // We could be forced to return i as a raw header when the previous header was - // a fragment extension header as the data following the fragment extension - // header may not be complete. - if i.forceRaw { - return i.AsRawHeader(true /* consume */), false, nil - } - - // Is the header we are parsing a known extension header? - switch i.nextHdrIdentifier { - case IPv6HopByHopOptionsExtHdrIdentifier: - nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) - if err != nil { - return nil, true, err - } - - i.nextHdrIdentifier = nextHdrIdentifier - return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil - case IPv6RoutingExtHdrIdentifier: - nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) - if err != nil { - return nil, true, err - } - - i.nextHdrIdentifier = nextHdrIdentifier - return IPv6RoutingExtHdr{view}, false, nil - case IPv6FragmentExtHdrIdentifier: - var data [6]byte - // We ignore the returned bytes because we know the fragment extension - // header specific data will fit in data. - nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) - if err != nil { - return nil, true, err - } - - fragmentExtHdr := IPv6FragmentExtHdr(data) - - // If the packet is not the first fragment, do not attempt to parse anything - // after the fragment extension header as the payload following the fragment - // extension header should not contain any headers; the first fragment must - // hold all the headers up to and including any upper layer headers, as per - // RFC 8200 section 4.5. - if fragmentExtHdr.FragmentOffset() != 0 { - i.forceRaw = true - } - - i.nextHdrIdentifier = nextHdrIdentifier - return fragmentExtHdr, false, nil - case IPv6DestinationOptionsExtHdrIdentifier: - nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) - if err != nil { - return nil, true, err - } - - i.nextHdrIdentifier = nextHdrIdentifier - return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil - case IPv6NoNextHeaderIdentifier: - // This indicates the end of the IPv6 payload. - return nil, true, nil - - default: - // The header we are parsing is not a known extension header. Return the - // raw payload. - return i.AsRawHeader(true /* consume */), false, nil - } -} - -// NextHeaderIdentifier returns the identifier of the header next returned by -// it.Next(). -func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier { - return i.nextHdrIdentifier -} - -// nextHeaderData returns the extension header's Next Header field and raw data. -// -// fragmentHdr indicates that the extension header being parsed is the Fragment -// extension header so the Length field should be ignored as it is Reserved -// for the Fragment extension header. -// -// If bytes is not nil, extension header specific data will be read into bytes -// if it has enough capacity. If bytes is provided but does not have enough -// capacity for the data, nextHeaderData will panic. -func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) { - // We ignore the number of bytes read because we know we will only ever read - // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read - // would return io.EOF to indicate that io.Reader has reached the end of the - // payload. - rdr := i.payload.AsBufferReader() - nextHdrIdentifier, err := rdr.ReadByte() - if err != nil { - return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) - } - i.parseOffset++ - - var length uint8 - length, err = rdr.ReadByte() - if err != nil { - if fragmentHdr { - return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) - } - - return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err) - } - if fragmentHdr { - length = 0 - } - - // Make parseOffset point to the first byte of the Extension Header - // specific data. - i.parseOffset++ - - // length is in 8 byte chunks but doesn't include the first one. - // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement - // in section 4.8 for new extension headers at the top of page 24. - // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet - // units, not including the first 8 octets. - i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) - - bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded - if fragmentHdr { - if n := len(bytes); n < bytesLen { - panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier)) - } - if n, err := io.ReadFull(&rdr, bytes); err != nil { - return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) - } - return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil - } - v := buffer.NewView(bytesLen) - if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - v.Release() - return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) - } - return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil -} - // IPv6SerializableExtHdr provides serialization for IPv6 extension // headers. type IPv6SerializableExtHdr interface { diff --git a/internal/tschecksum/checksum.go b/internal/tschecksum/checksum.go new file mode 100644 index 0000000..677879f --- /dev/null +++ b/internal/tschecksum/checksum.go @@ -0,0 +1,712 @@ +package tschecksum + +import ( + "encoding/binary" + "math/bits" + "strconv" + + "golang.org/x/sys/cpu" +) + +// checksumGeneric64 is a reference implementation of checksum using 64 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric64(b []byte, initial uint16) uint16 { + var ac uint64 + var carry uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 128 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry) + } + b = b[128:] + } + if len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry) + } else { + ac, carry = bits.Add64(ac, uint64(b[0]), carry) + } + } + + folded := ipChecksumFold64(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32 is a reference implementation of checksum using 32 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric32(b []byte, initial uint16) uint16 { + var ac uint32 + var carry uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry) + } else { + ac, carry = bits.Add32(ac, uint32(b[0]), carry) + } + } + + folded := ipChecksumFold32(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32Alternate is an alternate reference implementation of +// checksum using 32 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric32Alternate(b []byte, initial uint16) uint16 { + var ac uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + ac += uint32(binary.BigEndian.Uint16(b[32:34])) + ac += uint32(binary.BigEndian.Uint16(b[34:36])) + ac += uint32(binary.BigEndian.Uint16(b[36:38])) + ac += uint32(binary.BigEndian.Uint16(b[38:40])) + ac += uint32(binary.BigEndian.Uint16(b[40:42])) + ac += uint32(binary.BigEndian.Uint16(b[42:44])) + ac += uint32(binary.BigEndian.Uint16(b[44:46])) + ac += uint32(binary.BigEndian.Uint16(b[46:48])) + ac += uint32(binary.BigEndian.Uint16(b[48:50])) + ac += uint32(binary.BigEndian.Uint16(b[50:52])) + ac += uint32(binary.BigEndian.Uint16(b[52:54])) + ac += uint32(binary.BigEndian.Uint16(b[54:56])) + ac += uint32(binary.BigEndian.Uint16(b[56:58])) + ac += uint32(binary.BigEndian.Uint16(b[58:60])) + ac += uint32(binary.BigEndian.Uint16(b[60:62])) + ac += uint32(binary.BigEndian.Uint16(b[62:64])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + ac += uint32(binary.LittleEndian.Uint16(b[32:34])) + ac += uint32(binary.LittleEndian.Uint16(b[34:36])) + ac += uint32(binary.LittleEndian.Uint16(b[36:38])) + ac += uint32(binary.LittleEndian.Uint16(b[38:40])) + ac += uint32(binary.LittleEndian.Uint16(b[40:42])) + ac += uint32(binary.LittleEndian.Uint16(b[42:44])) + ac += uint32(binary.LittleEndian.Uint16(b[44:46])) + ac += uint32(binary.LittleEndian.Uint16(b[46:48])) + ac += uint32(binary.LittleEndian.Uint16(b[48:50])) + ac += uint32(binary.LittleEndian.Uint16(b[50:52])) + ac += uint32(binary.LittleEndian.Uint16(b[52:54])) + ac += uint32(binary.LittleEndian.Uint16(b[54:56])) + ac += uint32(binary.LittleEndian.Uint16(b[56:58])) + ac += uint32(binary.LittleEndian.Uint16(b[58:60])) + ac += uint32(binary.LittleEndian.Uint16(b[60:62])) + ac += uint32(binary.LittleEndian.Uint16(b[62:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b)) + } else { + ac += uint32(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint32(b[0]) << 8 + } else { + ac += uint32(b[0]) + } + } + + folded := ipChecksumFold32(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric64Alternate is an alternate reference implementation of +// checksum using 64 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric64Alternate(b []byte, initial uint16) uint16 { + var ac uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + ac += uint64(binary.LittleEndian.Uint32(b[32:36])) + ac += uint64(binary.LittleEndian.Uint32(b[36:40])) + ac += uint64(binary.LittleEndian.Uint32(b[40:44])) + ac += uint64(binary.LittleEndian.Uint32(b[44:48])) + ac += uint64(binary.LittleEndian.Uint32(b[48:52])) + ac += uint64(binary.LittleEndian.Uint32(b[52:56])) + ac += uint64(binary.LittleEndian.Uint32(b[56:60])) + ac += uint64(binary.LittleEndian.Uint32(b[60:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b)) + } else { + ac += uint64(binary.LittleEndian.Uint32(b)) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint16(b)) + } else { + ac += uint64(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint64(b[0]) << 8 + } else { + ac += uint64(b[0]) + } + } + + folded := ipChecksumFold64(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 { + sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry)) + // if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff + // therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is + // no need to save the carry flag + sum = (sum >> 16) + (sum & 0xffff) + carry + // sum <= 0x1_fffe therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 { + sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry + // sum <= 0x1_ffff: + // 0xffff + 0xffff = 0x1_fffe + // initialCarry is 0 or 1, for a combined maximum of 0x1_ffff + sum = (sum >> 16) + (sum & 0xffff) + // sum <= 0x1_0000 therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry) + } else { + sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry) + } else { + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint64 + if cpu.IsBigEndian { + sum = uint64(totalLen) + uint64(protocol) + } else { + sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8 + } + sum, carry := addrPartialChecksum64(srcAddr, sum, 0) + sum, carry = addrPartialChecksum64(dstAddr, sum, carry) + + foldedSum := ipChecksumFold64(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint32 + if cpu.IsBigEndian { + sum = uint32(totalLen) + uint32(protocol) + } else { + sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8 + } + sum, carry := addrPartialChecksum32(srcAddr, sum, 0) + sum, carry = addrPartialChecksum32(dstAddr, sum, carry) + + foldedSum := ipChecksumFold32(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and +// dstAddr must be 4 or 16 bytes in length. +func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + if strconv.IntSize < 64 { + return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) + } + return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen) +} diff --git a/internal/tschecksum/checksum_amd64.go b/internal/tschecksum/checksum_amd64.go new file mode 100644 index 0000000..85b925a --- /dev/null +++ b/internal/tschecksum/checksum_amd64.go @@ -0,0 +1,23 @@ +package tschecksum + +import "golang.org/x/sys/cpu" + +var checksum = checksumAMD64 + +// Checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. +func Checksum(data []byte, initial uint16) uint16 { + return checksum(data, initial) +} + +func init() { + if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { + checksum = checksumAVX2 + return + } + if cpu.X86.HasSSE2 { + checksum = checksumSSE2 + return + } +} diff --git a/internal/tschecksum/checksum_generated_amd64.go b/internal/tschecksum/checksum_generated_amd64.go new file mode 100644 index 0000000..acc7350 --- /dev/null +++ b/internal/tschecksum/checksum_generated_amd64.go @@ -0,0 +1,18 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +package tschecksum + +// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2) +// +//go:noescape +func checksumAVX2(b []byte, initial uint16) uint16 + +// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2) +// +//go:noescape +func checksumSSE2(b []byte, initial uint16) uint16 + +// checksumAMD64 computes an IP checksum using amd64 baseline instructions +// +//go:noescape +func checksumAMD64(b []byte, initial uint16) uint16 diff --git a/internal/tschecksum/checksum_generated_amd64.s b/internal/tschecksum/checksum_generated_amd64.s new file mode 100644 index 0000000..5f2e4c5 --- /dev/null +++ b/internal/tschecksum/checksum_generated_amd64.s @@ -0,0 +1,851 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +#include "textflag.h" + +DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff" +DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112 + +// func checksumAVX2(b []byte, initial uint16) uint16 +// Requires: AVX, AVX2, BMI2 +TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + VPXOR Y0, Y0, Y0 + VPXOR Y1, Y1, Y1 + VPXOR Y2, Y2, Y2 + VPXOR Y3, Y3, Y3 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 16(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 32(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 48(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 64(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 80(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 96(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 112(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 128(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 144(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 160(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 176(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 192(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 208(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 224(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 240(DX), Y4 + VPADDD Y4, Y3, Y3 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + VMOVDQU -16(DX)(BX*1), X4 + VPAND -16(CX)(BX*8), X4, X4 + VPMOVZXWD X4, Y4 + VPADDD Y4, Y0, Y0 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + VPADDD Y1, Y0, Y0 + VPADDD Y2, Y0, Y0 + VPADDD Y3, Y0, Y0 + + // extract the YMM into a pair of XMM and sum them + VEXTRACTI128 $0x01, Y0, X1 + VPADDD X0, X1, X0 + + // extract the XMM into GP64 + VPEXTRQ $0x00, X0, CX + VPEXTRQ $0x01, X0, DX + + // no more AVX code, clear upper registers to avoid SSE slowdowns + VZEROUPPER + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + RORXQ $0x20, AX, CX + ADCL CX, AX + RORXL $0x10, AX, CX + ADCW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumSSE2(b []byte, initial uint16) uint16 +// Requires: SSE2 +TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + PXOR X0, X0 + PXOR X1, X1 + PXOR X2, X2 + PXOR X3, X3 + PXOR X4, X4 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 16(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 32(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 48(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 64(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 80(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 96(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 112(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 128(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 144(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 160(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 176(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 192(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 208(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 224(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 240(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + MOVOU -16(DX)(BX*1), X5 + PAND -16(CX)(BX*8), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + PADDD X1, X0 + PADDD X2, X0 + PADDD X3, X0 + + // extract the XMM into GP64 + MOVQ X0, CX + PSRLDQ $0x08, X0 + MOVQ X0, DX + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumAMD64(b []byte, initial uint16) uint16 +TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // Number of 256 byte iterations into loop counter + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + SHRQ $0x08, CX + JZ startCleanup + CLC + XORQ SI, SI + XORQ DI, DI + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + XORQ R12, R12 + +bigLoop: + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ 32(DX), DI + ADCQ 40(DX), DI + ADCQ 48(DX), DI + ADCQ 56(DX), DI + ADCQ $0x00, R8 + ADDQ 64(DX), R9 + ADCQ 72(DX), R9 + ADCQ 80(DX), R9 + ADCQ 88(DX), R9 + ADCQ $0x00, R10 + ADDQ 96(DX), R11 + ADCQ 104(DX), R11 + ADCQ 112(DX), R11 + ADCQ 120(DX), R11 + ADCQ $0x00, R12 + ADDQ 128(DX), AX + ADCQ 136(DX), AX + ADCQ 144(DX), AX + ADCQ 152(DX), AX + ADCQ $0x00, SI + ADDQ 160(DX), DI + ADCQ 168(DX), DI + ADCQ 176(DX), DI + ADCQ 184(DX), DI + ADCQ $0x00, R8 + ADDQ 192(DX), R9 + ADCQ 200(DX), R9 + ADCQ 208(DX), R9 + ADCQ 216(DX), R9 + ADCQ $0x00, R10 + ADDQ 224(DX), R11 + ADCQ 232(DX), R11 + ADCQ 240(DX), R11 + ADCQ 248(DX), R11 + ADCQ $0x00, R12 + ADDQ $0x00000100, DX + SUBQ $0x01, CX + JNZ bigLoop + ADDQ SI, AX + ADCQ DI, AX + ADCQ R8, AX + ADCQ R9, AX + ADCQ R10, AX + ADCQ R11, AX + ADCQ R12, AX + + // accumulate CF (twice, in case the first time overflows) + ADCQ $0x00, AX + ADCQ $0x00, AX + +startCleanup: + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET diff --git a/internal/tschecksum/checksum_generic.go b/internal/tschecksum/checksum_generic.go new file mode 100644 index 0000000..2d6c134 --- /dev/null +++ b/internal/tschecksum/checksum_generic.go @@ -0,0 +1,15 @@ +// This file contains IP checksum algorithms that are not specific to any +// architecture and don't use hardware acceleration. + +//go:build !amd64 + +package tschecksum + +import "strconv" + +func Checksum(data []byte, initial uint16) uint16 { + if strconv.IntSize < 64 { + return checksumGeneric32(data, initial) + } + return checksumGeneric64(data, initial) +} diff --git a/internal/tschecksum/generate_amd64.go b/internal/tschecksum/generate_amd64.go new file mode 100644 index 0000000..a72a59e --- /dev/null +++ b/internal/tschecksum/generate_amd64.go @@ -0,0 +1,578 @@ +//go:build ignore + +//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go + +package main + +import ( + "fmt" + "math" + "math/bits" + + "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +const checksumSignature = "func(b []byte, initial uint16) uint16" + +func loadParams() (accum, buf, n reg.GPVirtual) { + accum, buf, n = GP64(), GP64(), GP64() + Load(Param("initial"), accum) + XCHGB(accum.As8H(), accum.As8L()) + Load(Param("b").Base(), buf) + Load(Param("b").Len(), n) + return +} + +type simdStrategy int + +const ( + sse2 = iota + avx2 +) + +const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes. + +func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const simdReadSize = 16 + + if minSIMDSize > tinyBufferSize { + Comment("skip all SIMD for small buffers") + if minSIMDSize <= math.MaxUint8 { + CMPQ(n, operand.U8(minSIMDSize)) + } else { + CMPQ(n, operand.U32(minSIMDSize)) + } + JGE(operand.LabelRef("startSIMD")) + + handleRemaining(n, buf, accum64, minSIMDSize-1) + JMP(operand.LabelRef("foldAndReturn")) + } + + Label("startSIMD") + + // chains is the number of accumulators to use. This improves speed via + // reduced data dependency. We combine the accumulators once when the big + // loop is complete. + simdAccumulate := make([]reg.VecVirtual, chains) + for i := range simdAccumulate { + switch strategy { + case sse2: + simdAccumulate[i] = XMM() + PXOR(simdAccumulate[i], simdAccumulate[i]) + case avx2: + simdAccumulate[i] = YMM() + VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i]) + } + } + var zero reg.VecVirtual + if strategy == sse2 { + zero = XMM() + PXOR(zero, zero) + } + + // Number of loads per big loop + const unroll = 16 + // Number of bytes + loopSize := uint64(simdReadSize * unroll) + if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + Comment(fmt.Sprintf("Number of %d byte iterations", loopSize)) + SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount) + JZ(operand.LabelRef("smallLoop")) + Label("bigLoop") + for i := 0; i < unroll; i++ { + chain := i % chains + switch strategy { + case sse2: + sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains]) + case avx2: + avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain]) + } + } + ADDQ(operand.U32(loopSize), buf) + DECQ(loopCount) + JNZ(operand.LabelRef("bigLoop")) + + Label("bigCleanup") + + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JLT(operand.LabelRef("doneSmallLoop")) + + Commentf("now read a single %d byte unit of data at a time", simdReadSize) + Label("smallLoop") + + switch strategy { + case sse2: + sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1]) + case avx2: + avx2AccumulateStep(0, buf, simdAccumulate[0]) + } + ADDQ(operand.Imm(uint64(simdReadSize)), buf) + SUBQ(operand.Imm(uint64(simdReadSize)), n) + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JGE(operand.LabelRef("smallLoop")) + + Label("doneSmallLoop") + CMPQ(n, operand.Imm(0)) + JE(operand.LabelRef("doneSIMD")) + + Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1) + + maskDataPtr := GP64() + LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr) + dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize} + // scale 8 is only correct here because n is guaranteed to be even and we + // do not generate masks for odd lengths + maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16} + remainder := XMM() + + switch strategy { + case sse2: + MOVOU(dataAddr, remainder) + PAND(maskAddr, remainder) + low := XMM() + MOVOA(remainder, low) + PUNPCKHWL(zero, remainder) + PUNPCKLWL(zero, low) + PADDD(remainder, simdAccumulate[0]) + PADDD(low, simdAccumulate[1]) + case avx2: + // Note: this is very similar to the sse2 path but MOVOU has a massive + // performance hit if used here, presumably due to switching between SSE + // and AVX2 modes. + VMOVDQU(dataAddr, remainder) + VPAND(maskAddr, remainder, remainder) + + temp := YMM() + VPMOVZXWD(remainder, temp) + VPADDD(temp, simdAccumulate[0], simdAccumulate[0]) + } + + Label("doneSIMD") + + Comment("Multi-chain loop is done, combine the accumulators") + for i := range simdAccumulate { + if i == 0 { + continue + } + switch strategy { + case sse2: + PADDD(simdAccumulate[i], simdAccumulate[0]) + case avx2: + VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0]) + } + } + + if strategy == avx2 { + Comment("extract the YMM into a pair of XMM and sum them") + tmp := YMM() + VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX()) + + xAccumulate := XMM() + VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate) + simdAccumulate = []reg.VecVirtual{xAccumulate} + } + + Comment("extract the XMM into GP64") + low, high := GP64(), GP64() + switch strategy { + case sse2: + MOVQ(simdAccumulate[0], low) + PSRLDQ(operand.Imm(8), simdAccumulate[0]) + MOVQ(simdAccumulate[0], high) + case avx2: + VPEXTRQ(operand.Imm(0), simdAccumulate[0], low) + VPEXTRQ(operand.Imm(1), simdAccumulate[0], high) + + Comment("no more AVX code, clear upper registers to avoid SSE slowdowns") + VZEROUPPER() + } + ADDQ(low, accum64) + ADCQ(high, accum64) + Label("foldAndReturn") + foldWithCF(accum64, strategy == avx2) + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleOddLength generates instructions to incorporate the last byte into +// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to +// handle that if overflow is possible. +func handleOddLength(n, buf, accum64 reg.GPVirtual) { + Comment("handle odd length buffers; they are difficult to handle in general") + TESTQ(operand.U32(1), n) + JZ(operand.LabelRef("lengthIsEven")) + + tmp := GP64() + MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp) + DECQ(n) + ADDQ(tmp, accum64) + + Label("lengthIsEven") +} + +func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) { + high, low := XMM(), XMM() + MOVOU(operand.Mem{Disp: offset, Base: buf}, high) + MOVOA(high, low) + PUNPCKHWL(zero, high) + PUNPCKLWL(zero, low) + PADDD(high, accumulate1) + PADDD(low, accumulate2) +} + +func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) { + tmp := YMM() + VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp) + VPADDD(tmp, accumulate, accumulate) +} + +func generateAMD64Checksum(name, doc string) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const ( + // numChains is the number of accumulators and carry counters to use. + // This improves speed via reduced data dependency. We combine the + // accumulators and carry counters once when the loop is complete. + numChains = 4 + unroll = 32 // The number of 64-bit reads to perform per iteration of the loop. + loopSize = 8 * unroll // The number of bytes read per iteration of the loop. + ) + if bits.Len(loopSize) != bits.Len(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize)) + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount) + JZ(operand.LabelRef("startCleanup")) + CLC() + + chains := make([]struct { + accum reg.GPVirtual + carries reg.GPVirtual + }, numChains) + for i := range chains { + if i == 0 { + chains[i].accum = accum64 + } else { + chains[i].accum = GP64() + XORQ(chains[i].accum, chains[i].accum) + } + chains[i].carries = GP64() + XORQ(chains[i].carries, chains[i].carries) + } + + Label("bigLoop") + + var curChain int + for i := 0; i < unroll; i++ { + // It is significantly faster to use a ADCX/ADOX pair instead of plain + // ADC, which results in two dependency chains, however those require + // ADX support, which was added after AVX2. If AVX2 is available, that's + // even better than ADCX/ADOX. + // + // However, multiple dependency chains using multiple accumulators and + // occasionally storing CF into temporary counters seems to work almost + // as well. + addr := operand.Mem{Disp: i * 8, Base: buf} + + if i%4 == 0 { + if i > 0 { + ADCQ(operand.Imm(0), chains[curChain].carries) + curChain = (curChain + 1) % len(chains) + } + ADDQ(addr, chains[curChain].accum) + } else { + ADCQ(addr, chains[curChain].accum) + } + } + ADCQ(operand.Imm(0), chains[curChain].carries) + ADDQ(operand.U32(loopSize), buf) + SUBQ(operand.Imm(1), loopCount) + JNZ(operand.LabelRef("bigLoop")) + for i := range chains { + if i == 0 { + ADDQ(chains[i].carries, accum64) + continue + } + ADCQ(chains[i].accum, accum64) + ADCQ(chains[i].carries, accum64) + } + + accumulateCF(accum64) + + Label("startCleanup") + handleRemaining(n, buf, accum64, loopSize-1) + Label("foldAndReturn") + foldWithCF(accum64, false) + + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleTinyBuffers computes checksums if the buffer length (the n parameter) +// is less than 32. After computing the checksum, a jump to returnLabel will +// be executed. Otherwise, if the buffer length is at least 32, nothing will be +// modified; a jump to continueLabel will be executed instead. +// +// When jumping to returnLabel, CF may be set and must be accommodated e.g. +// using foldWithCF or accumulateCF. +// +// Anecdotally, this appears to be faster than attempting to coordinate an +// overlapped read (which would also require special handling for buffers +// smaller than 8). +func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) { + Comment("handle tiny buffers (<=31 bytes) specially") + CMPQ(n, operand.Imm(tinyBufferSize)) + JGT(continueLabel) + + tmp2, tmp4, tmp8 := GP64(), GP64(), GP64() + XORQ(tmp2, tmp2) + XORQ(tmp4, tmp4) + XORQ(tmp8, tmp8) + + Comment("shift twice to start because length is guaranteed to be even", + "n = n >> 2; CF = originalN & 2") + SHRQ(operand.Imm(2), n) + JNC(operand.LabelRef("handleTiny4")) + Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]") + MOVWQZX(operand.Mem{Base: buf}, tmp2) + ADDQ(operand.Imm(2), buf) + + Label("handleTiny4") + Comment("n = n >> 1; CF = originalN & 4") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny8")) + Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]") + MOVLQZX(operand.Mem{Base: buf}, tmp4) + ADDQ(operand.Imm(4), buf) + + Label("handleTiny8") + Comment("n = n >> 1; CF = originalN & 8") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny16")) + Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]") + MOVQ(operand.Mem{Base: buf}, tmp8) + ADDQ(operand.Imm(8), buf) + + Label("handleTiny16") + Comment("n = n >> 1; CF = originalN & 16", + "n == 0 now, otherwise we would have branched after comparing with tinyBufferSize") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTinyFinish")) + ADDQ(operand.Mem{Base: buf}, accum) + ADCQ(operand.Mem{Base: buf, Disp: 8}, accum) + + Label("handleTinyFinish") + Comment("CF should be included from the previous add, so we use ADCQ.", + "If we arrived via the JNC above, then CF=0 due to the branch condition,", + "so ADCQ will still produce the correct result.") + ADCQ(tmp2, accum) + ADCQ(tmp4, accum) + ADCQ(tmp8, accum) + + JMP(returnLabel) +} + +// handleRemaining generates a series of conditional unrolled additions, +// starting with 8 bytes long and doubling each time until the length reaches +// max. This is the reverse order of what may be intuitive, but makes the branch +// conditions convenient to compute: perform one right shift each time and test +// against CF. +// +// When done, CF may be set and must be accommodated e.g., using foldWithCF or +// accumulateCF. +// +// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer +// will be performed, overlapping with data that will be read later. The +// duplicate data will be shifted off. +// +// The original buffer length must have been at least 8 bytes long, even if +// n < 8, otherwise this will access memory before the start of the buffer, +// which may be unsafe. +func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) { + Comment("Accumulate carries in this register. It is never expected to overflow.") + carries := GP64() + XORQ(carries, carries) + + Comment("We will perform an overlapped read for buffers with length not a multiple of 8.", + "Overlapped in this context means some memory will be read twice, but a shift will", + "eliminate the duplicated data. This extra read is performed at the end of the buffer to", + "preserve any alignment that may exist for the start of the buffer.") + leftover := reg.RCX + MOVQ(n, leftover) + SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining + ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end + JZ(operand.LabelRef("handleRemaining8")) + endBuf := GP64() + // endBuf is the position near the end of the buffer that is just past the + // last multiple of 8: (buf + len(buf)) & ^0x7 + LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf) + + overlapRead := GP64() + // equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)]) + MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead) + + Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)") + SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8 + NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression + ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8) + SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL + + ADDQ(overlapRead, accum64) + ADCQ(operand.Imm(0), carries) + + for curBytes := 8; curBytes <= max; curBytes *= 2 { + Label(fmt.Sprintf("handleRemaining%d", curBytes)) + SHRQ(operand.Imm(1), n) + if curBytes*2 <= max { + JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } else { + JNC(operand.LabelRef("handleRemainingComplete")) + } + + numLoads := curBytes / 8 + for i := 0; i < numLoads; i++ { + addr := operand.Mem{Base: buf, Disp: i * 8} + // It is possible to add the multiple dependency chains trick here + // that generateAMD64Checksum uses but anecdotally it does not + // appear to outweigh the cost. + if i == 0 { + ADDQ(addr, accum64) + continue + } + ADCQ(addr, accum64) + } + ADCQ(operand.Imm(0), carries) + + if curBytes > math.MaxUint8 { + ADDQ(operand.U32(uint64(curBytes)), buf) + } else { + ADDQ(operand.U8(uint64(curBytes)), buf) + } + if curBytes*2 >= max { + continue + } + JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } + Label("handleRemainingComplete") + ADDQ(carries, accum64) +} + +func accumulateCF(accum64 reg.GPVirtual) { + Comment("accumulate CF (twice, in case the first time overflows)") + // accum64 += CF + ADCQ(operand.Imm(0), accum64) + // accum64 += CF again if the previous add overflowed. The previous add was + // 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can + // never overflow. + ADCQ(operand.Imm(0), accum64) +} + +// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value +// according to ones-complement arithmetic. BMI2 instructions will be used if +// allowBMI2 is true (requires fewer instructions). +func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) { + Comment("add CF and fold") + + // CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff + + tmp := GP64() + if allowBMI2 { + // effectively, tmp = accum >> 32 (technically, this is a rotate) + RORXQ(operand.Imm(32), accum, tmp) + // accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set + ADCL(tmp.As32(), accum.As32()) + // effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate) + RORXL(operand.Imm(16), accum.As32(), tmp.As32()) + // accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set + ADCW(tmp.As16(), accum.As16()) + } else { + // tmp = uint32(accum); max value 0xffff_ffff + // MOVL clears the upper 32 bits of a GP64 so this is equivalent to the + // non-existent MOVLQZX. + MOVL(accum.As32(), tmp.As32()) + // tmp += CF; max value 0x1_0000_0000, CF unset + ADCQ(operand.Imm(0), tmp) + // accum = accum >> 32; max value 0xffff_ffff + SHRQ(operand.Imm(32), accum) + // accum = accum + tmp; max value 0x1_ffff_ffff + CF unset + ADDQ(tmp, accum) + // tmp = uint16(accum); max value 0xffff + MOVWQZX(accum.As16(), tmp) + // accum = accum >> 16; max value 0x1_ffff + SHRQ(operand.Imm(16), accum) + // accum = accum + tmp; max value 0x2_fffe + CF unset + ADDQ(tmp, accum) + // tmp as uint16 = uint16(accum); max value 0xffff + MOVW(accum.As16(), tmp.As16()) + // accum = accum >> 16; max value 0x2 + SHRQ(operand.Imm(16), accum) + // accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set + ADDW(tmp.As16(), accum.As16()) + } + // accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe + ADCW(operand.Imm(0), accum.As16()) +} + +func generateLoadMasks() { + var offset int + // xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14 + GLOBL("xmmLoadMasks", RODATA|NOPTR) + + for n := 2; n < 16; n += 2 { + var pattern [16]byte + for i := 0; i < len(pattern); i++ { + if i < len(pattern)-n { + pattern[i] = 0 + continue + } + pattern[i] = 0xff + } + DATA(offset, operand.String(pattern[:])) + offset += len(pattern) + } +} + +func main() { + generateLoadMasks() + generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2) + generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2) + generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions") + Generate() +} diff --git a/monitor_android.go b/monitor_android.go index 2734c85..1c7e711 100644 --- a/monitor_android.go +++ b/monitor_android.go @@ -51,12 +51,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { return err } - oldInterface := m.defaultInterface.Load() newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index) if err != nil { return E.Cause(err, "find updated interface: ", link.Attrs().Name) } - m.defaultInterface.Store(newInterface) + oldInterface := m.defaultInterface.Swap(newInterface) if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled { return nil } diff --git a/monitor_darwin.go b/monitor_darwin.go index 88ea90c..f937c37 100644 --- a/monitor_darwin.go +++ b/monitor_darwin.go @@ -165,12 +165,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { if defaultInterface == nil { return ErrNoRoute } - oldInterface := m.defaultInterface.Load() newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index) if err != nil { return E.Cause(err, "find updated interface: ", defaultInterface.Name) } - m.defaultInterface.Store(newInterface) + oldInterface := m.defaultInterface.Swap(newInterface) if oldInterface != nil && oldInterface.Equals(*newInterface) { return nil } diff --git a/monitor_linux.go b/monitor_linux.go index e92f469..86dd28b 100644 --- a/monitor_linux.go +++ b/monitor_linux.go @@ -27,7 +27,7 @@ type networkUpdateMonitor struct { var ErrNetlinkBanned = E.New( "netlink socket in Android is banned by Google, " + "use the root or system (ADB) user to run sing-box, " + - "or switch to the sing-box Adnroid graphical interface client", + "or switch to the sing-box Android graphical interface client", ) func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) { diff --git a/monitor_linux_default.go b/monitor_linux_default.go index e9cce1d..72ba1be 100644 --- a/monitor_linux_default.go +++ b/monitor_linux_default.go @@ -25,12 +25,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { return err } - oldInterface := m.defaultInterface.Load() newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index) if err != nil { return E.Cause(err, "find updated interface: ", link.Attrs().Name) } - m.defaultInterface.Store(newInterface) + oldInterface := m.defaultInterface.Swap(newInterface) if oldInterface != nil && oldInterface.Equals(*newInterface) { return nil } diff --git a/monitor_windows.go b/monitor_windows.go index d58e701..18a795f 100644 --- a/monitor_windows.go +++ b/monitor_windows.go @@ -102,13 +102,12 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { return ErrNoRoute } - oldInterface := m.defaultInterface.Load() newInterface, err := m.interfaceFinder.ByIndex(index) if err != nil { return E.Cause(err, "find updated interface: ", alias) } - m.defaultInterface.Store(newInterface) - if oldInterface != nil && !oldInterface.Equals(*newInterface) { + oldInterface := m.defaultInterface.Swap(newInterface) + if oldInterface != nil && oldInterface.Equals(*newInterface) { return nil } m.emit(newInterface, 0) diff --git a/redirect_linux.go b/redirect_linux.go index 1645b85..113d6f1 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -44,7 +44,7 @@ type autoRedirect struct { } func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { - r := &autoRedirect{ + return &autoRedirect{ tunOptions: options.TunOptions, ctx: options.Context, handler: options.Handler, @@ -56,7 +56,10 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { customRedirectPortFunc: options.CustomRedirectPort, routeAddressSet: options.RouteAddressSet, routeExcludeAddressSet: options.RouteExcludeAddressSet, - } + }, nil +} + +func (r *autoRedirect) Start() error { var err error if runtime.GOOS == "android" { r.enableIPv4 = true @@ -74,7 +77,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { } } if err != nil { - return nil, E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH")) + return E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH")) } } } else { @@ -90,7 +93,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { if !r.useNFTables { r.iptablesPath, err = exec.LookPath("iptables") if err != nil { - return nil, E.Cause(err, "iptables is required") + return E.Cause(err, "iptables is required") } } } @@ -100,7 +103,7 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { r.ip6tablesPath, err = exec.LookPath("ip6tables") if err != nil { if !r.enableIPv4 { - return nil, E.Cause(err, "ip6tables is required") + return E.Cause(err, "ip6tables is required") } else { r.enableIPv6 = false r.logger.Error("device has no ip6tables nat support: ", err) @@ -109,10 +112,6 @@ func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { } } } - return r, nil -} - -func (r *autoRedirect) Start() error { if r.customRedirectPortFunc != nil { r.customRedirectPort = r.customRedirectPortFunc() } @@ -132,7 +131,6 @@ func (r *autoRedirect) Start() error { } r.redirectServer = server } - var err error if r.useNFTables { r.cleanupNFTables() err = r.setupNFTables() diff --git a/redirect_nftables.go b/redirect_nftables.go index be86114..02a9ca8 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -32,6 +32,10 @@ func (r *autoRedirect) setupNFTables() error { return err } + err = r.interfaceFinder.Update() + if err != nil { + return err + } r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix { return common.Filter(it.Addresses, func(prefix netip.Prefix) bool { return it.Name == "lo" || prefix.Addr().IsGlobalUnicast() diff --git a/stack_gvisor.go b/stack_gvisor.go index 83ca9e6..65bb7bd 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -19,13 +19,11 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) const WithGVisor = true -const defaultNIC tcpip.NICID = 1 +const DefaultNIC tcpip.NICID = 1 type GVisor struct { ctx context.Context @@ -68,28 +66,11 @@ func (t *GVisor) Start() error { return err } linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun} - ipStack, err := newGVisorStack(linkEndpoint) + ipStack, err := NewGVisorStack(linkEndpoint) if err != nil { return err } - tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination) - if pErr != nil { - r.Complete(pErr != ErrDrop) - return - } - conn := &gLazyConn{ - parentCtx: t.ctx, - stack: t.stack, - request: r, - localAddr: source.TCPAddr(), - remoteAddr: destination.TCPAddr(), - } - go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) - }) - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint @@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr { } } -func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { +func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { icmp.NewProtocol6, }, }) - err := ipStack.CreateNIC(defaultNIC, ep) + err := ipStack.CreateNIC(DefaultNIC, ep) if err != nil { return nil, gonet.TranslateNetstackError(err) } ipStack.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, - {Destination: header.IPv6EmptySubnet, NIC: defaultNIC}, + {Destination: header.IPv4EmptySubnet, NIC: DefaultNIC}, + {Destination: header.IPv6EmptySubnet, NIC: DefaultNIC}, }) - err = ipStack.SetSpoofing(defaultNIC, true) + err = ipStack.SetSpoofing(DefaultNIC, true) if err != nil { return nil, gonet.TranslateNetstackError(err) } - err = ipStack.SetPromiscuousMode(defaultNIC, true) + err = ipStack.SetPromiscuousMode(DefaultNIC, true) if err != nil { return nil, gonet.TranslateNetstackError(err) } diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go new file mode 100644 index 0000000..33cf40e --- /dev/null +++ b/stack_gvisor_tcp.go @@ -0,0 +1,51 @@ +//go:build with_gvisor + +package tun + +import ( + "context" + + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type TCPForwarder struct { + ctx context.Context + stack *stack.Stack + handler Handler + forwarder *tcp.Forwarder +} + +func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { + forwarder := &TCPForwarder{ + ctx: ctx, + stack: stack, + handler: handler, + } + forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) + return forwarder +} + +func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + return f.forwarder.HandlePacket(id, pkt) +} + +func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { + source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination) + if pErr != nil { + r.Complete(pErr != ErrDrop) + return + } + conn := &gLazyConn{ + parentCtx: f.ctx, + stack: f.stack, + request: r, + localAddr: source.TCPAddr(), + remoteAddr: destination.TCPAddr(), + } + go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil) +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 150fd1a..3027798 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock defer packetBuffer.Release() route, err := w.stack.FindRoute( - defaultNIC, + DefaultNIC, AddressFromAddr(destination.Addr), w.source, w.sourceNetwork, diff --git a/stack_mixed.go b/stack_mixed.go index be146e3..9293fb8 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -38,7 +38,7 @@ func (m *Mixed) Start() error { return err } endpoint := channel.New(1024, uint32(m.mtu), "") - ipStack, err := newGVisorStack(endpoint) + ipStack, err := NewGVisorStack(endpoint) if err != nil { return err } @@ -50,6 +50,18 @@ func (m *Mixed) Start() error { return nil } +func (m *Mixed) Close() error { + if m.stack == nil { + return nil + } + m.endpoint.Attach(nil) + m.stack.Close() + for _, endpoint := range m.stack.CleanupEndpoints() { + endpoint.Abort() + } + return m.System.Close() +} + func (m *Mixed) tunLoop() { if winTun, isWinTun := m.tun.(WinTun); isWinTun { m.wintunLoop(winTun) @@ -137,7 +149,7 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) + _, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) if err != nil { m.logger.Trace(E.Cause(err, "batch write packet")) } @@ -151,10 +163,10 @@ func (m *Mixed) processPacket(packet []byte) bool { writeBack bool err error ) - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: + switch ipVersion := header.IPVersion(packet); ipVersion { + case header.IPv4Version: writeBack, err = m.processIPv4(packet) - case 6: + case header.IPv6Version: writeBack, err = m.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) @@ -222,15 +234,3 @@ func (m *Mixed) packetLoop() { packet.DecRef() } } - -func (m *Mixed) Close() error { - if m.stack == nil { - return nil - } - m.endpoint.Attach(nil) - m.stack.Close() - for _, endpoint := range m.stack.CleanupEndpoints() { - endpoint.Abort() - } - return m.System.Close() -} diff --git a/stack_system.go b/stack_system.go index 39aead0..e48b8b6 100644 --- a/stack_system.go +++ b/stack_system.go @@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) + _, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) if err != nil { s.logger.Trace(E.Cause(err, "batch write packet")) } @@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) } if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -586,7 +586,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error sourceAddress := ipHdr.SourceAddr() ipHdr.SetSourceAddr(ipHdr.DestinationAddr()) ipHdr.SetDestinationAddr(sourceAddress) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, checksum.Checksum(icmpHdr.Payload(), 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return nil @@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e })) copy(icmpHdr.Payload(), payload) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-s.frontHeadroom) } @@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) } else { newPacket.Advance(-w.frontHeadroom) } @@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetChecksum(0) } if PacketOffset > 0 { - newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version) } else { newPacket.Advance(-w.frontHeadroom) } diff --git a/stack_system_packet.go b/stack_system_packet.go new file mode 100644 index 0000000..b5060b0 --- /dev/null +++ b/stack_system_packet.go @@ -0,0 +1,34 @@ +package tun + +import ( + "net/netip" + "syscall" + + "github.com/sagernet/sing-tun/internal/gtcpip/header" +) + +func PacketIPVersion(packet []byte) int { + return header.IPVersion(packet) +} + +func PacketFillHeader(packet []byte, ipVersion int) { + if PacketOffset > 0 { + switch ipVersion { + case header.IPv4Version: + packet[3] = syscall.AF_INET + case header.IPv6Version: + packet[3] = syscall.AF_INET6 + } + } +} + +func PacketDestination(packet []byte) netip.Addr { + switch ipVersion := header.IPVersion(packet); ipVersion { + case header.IPv4Version: + return header.IPv4(packet).DestinationAddr() + case header.IPv6Version: + return header.IPv6(packet).DestinationAddr() + default: + return netip.Addr{} + } +} diff --git a/tun.go b/tun.go index d1738e8..b0f573a 100644 --- a/tun.go +++ b/tun.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + "github.com/sagernet/sing/common/control" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -24,8 +25,10 @@ type Handler interface { type Tun interface { io.ReadWriter N.VectorisedWriter + Name() (string, error) Start() error Close() error + UpdateRouteOptions(tunOptions Options) error } type WinTun interface { @@ -38,7 +41,7 @@ type LinuxTUN interface { N.FrontHeadroom BatchSize() int BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) - BatchWrite(buffers [][]byte, offset int) error + BatchWrite(buffers [][]byte, offset int) (n int, err error) TXChecksumOffload() bool } @@ -54,6 +57,7 @@ type Options struct { MTU uint32 GSO bool AutoRoute bool + InterfaceScope bool Inet4Gateway netip.Addr Inet6Gateway netip.Addr DNSServers []netip.Addr @@ -74,6 +78,7 @@ type Options struct { IncludeAndroidUser []int IncludePackage []string ExcludePackage []string + InterfaceFinder control.InterfaceFinder InterfaceMonitor DefaultInterfaceMonitor FileDescriptor int Logger logger.Logger @@ -99,10 +104,12 @@ func (o *Options) Inet4GatewayAddr() netip.Addr { case "darwin": return o.Inet4Address[0].Addr() default: - if HasNextAddress(o.Inet4Address[0], 1) { - return o.Inet4Address[0].Addr().Next() - } else { - return o.Inet4Address[0].Addr() + if !o.InterfaceScope { + if HasNextAddress(o.Inet4Address[0], 1) { + return o.Inet4Address[0].Addr().Next() + } else { + return o.Inet4Address[0].Addr() + } } } } @@ -123,10 +130,12 @@ func (o *Options) Inet6GatewayAddr() netip.Addr { case "darwin": return o.Inet6Address[0].Addr() default: - if HasNextAddress(o.Inet6Address[0], 1) { - return o.Inet6Address[0].Addr().Next() - } else { - return o.Inet6Address[0].Addr() + if !o.InterfaceScope { + if HasNextAddress(o.Inet6Address[0], 1) { + return o.Inet6Address[0].Addr().Next() + } else { + return o.Inet6Address[0].Addr() + } } } } diff --git a/tun_darwin.go b/tun_darwin.go index 3b7a47e..a0dd54a 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -9,6 +9,7 @@ import ( "syscall" "unsafe" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -28,7 +29,15 @@ type NativeTun struct { options Options inet4Address [4]byte inet6Address [16]byte - routerSet bool + routeSet bool +} + +func (t *NativeTun) Name() (string, error) { + return unix.GetsockoptString( + int(t.tunFile.Fd()), + 2, /* #define SYSPROTO_CONTROL 2 */ + 2, /* #define UTUN_OPT_IFNAME 2 */ + ) } func New(options Options) (Tun, error) { @@ -96,9 +105,10 @@ var ( func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { var packetHeader []byte - if buffers[0].Byte(0)>>4 == 4 { + switch header.IPVersion(buffers[0].Bytes()) { + case header.IPv4Version: packetHeader = packetHeader4[:] - } else { + case header.IPv6Version: packetHeader = packetHeader6[:] } return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) @@ -248,44 +258,63 @@ func configure(tunFd int, ifIndex int, name string, options Options) error { return nil } +func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error { + err := t.unsetRoutes() + if err != nil { + return err + } + t.options = tunOptions + return t.setRoutes() +} + func (t *NativeTun) setRoutes() error { - if t.options.AutoRoute && t.options.FileDescriptor == 0 { + if t.options.FileDescriptor == 0 { routeRanges, err := t.options.BuildAutoRouteRanges(false) if err != nil { return err } - gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() - for _, destination := range routeRanges { - var gateway netip.Addr - if destination.Addr().Is4() { - gateway = gateway4 - } else { - gateway = gateway6 - } - err = execRoute(unix.RTM_ADD, destination, gateway) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = execRoute(unix.RTM_DELETE, destination, gateway) - if err != nil { - return E.Cause(err, "remove existing route: ", destination) - } - err = execRoute(unix.RTM_ADD, destination, gateway) - if err != nil { - return E.Cause(err, "re-add route: ", destination) - } + if len(routeRanges) > 0 { + gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() + for _, destination := range routeRanges { + var gateway netip.Addr + if destination.Addr().Is4() { + gateway = gateway4 } else { - return E.Cause(err, "add route: ", destination) + gateway = gateway6 + } + var interfaceIndex int + if t.options.InterfaceScope { + iff, err := t.options.InterfaceFinder.ByName(t.options.Name) + if err != nil { + return err + } + interfaceIndex = iff.Index + } + err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway) + if err != nil { + if errors.Is(err, unix.EEXIST) { + err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway) + if err != nil { + return E.Cause(err, "remove existing route: ", destination) + } + err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway) + if err != nil { + return E.Cause(err, "re-add route: ", destination) + } + } else { + return E.Cause(err, "add route: ", destination) + } } } + flushDNSCache() + t.routeSet = true } - flushDNSCache() - t.routerSet = true } return nil } func (t *NativeTun) unsetRoutes() error { - if !t.routerSet { + if !t.routeSet { return nil } routeRanges, err := t.options.BuildAutoRouteRanges(false) @@ -300,7 +329,7 @@ func (t *NativeTun) unsetRoutes() error { } else { gateway = gateway6 } - err = execRoute(unix.RTM_DELETE, destination, gateway) + err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway) if err != nil { err = E.Errors(err, E.Cause(err, "delete route: ", destination)) } @@ -317,7 +346,7 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error { return block(socketFd) } -func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error { +func execRoute(rtmType int, interfaceScope bool, interfaceIndex int, destination netip.Prefix, gateway netip.Addr) error { routeMessage := route.RouteMessage{ Type: rtmType, Version: unix.RTM_VERSION, @@ -326,6 +355,10 @@ func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error } if rtmType == unix.RTM_ADD { routeMessage.Flags |= unix.RTF_UP + if interfaceScope { + routeMessage.Flags |= unix.RTF_IFSCOPE + routeMessage.Index = interfaceIndex + } } if gateway.Is4() { routeMessage.Addrs = []route.Addr{ diff --git a/tun_linux.go b/tun_linux.go index 09c9262..72aac6a 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -2,6 +2,7 @@ package tun import ( "errors" + "fmt" "math/rand" "net" "net/netip" @@ -13,6 +14,8 @@ import ( "unsafe" "github.com/sagernet/netlink" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -35,13 +38,15 @@ type NativeTun struct { interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int - gsoEnabled bool - gsoBuffer []byte + readAccess sync.Mutex + writeAccess sync.Mutex + vnetHdr bool + writeBuffer []byte gsoToWrite []int - gsoReadAccess sync.Mutex - tcpGROAccess sync.Mutex - tcp4GROTable *tcpGROTable - tcp6GROTable *tcpGROTable + tcpGROTable *tcpGROTable + udpGroAccess sync.Mutex + udpGROTable *udpGROTable + gro groDisablementFlags txChecksumOffload bool } @@ -80,105 +85,6 @@ func New(options Options) (Tun, error) { return nativeTun, nil } -func (t *NativeTun) FrontHeadroom() int { - if t.gsoEnabled { - return virtioNetHdrLen - } - return 0 -} - -func (t *NativeTun) Read(p []byte) (n int, err error) { - if t.gsoEnabled { - n, err = t.tunFile.Read(t.gsoBuffer) - if err != nil { - return - } - var sizes [1]int - n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0) - if err != nil { - return - } - if n == 0 { - return - } - n = sizes[0] - return - } else { - return t.tunFile.Read(p) - } -} - -func (t *NativeTun) Write(p []byte) (n int, err error) { - if t.gsoEnabled { - err = t.BatchWrite([][]byte{p}, virtioNetHdrLen) - if err != nil { - return - } - n = len(p) - return - } - return t.tunFile.Write(p) -} - -func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { - if t.gsoEnabled { - n := buf.LenMulti(buffers) - buffer := buf.NewSize(virtioNetHdrLen + n) - buffer.Truncate(virtioNetHdrLen) - buf.CopyMulti(buffer.Extend(n), buffers) - _, err := t.tunFile.Write(buffer.Bytes()) - buffer.Release() - return err - } else { - return t.tunWriter.WriteVectorised(buffers) - } -} - -func (t *NativeTun) BatchSize() int { - if !t.gsoEnabled { - return 1 - } - /* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605 - batchSize := int(gsoMaxSize/t.options.MTU) * 2 - if batchSize > idealBatchSize { - batchSize = idealBatchSize - } - return batchSize*/ - return idealBatchSize -} - -func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { - t.gsoReadAccess.Lock() - defer t.gsoReadAccess.Unlock() - n, err = t.tunFile.Read(t.gsoBuffer) - if err != nil { - return - } - return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset) -} - -func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error { - t.tcpGROAccess.Lock() - defer func() { - t.tcp4GROTable.reset() - t.tcp6GROTable.reset() - t.tcpGROAccess.Unlock() - }() - t.gsoToWrite = t.gsoToWrite[:0] - err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite) - if err != nil { - return err - } - offset -= virtioNetHdrLen - for _, bufferIndex := range t.gsoToWrite { - _, err = t.tunFile.Write(buffers[bufferIndex][offset:]) - if err != nil { - return err - } - } - return nil -} - var controlPath string func init() { @@ -196,29 +102,26 @@ func open(name string, vnetHdr bool) (int, error) { if err != nil { return -1, err } - - var ifr struct { - name [16]byte - flags uint16 - _ [22]byte + ifr, err := unix.NewIfreq(name) + if err != nil { + unix.Close(fd) + return 0, err } - - copy(ifr.name[:], name) - ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI + flags := unix.IFF_TUN | unix.IFF_NO_PI if vnetHdr { - ifr.flags |= unix.IFF_VNET_HDR + flags |= unix.IFF_VNET_HDR } - _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) - if errno != 0 { + ifr.SetUint16(uint16(flags)) + err = unix.IoctlIfreq(fd, unix.TUNSETIFF, ifr) + if err != nil { unix.Close(fd) - return -1, errno + return 0, err } - - if err = unix.SetNonblock(fd, true); err != nil { + err = unix.SetNonblock(fd, true) + if err != nil { unix.Close(fd) - return -1, err + return 0, err } - return fd, nil } @@ -250,22 +153,10 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { } if t.options.GSO { - var vnetHdrEnabled bool - vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name) + err = t.enableGSO() if err != nil { - return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled") + t.options.Logger.Warn(err) } - if !vnetHdrEnabled { - return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled") - } - err = setTCPOffload(t.tunFd) - if err != nil { - return err - } - t.gsoEnabled = true - t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) - t.tcp4GROTable = newTCPGROTable() - t.tcp6GROTable = newTCPGROTable() } var rxChecksumOffload bool @@ -280,7 +171,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { if err != nil { return err } - if err == nil && !txChecksumOffload { + if !txChecksumOffload { err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM) if err != nil { return err @@ -292,6 +183,83 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { return nil } +func (t *NativeTun) enableGSO() error { + vnetHdrEnabled, err := checkVNETHDREnabled(t.tunFd, t.options.Name) + if err != nil { + return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled") + } + if !vnetHdrEnabled { + return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled") + } + err = setTCPOffload(t.tunFd) + if err != nil { + return E.Cause(err, "enable TCP offload") + } + t.vnetHdr = true + t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) + t.tcpGROTable = newTCPGROTable() + t.udpGROTable = newUDPGROTable() + err = setUDPOffload(t.tunFd) + if err != nil { + t.gro.disableUDPGRO() + } + return nil +} + +func (t *NativeTun) probeTCPGRO() error { + ipPort := netip.AddrPortFrom(t.options.Inet4Address[0].Addr(), 0) + fingerprint := []byte("sing-tun-probe-tun-gro") + segmentSize := len(fingerprint) + iphLen := 20 + tcphLen := 20 + totalLen := iphLen + tcphLen + segmentSize + bufs := make([][]byte, 2) + for i := range bufs { + bufs[i] = make([]byte, virtioNetHdrLen+totalLen, virtioNetHdrLen+(totalLen*2)) + ipv4H := header.IPv4(bufs[i][virtioNetHdrLen:]) + ipv4H.Encode(&header.IPv4Fields{ + SrcAddr: ipPort.Addr(), + DstAddr: ipPort.Addr(), + Protocol: unix.IPPROTO_TCP, + // Use a zero value TTL as best effort means to reduce chance of + // probe packet leaking further than it needs to. + TTL: 0, + TotalLength: uint16(totalLen), + }) + tcpH := header.TCP(bufs[i][virtioNetHdrLen+iphLen:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: ipPort.Port(), + DstPort: ipPort.Port(), + SeqNum: 1 + uint32(i*segmentSize), + AckNum: 1, + DataOffset: 20, + Flags: header.TCPFlagAck, + WindowSize: 3000, + }) + copy(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], fingerprint) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddressSlice(), ipv4H.DestinationAddressSlice(), uint16(tcphLen+segmentSize)) + pseudoCsum = checksum.Checksum(bufs[i][virtioNetHdrLen+iphLen+tcphLen:], pseudoCsum) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + } + _, err := t.BatchWrite(bufs, virtioNetHdrLen) + return err +} + +func (t *NativeTun) Name() (string, error) { + var ifr [unix.IFNAMSIZ + 64]byte + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(t.tunFd), + uintptr(unix.TUNGETIFF), + uintptr(unsafe.Pointer(&ifr[0])), + ) + if errno != 0 { + return "", os.NewSyscallError("ioctl TUNGETIFF", errno) + } + return unix.ByteSliceToString(ifr[:]), nil +} + func (t *NativeTun) Start() error { if t.options.FileDescriptor != 0 { return nil @@ -307,6 +275,15 @@ func (t *NativeTun) Start() error { return err } + if t.vnetHdr && len(t.options.Inet4Address) > 0 { + err = t.probeTCPGRO() + if err != nil { + t.gro.disableTCPGRO() + t.gro.disableUDPGRO() + t.options.Logger.Warn(E.Cause(err, "disabled TUN TCP & UDP GRO due to GRO probe error")) + } + } + if t.options.IPRoute2TableIndex == 0 { for { t.options.IPRoute2TableIndex = int(rand.Uint32()) @@ -348,6 +325,164 @@ func (t *NativeTun) Close() error { return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } +func (t *NativeTun) Read(p []byte) (n int, err error) { + if t.vnetHdr { + n, err = t.tunFile.Read(t.writeBuffer) + if err != nil { + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + return + } + var sizes [1]int + n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0) + if err != nil { + return + } + if n == 0 { + return + } + n = sizes[0] + return + } else { + return t.tunFile.Read(p) + } +} + +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + + options, err := hdr.toGSOOptions() + if err != nil { + return 0, err + } + + // Don't trust HdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // CsumStart, which is synonymous for IP header length. + if options.GSOType == GSOUDPL4 { + options.HdrLen = options.CsumStart + 8 + } else if options.GSOType != GSONone { + if len(in) <= int(options.CsumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + options.HdrLen = options.CsumStart + tcpHLen + } + + return GSOSplit(in, options, bufs, sizes, offset) +} + +func (t *NativeTun) Write(p []byte) (n int, err error) { + if t.vnetHdr { + buffer := buf.Get(virtioNetHdrLen + len(p)) + copy(buffer[virtioNetHdrLen:], p) + _, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen) + buf.Put(buffer) + if err != nil { + return + } + n = len(p) + return + } + return t.tunFile.Write(p) +} + +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + if t.vnetHdr { + n := buf.LenMulti(buffers) + buffer := buf.NewSize(virtioNetHdrLen + n) + buffer.Truncate(virtioNetHdrLen) + buf.CopyMulti(buffer.Extend(n), buffers) + _, err := t.tunFile.Write(buffer.Bytes()) + buffer.Release() + return err + } else { + return t.tunWriter.WriteVectorised(buffers) + } +} + +func (t *NativeTun) FrontHeadroom() int { + if t.vnetHdr { + return virtioNetHdrLen + } + return 0 +} + +func (t *NativeTun) BatchSize() int { + if !t.vnetHdr { + return 1 + } + /* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605 + batchSize := int(gsoMaxSize/t.options.MTU) * 2 + if batchSize > idealBatchSize { + batchSize = idealBatchSize + } + return batchSize*/ + return idealBatchSize +} + +func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { + t.readAccess.Lock() + defer t.readAccess.Unlock() + n, err = t.tunFile.Read(t.writeBuffer) + if err != nil { + return + } + return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset) +} + +func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) { + t.writeAccess.Lock() + defer func() { + t.tcpGROTable.reset() + t.udpGROTable.reset() + t.writeAccess.Unlock() + }() + var ( + errs error + total int + ) + t.gsoToWrite = t.gsoToWrite[:0] + if t.vnetHdr { + err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen + } else { + for i := range buffers { + t.gsoToWrite = append(t.gsoToWrite, i) + } + } + for _, toWrite := range t.gsoToWrite { + n, err := t.tunFile.Write(buffers[toWrite][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) + } else { + total += n + } + } + return total, errs +} + func (t *NativeTun) TXChecksumOffload() bool { return t.txChecksumOffload } @@ -359,6 +494,25 @@ func prefixToIPNet(prefix netip.Prefix) *net.IPNet { } } +func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error { + if t.options.FileDescriptor > 0 { + return nil + } else if !t.options.AutoRoute { + t.options = tunOptions + return nil + } + tunLink, err := netlink.LinkByName(t.options.Name) + if err != nil { + return err + } + err = t.unsetRoute0(tunLink) + if err != nil { + return err + } + t.options = tunOptions + return t.setRoute(tunLink) +} + func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) { routeRanges, err := t.options.BuildAutoRouteRanges(false) if err != nil { diff --git a/tun_linux_flags.go b/tun_linux_flags.go index 1b84baa..53fff08 100644 --- a/tun_linux_flags.go +++ b/tun_linux_flags.go @@ -12,6 +12,12 @@ import ( "golang.org/x/sys/unix" ) +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + func checkVNETHDREnabled(fd int, name string) (bool, error) { ifr, err := unix.NewIfreq(name) if err != nil { @@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) { } func setTCPOffload(fd int) error { - const ( - // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 - ) - err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads) + err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") } return nil } +func setUDPOffload(fd int) error { + return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) +} + type ifreqData struct { ifrName [unix.IFNAMSIZ]byte ifrData uintptr diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index 1edeab1..f82d762 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -10,11 +10,12 @@ import ( var _ GVisorTun = (*NativeTun)(nil) func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { - if t.gsoEnabled { + if t.vnetHdr { return fdbased.New(&fdbased.Options{ FDs: []int{t.tunFd}, MTU: t.options.MTU, GSOMaxSize: gsoMaxSize, + GRO: true, RXChecksumOffload: true, TXChecksumOffload: t.txChecksumOffload, }) diff --git a/tun_linux_offload.go b/tun_linux_offload.go deleted file mode 100644 index 488d3ff..0000000 --- a/tun_linux_offload.go +++ /dev/null @@ -1,768 +0,0 @@ -//go:build linux - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "unsafe" - - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - E "github.com/sagernet/sing/common/exceptions" - - "golang.org/x/sys/unix" -) - -const ( - gsoMaxSize = 65536 - tcpFlagsOffset = 13 - idealBatchSize = 128 -) - -const ( - tcpFlagFIN uint8 = 0x01 - tcpFlagPSH uint8 = 0x08 - tcpFlagACK uint8 = 0x10 -) - -// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The -// kernel symbol is virtio_net_hdr. -type virtioNetHdr struct { - flags uint8 - gsoType uint8 - hdrLen uint16 - gsoSize uint16 - csumStart uint16 - csumOffset uint16 -} - -func (v *virtioNetHdr) decode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) - return nil -} - -func (v *virtioNetHdr) encode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) - return nil -} - -const ( - // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the - // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). - virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) -) - -// flowKey represents the key for a flow. -type flowKey struct { - srcAddr, dstAddr [16]byte - srcPort, dstPort uint16 - rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. -} - -// tcpGROTable holds flow and coalescing information for the purposes of GRO. -type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem - itemsPool [][]tcpGROItem -} - -func newTCPGROTable() *tcpGROTable { - t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize), - itemsPool: make([][]tcpGROItem, idealBatchSize), - } - for i := range t.itemsPool { - t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) - } - return t -} - -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) - key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) - key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) - key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) - return key -} - -// lookupOrInsert looks up a flow for the provided packet and metadata, -// returning the packets found for the flow, or inserting a new one if none -// is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - items, ok := t.itemsByFlow[key] - if ok { - return items, ok - } - // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) - return nil, false -} - -// insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, - } - items, ok := t.itemsByFlow[key] - if !ok { - items = t.newItems() - } - items = append(items, item) - t.itemsByFlow[key] = items -} - -func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { - items, _ := t.itemsByFlow[item.key] - items[i] = item -} - -func (t *tcpGROTable) deleteAt(key flowKey, i int) { - items, _ := t.itemsByFlow[key] - items = append(items[:i], items[i+1:]...) - t.itemsByFlow[key] = items -} - -// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime -// of a GRO evaluation across a vector of packets. -type tcpGROItem struct { - key flowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set -} - -func (t *tcpGROTable) newItems() []tcpGROItem { - var items []tcpGROItem - items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] - return items -} - -func (t *tcpGROTable) reset() { - for k, items := range t.itemsByFlow { - items = items[:0] - t.itemsPool = append(t.itemsPool, items) - delete(t.itemsByFlow, k) - } -} - -// canCoalesce represents the outcome of checking if two TCP packets are -// candidates for coalescing. -type canCoalesce int - -const ( - coalescePrepend canCoalesce = -1 - coalesceUnavailable canCoalesce = 0 - coalesceAppend canCoalesce = 1 -) - -// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. This function makes considerations that match the kernel's -// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] - if tcphLen != item.tcphLen { - // cannot coalesce with unequal tcp options len - return coalesceUnavailable - } - if tcphLen > 20 { - if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { - // cannot coalesce with unequal tcp options - return coalesceUnavailable - } - } - if pkt[0]>>4 == 6 { - if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { - // cannot coalesce with unequal Traffic class values - return coalesceUnavailable - } - if pkt[7] != pktTarget[7] { - // cannot coalesce with unequal Hop limit values - return coalesceUnavailable - } - } else { - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - if pkt[8] != pktTarget[8] { - // cannot coalesce with unequal TTL values - return coalesceUnavailable - } - } - // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective - if item.pshSet { - // We cannot append to a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { - // A smaller than gsoSize packet has been appended previously. - // Nothing can come after a smaller packet on the end. - return coalesceUnavailable - } - if gsoSize > item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - return coalesceAppend - } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective - if pshSet { - // We cannot prepend with a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if gsoSize < item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - if gsoSize > item.gsoSize && item.numMerged > 0 { - // There's at least one previous merge, and we're larger than all - // previous. This would put multiple smaller packets on the end. - return coalesceUnavailable - } - return coalescePrepend - } - return coalesceUnavailable -} - -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { - srcAddrAt := ipv4SrcAddrOffset - addrSize := 4 - if isV6 { - srcAddrAt = ipv6SrcAddrOffset - addrSize = 16 - } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0 -} - -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. -type coalesceResult int - -const ( - coalesceInsufficientCap coalesceResult = iota - coalescePSHEnding - coalesceItemInvalidCSum - coalescePktInvalidCSum - coalesceSuccess -) - -// coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data - if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if pshSet { - return coalescePSHEnding - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] - } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - if pshSet { - // We are appending a segment with PSH set. - item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH - } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - } - - if gsoSize > item.gsoSize { - item.gsoSize = gsoSize - } - - item.numMerged++ - return coalesceSuccess -} - -const ( - ipv4FlagMoreFragments uint8 = 0x20 -) - -const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 -) - -type tcpGROResult int - -const ( - tcpGROResultNoop tcpGROResult = iota - tcpGROResultTableInsert - tcpGROResultCoalesced -) - -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It returns a tcpGROResultNoop when no -// action was taken, tcpGROResultTableInsert when the evaluated packet was -// inserted into table, and tcpGROResultCoalesced when the evaluated packet was -// coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { - pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { - // A valid IPv4 or IPv6 packet will never exceed this. - return tcpGROResultNoop - } - iphLen := int((pkt[0] & 0x0F) * 4) - if isV6 { - iphLen = 40 - ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) - if ipv6HPayloadLen != len(pkt)-iphLen { - return tcpGROResultNoop - } - } else { - totalLen := int(binary.BigEndian.Uint16(pkt[2:])) - if totalLen != len(pkt) { - return tcpGROResultNoop - } - } - if len(pkt) < iphLen { - return tcpGROResultNoop - } - tcphLen := int((pkt[iphLen+12] >> 4) * 4) - if tcphLen < 20 || tcphLen > 60 { - return tcpGROResultNoop - } - if len(pkt) < iphLen+tcphLen { - return tcpGROResultNoop - } - if !isV6 { - if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { - // no GRO support for fragmented segments for now - return tcpGROResultNoop - } - } - tcpFlags := pkt[iphLen+tcpFlagsOffset] - var pshSet bool - // not a candidate if any non-ACK flags (except PSH+ACK) are set - if tcpFlags != tcpFlagACK { - if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return tcpGROResultNoop - } - pshSet = true - } - gsoSize := uint16(len(pkt) - tcphLen - iphLen) - // not a candidate if payload len is 0 - if gsoSize < 1 { - return tcpGROResultNoop - } - seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) - srcAddrOffset := ipv4SrcAddrOffset - addrLen := 4 - if isV6 { - srcAddrOffset = ipv6SrcAddrOffset - addrLen = 16 - } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - if !existing { - return tcpGROResultNoop - } - for i := len(items) - 1; i >= 0; i-- { - // In the best case of packets arriving in order iterating in reverse is - // more efficient if there are multiple items for a given flow. This - // also enables a natural table.deleteAt() in the - // coalesceItemInvalidCSum case without the need for index tracking. - // This algorithm makes a best effort to coalesce in the event of - // unordered packets, where pkt may land anywhere in items from a - // sequence number perspective, however once an item is inserted into - // the table it is never compared across other items later. - item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) - if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) - switch result { - case coalesceSuccess: - table.updateAt(item, i) - return tcpGROResultCoalesced - case coalesceItemInvalidCSum: - // delete the item with an invalid csum - table.deleteAt(item.key, i) - case coalescePktInvalidCSum: - // no point in inserting an item that we can't coalesce - return tcpGROResultNoop - default: - } - } - } - // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return tcpGROResultTableInsert -} - -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true -} - -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// applyCoalesceAccounting updates bufs to account for coalescing based on the -// metadata found in table. -func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { - for _, items := range table.itemsByFlow { - for _, item := range items { - if item.numMerged > 0 { - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(item.iphLen + item.tcphLen), - gsoSize: item.gsoSize, - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - pkt := bufs[item.bufsIndex][offset:] - - // Recalculate the total len (IPv4) or payload len (IPv6). - // Recalculate the (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field - } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - - // Calculate the pseudo header checksum and place it at the TCP - // checksum offset. Downstream checksum offloading will combine - // this with computation of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum)) - } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - } - } - } - return nil -} - -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be -// empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { - for i := range bufs { - if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { - return errors.New("invalid offset") - } - var result tcpGROResult - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - result = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - result = tcpGRO(bufs, offset, i, tcp6Table, true) - } - switch result { - case tcpGROResultNoop: - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - fallthrough - case tcpGROResultTableInsert: - *toWrite = append(*toWrite, i) - } - } - err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) - err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) - return E.Errors(err4, err6) -} - -// tcpTSO splits packets from in into outBuffs, writing the size of each -// element into sizes. It returns the number of buffers populated, and/or an -// error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { - iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset - addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset - addrLen = 4 - } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) - nextSegmentDataAt := int(hdr.hdrLen) - i := 0 - for ; nextSegmentDataAt < len(in); i++ { - if i == len(outBuffs) { - return i - 1, ErrTooManySegments - } - nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) - if nextSegmentEnd > len(in) { - nextSegmentEnd = len(in) - } - segmentDataLen := nextSegmentEnd - nextSegmentDataAt - totalLen := int(hdr.hdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBuffs[i][outOffset:] - - copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - // For IPv4 we are responsible for incrementing the ID field, - // updating the total len field, and recalculating the header - // checksum. - if i > 0 { - id := binary.BigEndian.Uint16(out[4:]) - id += uint16(i) - binary.BigEndian.PutUint16(out[4:], id) - } - binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksumFold(out[:iphLen], 0) - binary.BigEndian.PutUint16(out[10:], ipv4CSum) - } else { - // For IPv6 we are responsible for updating the payload length field. - binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) - } - - // TCP header - copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags - } - - // payload - copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) - - nextSegmentDataAt += int(hdr.gsoSize) - } - return i, nil -} - -func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { - cSumAt := cSumStart + cSumOffset - // The initial value at the checksum offset should be summed with the - // checksum we compute. This is typically the pseudo-header checksum. - initial := binary.BigEndian.Uint16(in[cSumAt:]) - in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial))) - return nil -} - -// handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { - var hdr virtioNetHdr - err := hdr.decode(in) - if err != nil { - return 0, err - } - in = in[virtioNetHdrLen:] - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { - if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { - // This means CHECKSUM_PARTIAL in skb context. We are responsible - // for computing the checksum starting at hdr.csumStart and placing - // at hdr.csumOffset. - err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) - if err != nil { - return 0, err - } - } - if len(in) > len(bufs[0][offset:]) { - return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) - } - n := copy(bufs[0][offset:], in) - sizes[0] = n - return 1, nil - } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) - } - - ipVersion := in[0] >> 4 - switch ipVersion { - case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - default: - return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) - } - - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } - // Don't trust hdr.hdrLen from the kernel as it can be equal to the length - // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto - // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) - } - hdr.hdrLen = hdr.csumStart + tcpHLen - - if len(in) < int(hdr.hdrLen) { - return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) - } - - if hdr.hdrLen < hdr.csumStart { - return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) - } - cSumAt := int(hdr.csumStart + hdr.csumOffset) - if cSumAt+1 >= len(in) { - return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) - } - - return tcpTSO(in, hdr, bufs, sizes, offset) -} - -func checksumNoFold(b []byte, initial uint64) uint64 { - return uint64(checksum.Checksum(b, uint16(initial))) -} - -func checksumFold(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - return uint16(ac) -} - -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) -} diff --git a/tun_linux_offload_errors.go b/tun_linux_offload_errors.go deleted file mode 100644 index 8e5db90..0000000 --- a/tun_linux_offload_errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package tun - -import E "github.com/sagernet/sing/common/exceptions" - -var ErrTooManySegments = E.New("too many segments") diff --git a/tun_offload.go b/tun_offload.go new file mode 100644 index 0000000..a0eee82 --- /dev/null +++ b/tun_offload.go @@ -0,0 +1,229 @@ +package tun + +import ( + "encoding/binary" + "fmt" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" +) + +const ( + gsoMaxSize = 65536 + idealBatchSize = 128 +) + +// GSOType represents the type of segmentation offload. +type GSOType int + +const ( + GSONone GSOType = iota + GSOTCPv4 + GSOTCPv6 + GSOUDPL4 +) + +func (g GSOType) String() string { + switch g { + case GSONone: + return "GSONone" + case GSOTCPv4: + return "GSOTCPv4" + case GSOTCPv6: + return "GSOTCPv6" + case GSOUDPL4: + return "GSOUDPL4" + default: + return "unknown" + } +} + +// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO +// specification. It is a common representation of GSO metadata that can be +// applied to support packet GSO across tun.Device implementations. +type GSOOptions struct { + // GSOType represents the type of segmentation offload. + GSOType GSOType + // HdrLen is the sum of the layer 3 and 4 header lengths. This field may be + // zero when GSOType == GSONone. + HdrLen uint16 + // CsumStart is the head byte index of the packet data to be checksummed, + // i.e. the start of the TCP or UDP header. + CsumStart uint16 + // CsumOffset is the offset from CsumStart where the 2-byte checksum value + // should be placed. + CsumOffset uint16 + // GSOSize is the size of each segment exclusive of HdrLen. The tail segment + // may be smaller than this value. + GSOSize uint16 + // NeedsCsum may be set where GSOType == GSONone. When set, the checksum + // at CsumStart + CsumOffset must be a partial checksum, i.e. the + // pseudo-header sum. + NeedsCsum bool +} + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +const ( + // defined here in order to avoid importation of any platform-specific pkgs + ipProtoTCP = 6 + ipProtoUDP = 17 +) + +// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing +// the size of each element into sizes. It returns the number of buffers +// populated, and/or an error. Callers may pass an 'in' slice that overlaps with +// the first element of outBuffers, i.e. &in[0] may be equal to +// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// value of options.NeedsCsum. Length of each outBufs element must be greater +// than or equal to the length of 'in', otherwise output may be silently +// truncated. +func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { + cSumAt := int(options.CsumStart) + int(options.CsumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + if len(in) < int(options.HdrLen) { + return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen) + } + + // Handle the conditions where we are copying a single element to outBuffs. + payloadLen := len(in) - int(options.HdrLen) + if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { + if len(in) > len(outBufs[0][outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + } + if options.NeedsCsum { + // The initial value at the checksum offset should be summed with + // the checksum we compute. This is typically the pseudo-header sum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum.Checksum(in[options.CsumStart:], initial)) + } + sizes[0] = copy(outBufs[0][outOffset:], in) + return 1, nil + } + + if options.HdrLen < options.CsumStart { + return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 20 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20) + } + case 6: + if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 40 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + iphLen := int(options.CsumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if ipVersion == 4 { + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(options.CsumStart + options.CsumOffset) + var firstTCPSeqNum uint32 + var protocol uint8 + if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 { + protocol = ipProtoTCP + if len(in) < int(options.CsumStart)+20 { + return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)", + len(in), options.CsumStart, 20) + } + firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:]) + } else { + protocol = ipProtoUDP + } + nextSegmentDataAt := int(options.HdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBufs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(options.HdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBufs[i][outOffset:] + + copy(out, in[:iphLen]) + if ipVersion == 4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + out[10], out[11] = 0, 0 // clear ipv4 header checksum + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum.Checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen]) + + if protocol == ipProtoTCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i)) + binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[options.CsumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart)) + } + + // payload + copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + transportHeaderLen := int(options.HdrLen - options.CsumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(protocol), in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^checksum.Checksum(out[options.CsumStart:totalLen], transportCSum) + binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum) + + nextSegmentDataAt += int(options.GSOSize) + } + return i, nil +} diff --git a/tun_offload_errors.go b/tun_offload_errors.go new file mode 100644 index 0000000..2c49fc7 --- /dev/null +++ b/tun_offload_errors.go @@ -0,0 +1,10 @@ +package tun + +import ( + "errors" +) + +// ErrTooManySegments is returned by Device.Read() when segmentation +// overflows the length of supplied buffers. This error should not cause +// reads to cease. +var ErrTooManySegments = errors.New("too many segments") diff --git a/tun_offload_linux.go b/tun_offload_linux.go new file mode 100644 index 0000000..7e44a55 --- /dev/null +++ b/tun_offload_linux.go @@ -0,0 +1,937 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "unsafe" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" + + "golang.org/x/sys/unix" +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) { + var gsoType GSOType + switch v.gsoType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + gsoType = GSONone + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + gsoType = GSOTCPv4 + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + gsoType = GSOTCPv6 + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + gsoType = GSOUDPL4 + default: + return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType) + } + return GSOOptions{ + GSOType: gsoType, + HdrLen: v.hdrLen, + CsumStart: v.csumStart, + CsumOffset: v.csumOffset, + GSOSize: v.gsoSize, + NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0, + }, nil +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, idealBatchSize), + itemsPool: make([][]tcpGROItem, idealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, idealBatchSize), + itemsPool: make([][]udpGROItem, idealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, idealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := header.PseudoHeaderChecksum(tcpip.TransportProtocolNumber(proto), pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum.Checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum.Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum.Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +type groDisablementFlags int + +const ( + tcpGRODisabled groDisablementFlags = 1 << iota + udpGRODisabled +) + +func (g *groDisablementFlags) disableTCPGRO() { + *g |= tcpGRODisabled +} + +func (g *groDisablementFlags) canTCPGRO() bool { + return (*g)&tcpGRODisabled == 0 +} + +func (g *groDisablementFlags) disableUDPGRO() { + *g |= udpGRODisabled +} + +func (g *groDisablementFlags) canUDPGRO() bool { + return (*g)&udpGRODisabled == 0 +} + +func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO +// are supported/enabled. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], gro) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} diff --git a/tun_rules.go b/tun_rules.go index 93b0430..c1b983f 100644 --- a/tun_rules.go +++ b/tun_rules.go @@ -108,7 +108,7 @@ const autoRouteUseSubRanges = runtime.GOOS == "darwin" func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) { var routeRanges []netip.Prefix - if o.AutoRoute && len(o.Inet4Address) > 0 { + if len(o.Inet4Address) > 0 { var inet4Ranges []netip.Prefix if len(o.Inet4RouteAddress) > 0 { inet4Ranges = o.Inet4RouteAddress @@ -119,19 +119,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref } } } - } else if autoRouteUseSubRanges && !underNetworkExtension { - inet4Ranges = []netip.Prefix{ - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2), - netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1), + } else if o.AutoRoute { + if autoRouteUseSubRanges && !underNetworkExtension { + inet4Ranges = []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1), + } + } else { + inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} + } + } else if runtime.GOOS == "darwin" { + for _, address := range o.Inet4Address { + if address.Bits() < 32 { + inet4Ranges = append(inet4Ranges, address.Masked()) + } } - } else { - inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} } if len(o.Inet4RouteExcludeAddress) == 0 { routeRanges = append(routeRanges, inet4Ranges...) @@ -161,19 +169,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref } } } - } else if autoRouteUseSubRanges && !underNetworkExtension { - inet6Ranges = []netip.Prefix{ - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2), - netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1), + } else if o.AutoRoute { + if autoRouteUseSubRanges && !underNetworkExtension { + inet6Ranges = []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1), + } + } else { + inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + } + } else if runtime.GOOS == "darwin" { + for _, address := range o.Inet6Address { + if address.Bits() < 32 { + inet6Ranges = append(inet6Ranges, address.Masked()) + } } - } else { - inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)} } if len(o.Inet6RouteExcludeAddress) == 0 { routeRanges = append(routeRanges, inet6Ranges...) diff --git a/tun_windows.go b/tun_windows.go index 392b78c..9cb1e96 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -72,7 +72,7 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv4 address") } - if !t.options.EXP_DisableDNSHijack { + if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4) if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) { dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()} @@ -83,6 +83,11 @@ func (t *NativeTun) configure() error { return E.Cause(err, "set ipv4 dns") } } + } else { + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), nil, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") + } } } if len(t.options.Inet6Address) > 0 { @@ -90,7 +95,7 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv6 address") } - if !t.options.EXP_DisableDNSHijack { + if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6) if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) { dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()} @@ -101,6 +106,11 @@ func (t *NativeTun) configure() error { return E.Cause(err, "set ipv6 dns") } } + } else { + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), nil, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") + } } } if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { @@ -148,6 +158,10 @@ func (t *NativeTun) configure() error { return nil } +func (t *NativeTun) Name() (string, error) { + return t.options.Name, nil +} + func (t *NativeTun) Start() error { if !t.options.AutoRoute { return nil @@ -158,13 +172,7 @@ func (t *NativeTun) Start() error { if err != nil { return err } - for _, routeRange := range routeRanges { - if routeRange.Addr().Is4() { - err = luid.AddRoute(routeRange, gateway4, 0) - } else { - err = luid.AddRoute(routeRange, gateway6, 0) - } - } + err = addRouteList(luid, routeRanges, gateway4, gateway6, 0) if err != nil { return err } @@ -349,7 +357,40 @@ func (t *NativeTun) Start() error { } func (t *NativeTun) Read(p []byte) (n int, err error) { - return 0, os.ErrInvalid + t.running.Add(1) + defer t.running.Done() +retry: + if t.close.Load() == 1 { + return 0, os.ErrClosed + } + start := nanotime() + shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 + for { + if t.close.Load() == 1 { + return 0, os.ErrClosed + } + var packet []byte + packet, err = t.session.ReceivePacket() + switch err { + case nil: + n = copy(p, packet) + t.session.ReleaseReceivePacket(packet) + t.rate.update(uint64(n)) + return + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(t.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("send ring corrupt") + } + return 0, fmt.Errorf("read failed: %w", err) + } } func (t *NativeTun) ReadPacket() ([]byte, func(), error) { @@ -498,6 +539,63 @@ func (t *NativeTun) Close() error { return err } +func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error { + t.options = tunOptions + if !t.options.AutoRoute { + return nil + } + gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() + routeRanges, err := t.options.BuildAutoRouteRanges(false) + if err != nil { + return err + } + luid := winipcfg.LUID(t.adapter.LUID()) + err = luid.FlushRoutes(windows.AF_UNSPEC) + if err != nil { + return err + } + err = addRouteList(luid, routeRanges, gateway4, gateway6, 0) + if err != nil { + return err + } + err = windnsapi.FlushResolverCache() + if err != nil { + return err + } + return nil +} + +func addRouteList(luid winipcfg.LUID, destinations []netip.Prefix, gateway4 netip.Addr, gateway6 netip.Addr, metric uint32) error { + row := winipcfg.MibIPforwardRow2{} + row.Init() + row.InterfaceLUID = luid + row.Metric = metric + nextHop4 := row.NextHop + nextHop6 := row.NextHop + if gateway4.IsValid() { + nextHop4.SetAddr(gateway4) + } + if gateway6.IsValid() { + nextHop6.SetAddr(gateway6) + } + for _, destination := range destinations { + err := row.DestinationPrefix.SetPrefix(destination) + if err != nil { + return err + } + if destination.Addr().Is4() { + row.NextHop = nextHop4 + } else { + row.NextHop = nextHop6 + } + err = row.Create() + if err != nil { + return err + } + } + return nil +} + func generateGUIDByDeviceName(name string) *windows.GUID { hash := md5.New() hash.Write([]byte("wintun"))