diff --git a/internal/checksum_test/sum_bench_test.go b/internal/checksum_test/sum_bench_test.go new file mode 100644 index 0000000..bfc0752 --- /dev/null +++ b/internal/checksum_test/sum_bench_test.go @@ -0,0 +1,33 @@ +package checksum_test + +import ( + "crypto/rand" + "testing" + + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/tschecksum" +) + +func BenchmarkTsChecksum(b *testing.B) { + packet := make([][]byte, 1000) + for i := 0; i < 1000; i++ { + packet[i] = make([]byte, 1500) + rand.Read(packet[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + tschecksum.Checksum(packet[i%1000], 0) + } +} + +func BenchmarkGChecksum(b *testing.B) { + packet := make([][]byte, 1000) + for i := 0; i < 1000; i++ { + packet[i] = make([]byte, 1500) + rand.Read(packet[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum.Checksum(packet[i%1000], 0) + } +} diff --git a/internal/gtcpip/checksum/checksum.go b/internal/gtcpip/checksum/checksum.go index 5d4e117..dfb4dd7 100644 --- a/internal/gtcpip/checksum/checksum.go +++ b/internal/gtcpip/checksum/checksum.go @@ -30,34 +30,6 @@ func Put(b []byte, xsum uint16) { binary.BigEndian.PutUint16(b, xsum) } -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. This function uses an optimized version of the checksum -// algorithm. -// -// The initial checksum must have been computed on an even number of bytes. -func Checksum(buf []byte, initial uint16) uint16 { - s, _ := calculateChecksum(buf, false, initial) - return s -} - -// Checksumer calculates checksum defined in RFC 1071. -type Checksumer struct { - sum uint16 - odd bool -} - -// Add adds b to checksum. -func (c *Checksumer) Add(b []byte) { - if len(b) > 0 { - c.sum, c.odd = calculateChecksum(b, c.odd, c.sum) - } -} - -// Checksum returns the latest checksum value. -func (c *Checksumer) Checksum() uint16 { - return c.sum -} - // Combine combines the two uint16 to form their checksum. This is done // by adding them and the carry. // diff --git a/internal/gtcpip/checksum/checksum_default.go b/internal/gtcpip/checksum/checksum_default.go new file mode 100644 index 0000000..ea4585e --- /dev/null +++ b/internal/gtcpip/checksum/checksum_default.go @@ -0,0 +1,13 @@ +//go:build !amd64 + +package checksum + +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. This function uses an optimized version of the checksum +// algorithm. +// +// The initial checksum must have been computed on an even number of bytes. +func Checksum(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, initial) + return s +} diff --git a/internal/gtcpip/checksum/checksum_ts.go b/internal/gtcpip/checksum/checksum_ts.go new file mode 100644 index 0000000..f6766d3 --- /dev/null +++ b/internal/gtcpip/checksum/checksum_ts.go @@ -0,0 +1,9 @@ +//go:build amd64 + +package checksum + +import "github.com/sagernet/sing-tun/internal/tschecksum" + +func Checksum(buf []byte, initial uint16) uint16 { + return tschecksum.Checksum(buf, initial) +} diff --git a/internal/gtcpip/checksum/checksum_unsafe.go b/internal/gtcpip/checksum/checksum_unsafe.go index 66b7ab6..83f35c8 100644 --- a/internal/gtcpip/checksum/checksum_unsafe.go +++ b/internal/gtcpip/checksum/checksum_unsafe.go @@ -1,3 +1,5 @@ +//go:build !amd64 + // Copyright 2023 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/internal/tschecksum/checksum.go b/internal/tschecksum/checksum.go new file mode 100644 index 0000000..677879f --- /dev/null +++ b/internal/tschecksum/checksum.go @@ -0,0 +1,712 @@ +package tschecksum + +import ( + "encoding/binary" + "math/bits" + "strconv" + + "golang.org/x/sys/cpu" +) + +// checksumGeneric64 is a reference implementation of checksum using 64 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric64(b []byte, initial uint16) uint16 { + var ac uint64 + var carry uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 128 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry) + } + b = b[128:] + } + if len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry) + } else { + ac, carry = bits.Add64(ac, uint64(b[0]), carry) + } + } + + folded := ipChecksumFold64(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32 is a reference implementation of checksum using 32 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric32(b []byte, initial uint16) uint16 { + var ac uint32 + var carry uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry) + } else { + ac, carry = bits.Add32(ac, uint32(b[0]), carry) + } + } + + folded := ipChecksumFold32(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32Alternate is an alternate reference implementation of +// checksum using 32 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric32Alternate(b []byte, initial uint16) uint16 { + var ac uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + ac += uint32(binary.BigEndian.Uint16(b[32:34])) + ac += uint32(binary.BigEndian.Uint16(b[34:36])) + ac += uint32(binary.BigEndian.Uint16(b[36:38])) + ac += uint32(binary.BigEndian.Uint16(b[38:40])) + ac += uint32(binary.BigEndian.Uint16(b[40:42])) + ac += uint32(binary.BigEndian.Uint16(b[42:44])) + ac += uint32(binary.BigEndian.Uint16(b[44:46])) + ac += uint32(binary.BigEndian.Uint16(b[46:48])) + ac += uint32(binary.BigEndian.Uint16(b[48:50])) + ac += uint32(binary.BigEndian.Uint16(b[50:52])) + ac += uint32(binary.BigEndian.Uint16(b[52:54])) + ac += uint32(binary.BigEndian.Uint16(b[54:56])) + ac += uint32(binary.BigEndian.Uint16(b[56:58])) + ac += uint32(binary.BigEndian.Uint16(b[58:60])) + ac += uint32(binary.BigEndian.Uint16(b[60:62])) + ac += uint32(binary.BigEndian.Uint16(b[62:64])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + ac += uint32(binary.LittleEndian.Uint16(b[32:34])) + ac += uint32(binary.LittleEndian.Uint16(b[34:36])) + ac += uint32(binary.LittleEndian.Uint16(b[36:38])) + ac += uint32(binary.LittleEndian.Uint16(b[38:40])) + ac += uint32(binary.LittleEndian.Uint16(b[40:42])) + ac += uint32(binary.LittleEndian.Uint16(b[42:44])) + ac += uint32(binary.LittleEndian.Uint16(b[44:46])) + ac += uint32(binary.LittleEndian.Uint16(b[46:48])) + ac += uint32(binary.LittleEndian.Uint16(b[48:50])) + ac += uint32(binary.LittleEndian.Uint16(b[50:52])) + ac += uint32(binary.LittleEndian.Uint16(b[52:54])) + ac += uint32(binary.LittleEndian.Uint16(b[54:56])) + ac += uint32(binary.LittleEndian.Uint16(b[56:58])) + ac += uint32(binary.LittleEndian.Uint16(b[58:60])) + ac += uint32(binary.LittleEndian.Uint16(b[60:62])) + ac += uint32(binary.LittleEndian.Uint16(b[62:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b)) + } else { + ac += uint32(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint32(b[0]) << 8 + } else { + ac += uint32(b[0]) + } + } + + folded := ipChecksumFold32(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric64Alternate is an alternate reference implementation of +// checksum using 64 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric64Alternate(b []byte, initial uint16) uint16 { + var ac uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + ac += uint64(binary.LittleEndian.Uint32(b[32:36])) + ac += uint64(binary.LittleEndian.Uint32(b[36:40])) + ac += uint64(binary.LittleEndian.Uint32(b[40:44])) + ac += uint64(binary.LittleEndian.Uint32(b[44:48])) + ac += uint64(binary.LittleEndian.Uint32(b[48:52])) + ac += uint64(binary.LittleEndian.Uint32(b[52:56])) + ac += uint64(binary.LittleEndian.Uint32(b[56:60])) + ac += uint64(binary.LittleEndian.Uint32(b[60:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b)) + } else { + ac += uint64(binary.LittleEndian.Uint32(b)) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint16(b)) + } else { + ac += uint64(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint64(b[0]) << 8 + } else { + ac += uint64(b[0]) + } + } + + folded := ipChecksumFold64(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 { + sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry)) + // if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff + // therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is + // no need to save the carry flag + sum = (sum >> 16) + (sum & 0xffff) + carry + // sum <= 0x1_fffe therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 { + sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry + // sum <= 0x1_ffff: + // 0xffff + 0xffff = 0x1_fffe + // initialCarry is 0 or 1, for a combined maximum of 0x1_ffff + sum = (sum >> 16) + (sum & 0xffff) + // sum <= 0x1_0000 therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry) + } else { + sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry) + } else { + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint64 + if cpu.IsBigEndian { + sum = uint64(totalLen) + uint64(protocol) + } else { + sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8 + } + sum, carry := addrPartialChecksum64(srcAddr, sum, 0) + sum, carry = addrPartialChecksum64(dstAddr, sum, carry) + + foldedSum := ipChecksumFold64(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint32 + if cpu.IsBigEndian { + sum = uint32(totalLen) + uint32(protocol) + } else { + sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8 + } + sum, carry := addrPartialChecksum32(srcAddr, sum, 0) + sum, carry = addrPartialChecksum32(dstAddr, sum, carry) + + foldedSum := ipChecksumFold32(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and +// dstAddr must be 4 or 16 bytes in length. +func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + if strconv.IntSize < 64 { + return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) + } + return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen) +} diff --git a/internal/tschecksum/checksum_amd64.go b/internal/tschecksum/checksum_amd64.go new file mode 100644 index 0000000..85b925a --- /dev/null +++ b/internal/tschecksum/checksum_amd64.go @@ -0,0 +1,23 @@ +package tschecksum + +import "golang.org/x/sys/cpu" + +var checksum = checksumAMD64 + +// Checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. +func Checksum(data []byte, initial uint16) uint16 { + return checksum(data, initial) +} + +func init() { + if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { + checksum = checksumAVX2 + return + } + if cpu.X86.HasSSE2 { + checksum = checksumSSE2 + return + } +} diff --git a/internal/tschecksum/checksum_generated_amd64.go b/internal/tschecksum/checksum_generated_amd64.go new file mode 100644 index 0000000..acc7350 --- /dev/null +++ b/internal/tschecksum/checksum_generated_amd64.go @@ -0,0 +1,18 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +package tschecksum + +// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2) +// +//go:noescape +func checksumAVX2(b []byte, initial uint16) uint16 + +// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2) +// +//go:noescape +func checksumSSE2(b []byte, initial uint16) uint16 + +// checksumAMD64 computes an IP checksum using amd64 baseline instructions +// +//go:noescape +func checksumAMD64(b []byte, initial uint16) uint16 diff --git a/internal/tschecksum/checksum_generated_amd64.s b/internal/tschecksum/checksum_generated_amd64.s new file mode 100644 index 0000000..5f2e4c5 --- /dev/null +++ b/internal/tschecksum/checksum_generated_amd64.s @@ -0,0 +1,851 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +#include "textflag.h" + +DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff" +DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112 + +// func checksumAVX2(b []byte, initial uint16) uint16 +// Requires: AVX, AVX2, BMI2 +TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + VPXOR Y0, Y0, Y0 + VPXOR Y1, Y1, Y1 + VPXOR Y2, Y2, Y2 + VPXOR Y3, Y3, Y3 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 16(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 32(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 48(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 64(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 80(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 96(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 112(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 128(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 144(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 160(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 176(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 192(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 208(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 224(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 240(DX), Y4 + VPADDD Y4, Y3, Y3 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + VMOVDQU -16(DX)(BX*1), X4 + VPAND -16(CX)(BX*8), X4, X4 + VPMOVZXWD X4, Y4 + VPADDD Y4, Y0, Y0 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + VPADDD Y1, Y0, Y0 + VPADDD Y2, Y0, Y0 + VPADDD Y3, Y0, Y0 + + // extract the YMM into a pair of XMM and sum them + VEXTRACTI128 $0x01, Y0, X1 + VPADDD X0, X1, X0 + + // extract the XMM into GP64 + VPEXTRQ $0x00, X0, CX + VPEXTRQ $0x01, X0, DX + + // no more AVX code, clear upper registers to avoid SSE slowdowns + VZEROUPPER + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + RORXQ $0x20, AX, CX + ADCL CX, AX + RORXL $0x10, AX, CX + ADCW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumSSE2(b []byte, initial uint16) uint16 +// Requires: SSE2 +TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + PXOR X0, X0 + PXOR X1, X1 + PXOR X2, X2 + PXOR X3, X3 + PXOR X4, X4 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 16(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 32(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 48(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 64(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 80(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 96(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 112(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 128(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 144(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 160(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 176(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 192(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 208(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 224(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 240(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + MOVOU -16(DX)(BX*1), X5 + PAND -16(CX)(BX*8), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + PADDD X1, X0 + PADDD X2, X0 + PADDD X3, X0 + + // extract the XMM into GP64 + MOVQ X0, CX + PSRLDQ $0x08, X0 + MOVQ X0, DX + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumAMD64(b []byte, initial uint16) uint16 +TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // Number of 256 byte iterations into loop counter + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + SHRQ $0x08, CX + JZ startCleanup + CLC + XORQ SI, SI + XORQ DI, DI + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + XORQ R12, R12 + +bigLoop: + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ 32(DX), DI + ADCQ 40(DX), DI + ADCQ 48(DX), DI + ADCQ 56(DX), DI + ADCQ $0x00, R8 + ADDQ 64(DX), R9 + ADCQ 72(DX), R9 + ADCQ 80(DX), R9 + ADCQ 88(DX), R9 + ADCQ $0x00, R10 + ADDQ 96(DX), R11 + ADCQ 104(DX), R11 + ADCQ 112(DX), R11 + ADCQ 120(DX), R11 + ADCQ $0x00, R12 + ADDQ 128(DX), AX + ADCQ 136(DX), AX + ADCQ 144(DX), AX + ADCQ 152(DX), AX + ADCQ $0x00, SI + ADDQ 160(DX), DI + ADCQ 168(DX), DI + ADCQ 176(DX), DI + ADCQ 184(DX), DI + ADCQ $0x00, R8 + ADDQ 192(DX), R9 + ADCQ 200(DX), R9 + ADCQ 208(DX), R9 + ADCQ 216(DX), R9 + ADCQ $0x00, R10 + ADDQ 224(DX), R11 + ADCQ 232(DX), R11 + ADCQ 240(DX), R11 + ADCQ 248(DX), R11 + ADCQ $0x00, R12 + ADDQ $0x00000100, DX + SUBQ $0x01, CX + JNZ bigLoop + ADDQ SI, AX + ADCQ DI, AX + ADCQ R8, AX + ADCQ R9, AX + ADCQ R10, AX + ADCQ R11, AX + ADCQ R12, AX + + // accumulate CF (twice, in case the first time overflows) + ADCQ $0x00, AX + ADCQ $0x00, AX + +startCleanup: + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET diff --git a/internal/tschecksum/checksum_generic.go b/internal/tschecksum/checksum_generic.go new file mode 100644 index 0000000..2d6c134 --- /dev/null +++ b/internal/tschecksum/checksum_generic.go @@ -0,0 +1,15 @@ +// This file contains IP checksum algorithms that are not specific to any +// architecture and don't use hardware acceleration. + +//go:build !amd64 + +package tschecksum + +import "strconv" + +func Checksum(data []byte, initial uint16) uint16 { + if strconv.IntSize < 64 { + return checksumGeneric32(data, initial) + } + return checksumGeneric64(data, initial) +} diff --git a/internal/tschecksum/generate_amd64.go b/internal/tschecksum/generate_amd64.go new file mode 100644 index 0000000..a72a59e --- /dev/null +++ b/internal/tschecksum/generate_amd64.go @@ -0,0 +1,578 @@ +//go:build ignore + +//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go + +package main + +import ( + "fmt" + "math" + "math/bits" + + "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +const checksumSignature = "func(b []byte, initial uint16) uint16" + +func loadParams() (accum, buf, n reg.GPVirtual) { + accum, buf, n = GP64(), GP64(), GP64() + Load(Param("initial"), accum) + XCHGB(accum.As8H(), accum.As8L()) + Load(Param("b").Base(), buf) + Load(Param("b").Len(), n) + return +} + +type simdStrategy int + +const ( + sse2 = iota + avx2 +) + +const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes. + +func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const simdReadSize = 16 + + if minSIMDSize > tinyBufferSize { + Comment("skip all SIMD for small buffers") + if minSIMDSize <= math.MaxUint8 { + CMPQ(n, operand.U8(minSIMDSize)) + } else { + CMPQ(n, operand.U32(minSIMDSize)) + } + JGE(operand.LabelRef("startSIMD")) + + handleRemaining(n, buf, accum64, minSIMDSize-1) + JMP(operand.LabelRef("foldAndReturn")) + } + + Label("startSIMD") + + // chains is the number of accumulators to use. This improves speed via + // reduced data dependency. We combine the accumulators once when the big + // loop is complete. + simdAccumulate := make([]reg.VecVirtual, chains) + for i := range simdAccumulate { + switch strategy { + case sse2: + simdAccumulate[i] = XMM() + PXOR(simdAccumulate[i], simdAccumulate[i]) + case avx2: + simdAccumulate[i] = YMM() + VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i]) + } + } + var zero reg.VecVirtual + if strategy == sse2 { + zero = XMM() + PXOR(zero, zero) + } + + // Number of loads per big loop + const unroll = 16 + // Number of bytes + loopSize := uint64(simdReadSize * unroll) + if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + Comment(fmt.Sprintf("Number of %d byte iterations", loopSize)) + SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount) + JZ(operand.LabelRef("smallLoop")) + Label("bigLoop") + for i := 0; i < unroll; i++ { + chain := i % chains + switch strategy { + case sse2: + sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains]) + case avx2: + avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain]) + } + } + ADDQ(operand.U32(loopSize), buf) + DECQ(loopCount) + JNZ(operand.LabelRef("bigLoop")) + + Label("bigCleanup") + + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JLT(operand.LabelRef("doneSmallLoop")) + + Commentf("now read a single %d byte unit of data at a time", simdReadSize) + Label("smallLoop") + + switch strategy { + case sse2: + sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1]) + case avx2: + avx2AccumulateStep(0, buf, simdAccumulate[0]) + } + ADDQ(operand.Imm(uint64(simdReadSize)), buf) + SUBQ(operand.Imm(uint64(simdReadSize)), n) + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JGE(operand.LabelRef("smallLoop")) + + Label("doneSmallLoop") + CMPQ(n, operand.Imm(0)) + JE(operand.LabelRef("doneSIMD")) + + Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1) + + maskDataPtr := GP64() + LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr) + dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize} + // scale 8 is only correct here because n is guaranteed to be even and we + // do not generate masks for odd lengths + maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16} + remainder := XMM() + + switch strategy { + case sse2: + MOVOU(dataAddr, remainder) + PAND(maskAddr, remainder) + low := XMM() + MOVOA(remainder, low) + PUNPCKHWL(zero, remainder) + PUNPCKLWL(zero, low) + PADDD(remainder, simdAccumulate[0]) + PADDD(low, simdAccumulate[1]) + case avx2: + // Note: this is very similar to the sse2 path but MOVOU has a massive + // performance hit if used here, presumably due to switching between SSE + // and AVX2 modes. + VMOVDQU(dataAddr, remainder) + VPAND(maskAddr, remainder, remainder) + + temp := YMM() + VPMOVZXWD(remainder, temp) + VPADDD(temp, simdAccumulate[0], simdAccumulate[0]) + } + + Label("doneSIMD") + + Comment("Multi-chain loop is done, combine the accumulators") + for i := range simdAccumulate { + if i == 0 { + continue + } + switch strategy { + case sse2: + PADDD(simdAccumulate[i], simdAccumulate[0]) + case avx2: + VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0]) + } + } + + if strategy == avx2 { + Comment("extract the YMM into a pair of XMM and sum them") + tmp := YMM() + VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX()) + + xAccumulate := XMM() + VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate) + simdAccumulate = []reg.VecVirtual{xAccumulate} + } + + Comment("extract the XMM into GP64") + low, high := GP64(), GP64() + switch strategy { + case sse2: + MOVQ(simdAccumulate[0], low) + PSRLDQ(operand.Imm(8), simdAccumulate[0]) + MOVQ(simdAccumulate[0], high) + case avx2: + VPEXTRQ(operand.Imm(0), simdAccumulate[0], low) + VPEXTRQ(operand.Imm(1), simdAccumulate[0], high) + + Comment("no more AVX code, clear upper registers to avoid SSE slowdowns") + VZEROUPPER() + } + ADDQ(low, accum64) + ADCQ(high, accum64) + Label("foldAndReturn") + foldWithCF(accum64, strategy == avx2) + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleOddLength generates instructions to incorporate the last byte into +// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to +// handle that if overflow is possible. +func handleOddLength(n, buf, accum64 reg.GPVirtual) { + Comment("handle odd length buffers; they are difficult to handle in general") + TESTQ(operand.U32(1), n) + JZ(operand.LabelRef("lengthIsEven")) + + tmp := GP64() + MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp) + DECQ(n) + ADDQ(tmp, accum64) + + Label("lengthIsEven") +} + +func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) { + high, low := XMM(), XMM() + MOVOU(operand.Mem{Disp: offset, Base: buf}, high) + MOVOA(high, low) + PUNPCKHWL(zero, high) + PUNPCKLWL(zero, low) + PADDD(high, accumulate1) + PADDD(low, accumulate2) +} + +func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) { + tmp := YMM() + VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp) + VPADDD(tmp, accumulate, accumulate) +} + +func generateAMD64Checksum(name, doc string) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const ( + // numChains is the number of accumulators and carry counters to use. + // This improves speed via reduced data dependency. We combine the + // accumulators and carry counters once when the loop is complete. + numChains = 4 + unroll = 32 // The number of 64-bit reads to perform per iteration of the loop. + loopSize = 8 * unroll // The number of bytes read per iteration of the loop. + ) + if bits.Len(loopSize) != bits.Len(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize)) + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount) + JZ(operand.LabelRef("startCleanup")) + CLC() + + chains := make([]struct { + accum reg.GPVirtual + carries reg.GPVirtual + }, numChains) + for i := range chains { + if i == 0 { + chains[i].accum = accum64 + } else { + chains[i].accum = GP64() + XORQ(chains[i].accum, chains[i].accum) + } + chains[i].carries = GP64() + XORQ(chains[i].carries, chains[i].carries) + } + + Label("bigLoop") + + var curChain int + for i := 0; i < unroll; i++ { + // It is significantly faster to use a ADCX/ADOX pair instead of plain + // ADC, which results in two dependency chains, however those require + // ADX support, which was added after AVX2. If AVX2 is available, that's + // even better than ADCX/ADOX. + // + // However, multiple dependency chains using multiple accumulators and + // occasionally storing CF into temporary counters seems to work almost + // as well. + addr := operand.Mem{Disp: i * 8, Base: buf} + + if i%4 == 0 { + if i > 0 { + ADCQ(operand.Imm(0), chains[curChain].carries) + curChain = (curChain + 1) % len(chains) + } + ADDQ(addr, chains[curChain].accum) + } else { + ADCQ(addr, chains[curChain].accum) + } + } + ADCQ(operand.Imm(0), chains[curChain].carries) + ADDQ(operand.U32(loopSize), buf) + SUBQ(operand.Imm(1), loopCount) + JNZ(operand.LabelRef("bigLoop")) + for i := range chains { + if i == 0 { + ADDQ(chains[i].carries, accum64) + continue + } + ADCQ(chains[i].accum, accum64) + ADCQ(chains[i].carries, accum64) + } + + accumulateCF(accum64) + + Label("startCleanup") + handleRemaining(n, buf, accum64, loopSize-1) + Label("foldAndReturn") + foldWithCF(accum64, false) + + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleTinyBuffers computes checksums if the buffer length (the n parameter) +// is less than 32. After computing the checksum, a jump to returnLabel will +// be executed. Otherwise, if the buffer length is at least 32, nothing will be +// modified; a jump to continueLabel will be executed instead. +// +// When jumping to returnLabel, CF may be set and must be accommodated e.g. +// using foldWithCF or accumulateCF. +// +// Anecdotally, this appears to be faster than attempting to coordinate an +// overlapped read (which would also require special handling for buffers +// smaller than 8). +func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) { + Comment("handle tiny buffers (<=31 bytes) specially") + CMPQ(n, operand.Imm(tinyBufferSize)) + JGT(continueLabel) + + tmp2, tmp4, tmp8 := GP64(), GP64(), GP64() + XORQ(tmp2, tmp2) + XORQ(tmp4, tmp4) + XORQ(tmp8, tmp8) + + Comment("shift twice to start because length is guaranteed to be even", + "n = n >> 2; CF = originalN & 2") + SHRQ(operand.Imm(2), n) + JNC(operand.LabelRef("handleTiny4")) + Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]") + MOVWQZX(operand.Mem{Base: buf}, tmp2) + ADDQ(operand.Imm(2), buf) + + Label("handleTiny4") + Comment("n = n >> 1; CF = originalN & 4") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny8")) + Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]") + MOVLQZX(operand.Mem{Base: buf}, tmp4) + ADDQ(operand.Imm(4), buf) + + Label("handleTiny8") + Comment("n = n >> 1; CF = originalN & 8") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny16")) + Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]") + MOVQ(operand.Mem{Base: buf}, tmp8) + ADDQ(operand.Imm(8), buf) + + Label("handleTiny16") + Comment("n = n >> 1; CF = originalN & 16", + "n == 0 now, otherwise we would have branched after comparing with tinyBufferSize") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTinyFinish")) + ADDQ(operand.Mem{Base: buf}, accum) + ADCQ(operand.Mem{Base: buf, Disp: 8}, accum) + + Label("handleTinyFinish") + Comment("CF should be included from the previous add, so we use ADCQ.", + "If we arrived via the JNC above, then CF=0 due to the branch condition,", + "so ADCQ will still produce the correct result.") + ADCQ(tmp2, accum) + ADCQ(tmp4, accum) + ADCQ(tmp8, accum) + + JMP(returnLabel) +} + +// handleRemaining generates a series of conditional unrolled additions, +// starting with 8 bytes long and doubling each time until the length reaches +// max. This is the reverse order of what may be intuitive, but makes the branch +// conditions convenient to compute: perform one right shift each time and test +// against CF. +// +// When done, CF may be set and must be accommodated e.g., using foldWithCF or +// accumulateCF. +// +// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer +// will be performed, overlapping with data that will be read later. The +// duplicate data will be shifted off. +// +// The original buffer length must have been at least 8 bytes long, even if +// n < 8, otherwise this will access memory before the start of the buffer, +// which may be unsafe. +func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) { + Comment("Accumulate carries in this register. It is never expected to overflow.") + carries := GP64() + XORQ(carries, carries) + + Comment("We will perform an overlapped read for buffers with length not a multiple of 8.", + "Overlapped in this context means some memory will be read twice, but a shift will", + "eliminate the duplicated data. This extra read is performed at the end of the buffer to", + "preserve any alignment that may exist for the start of the buffer.") + leftover := reg.RCX + MOVQ(n, leftover) + SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining + ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end + JZ(operand.LabelRef("handleRemaining8")) + endBuf := GP64() + // endBuf is the position near the end of the buffer that is just past the + // last multiple of 8: (buf + len(buf)) & ^0x7 + LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf) + + overlapRead := GP64() + // equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)]) + MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead) + + Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)") + SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8 + NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression + ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8) + SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL + + ADDQ(overlapRead, accum64) + ADCQ(operand.Imm(0), carries) + + for curBytes := 8; curBytes <= max; curBytes *= 2 { + Label(fmt.Sprintf("handleRemaining%d", curBytes)) + SHRQ(operand.Imm(1), n) + if curBytes*2 <= max { + JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } else { + JNC(operand.LabelRef("handleRemainingComplete")) + } + + numLoads := curBytes / 8 + for i := 0; i < numLoads; i++ { + addr := operand.Mem{Base: buf, Disp: i * 8} + // It is possible to add the multiple dependency chains trick here + // that generateAMD64Checksum uses but anecdotally it does not + // appear to outweigh the cost. + if i == 0 { + ADDQ(addr, accum64) + continue + } + ADCQ(addr, accum64) + } + ADCQ(operand.Imm(0), carries) + + if curBytes > math.MaxUint8 { + ADDQ(operand.U32(uint64(curBytes)), buf) + } else { + ADDQ(operand.U8(uint64(curBytes)), buf) + } + if curBytes*2 >= max { + continue + } + JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } + Label("handleRemainingComplete") + ADDQ(carries, accum64) +} + +func accumulateCF(accum64 reg.GPVirtual) { + Comment("accumulate CF (twice, in case the first time overflows)") + // accum64 += CF + ADCQ(operand.Imm(0), accum64) + // accum64 += CF again if the previous add overflowed. The previous add was + // 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can + // never overflow. + ADCQ(operand.Imm(0), accum64) +} + +// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value +// according to ones-complement arithmetic. BMI2 instructions will be used if +// allowBMI2 is true (requires fewer instructions). +func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) { + Comment("add CF and fold") + + // CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff + + tmp := GP64() + if allowBMI2 { + // effectively, tmp = accum >> 32 (technically, this is a rotate) + RORXQ(operand.Imm(32), accum, tmp) + // accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set + ADCL(tmp.As32(), accum.As32()) + // effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate) + RORXL(operand.Imm(16), accum.As32(), tmp.As32()) + // accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set + ADCW(tmp.As16(), accum.As16()) + } else { + // tmp = uint32(accum); max value 0xffff_ffff + // MOVL clears the upper 32 bits of a GP64 so this is equivalent to the + // non-existent MOVLQZX. + MOVL(accum.As32(), tmp.As32()) + // tmp += CF; max value 0x1_0000_0000, CF unset + ADCQ(operand.Imm(0), tmp) + // accum = accum >> 32; max value 0xffff_ffff + SHRQ(operand.Imm(32), accum) + // accum = accum + tmp; max value 0x1_ffff_ffff + CF unset + ADDQ(tmp, accum) + // tmp = uint16(accum); max value 0xffff + MOVWQZX(accum.As16(), tmp) + // accum = accum >> 16; max value 0x1_ffff + SHRQ(operand.Imm(16), accum) + // accum = accum + tmp; max value 0x2_fffe + CF unset + ADDQ(tmp, accum) + // tmp as uint16 = uint16(accum); max value 0xffff + MOVW(accum.As16(), tmp.As16()) + // accum = accum >> 16; max value 0x2 + SHRQ(operand.Imm(16), accum) + // accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set + ADDW(tmp.As16(), accum.As16()) + } + // accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe + ADCW(operand.Imm(0), accum.As16()) +} + +func generateLoadMasks() { + var offset int + // xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14 + GLOBL("xmmLoadMasks", RODATA|NOPTR) + + for n := 2; n < 16; n += 2 { + var pattern [16]byte + for i := 0; i < len(pattern); i++ { + if i < len(pattern)-n { + pattern[i] = 0 + continue + } + pattern[i] = 0xff + } + DATA(offset, operand.String(pattern[:])) + offset += len(pattern) + } +} + +func main() { + generateLoadMasks() + generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2) + generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2) + generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions") + Generate() +}