Add tailscale checksum

This commit is contained in:
世界 2024-11-22 15:45:38 +08:00
parent 4ebeb2fa86
commit 06b4d4ecd1
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 2254 additions and 28 deletions

View file

@ -0,0 +1,33 @@
package checksum_test
import (
"crypto/rand"
"testing"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/tschecksum"
)
func BenchmarkTsChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tschecksum.Checksum(packet[i%1000], 0)
}
}
func BenchmarkGChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum.Checksum(packet[i%1000], 0)
}
}

View file

@ -30,34 +30,6 @@ func Put(b []byte, xsum uint16) {
binary.BigEndian.PutUint16(b, xsum)
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}
// Checksumer calculates checksum defined in RFC 1071.
type Checksumer struct {
sum uint16
odd bool
}
// Add adds b to checksum.
func (c *Checksumer) Add(b []byte) {
if len(b) > 0 {
c.sum, c.odd = calculateChecksum(b, c.odd, c.sum)
}
}
// Checksum returns the latest checksum value.
func (c *Checksumer) Checksum() uint16 {
return c.sum
}
// Combine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//

View file

@ -0,0 +1,13 @@
//go:build !amd64
package checksum
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}

View file

@ -0,0 +1,9 @@
//go:build amd64
package checksum
import "github.com/sagernet/sing-tun/internal/tschecksum"
func Checksum(buf []byte, initial uint16) uint16 {
return tschecksum.Checksum(buf, initial)
}

View file

@ -1,3 +1,5 @@
//go:build !amd64
// Copyright 2023 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");

View file

@ -0,0 +1,712 @@
package tschecksum
import (
"encoding/binary"
"math/bits"
"strconv"
"golang.org/x/sys/cpu"
)
// checksumGeneric64 is a reference implementation of checksum using 64 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric64(b []byte, initial uint16) uint16 {
var ac uint64
var carry uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 128 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
}
b = b[128:]
}
if len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
} else {
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
}
}
folded := ipChecksumFold64(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32 is a reference implementation of checksum using 32 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric32(b []byte, initial uint16) uint16 {
var ac uint32
var carry uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
} else {
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
}
}
folded := ipChecksumFold32(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32Alternate is an alternate reference implementation of
// checksum using 32 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
var ac uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b))
} else {
ac += uint32(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint32(b[0]) << 8
} else {
ac += uint32(b[0])
}
}
folded := ipChecksumFold32(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric64Alternate is an alternate reference implementation of
// checksum using 64 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
var ac uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b))
} else {
ac += uint64(binary.LittleEndian.Uint32(b))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint16(b))
} else {
ac += uint64(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint64(b[0]) << 8
} else {
ac += uint64(b[0])
}
}
folded := ipChecksumFold64(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
// no need to save the carry flag
sum = (sum >> 16) + (sum & 0xffff) + carry
// sum <= 0x1_fffe therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
// sum <= 0x1_ffff:
// 0xffff + 0xffff = 0x1_fffe
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
sum = (sum >> 16) + (sum & 0xffff)
// sum <= 0x1_0000 therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
} else {
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
} else {
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint64
if cpu.IsBigEndian {
sum = uint64(totalLen) + uint64(protocol)
} else {
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
}
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
foldedSum := ipChecksumFold64(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint32
if cpu.IsBigEndian {
sum = uint32(totalLen) + uint32(protocol)
} else {
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
}
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
foldedSum := ipChecksumFold32(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
// dstAddr must be 4 or 16 bytes in length.
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
if strconv.IntSize < 64 {
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
}
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
}

View file

@ -0,0 +1,23 @@
package tschecksum
import "golang.org/x/sys/cpu"
var checksum = checksumAMD64
// Checksum computes an IP checksum starting with the provided initial value.
// The length of data should be at least 128 bytes for best performance. Smaller
// buffers will still compute a correct result.
func Checksum(data []byte, initial uint16) uint16 {
return checksum(data, initial)
}
func init() {
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
checksum = checksumAVX2
return
}
if cpu.X86.HasSSE2 {
checksum = checksumSSE2
return
}
}

View file

@ -0,0 +1,18 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
package tschecksum
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
//
//go:noescape
func checksumAVX2(b []byte, initial uint16) uint16
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
//
//go:noescape
func checksumSSE2(b []byte, initial uint16) uint16
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
//
//go:noescape
func checksumAMD64(b []byte, initial uint16) uint16

View file

@ -0,0 +1,851 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
#include "textflag.h"
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
// func checksumAVX2(b []byte, initial uint16) uint16
// Requires: AVX, AVX2, BMI2
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
VPXOR Y0, Y0, Y0
VPXOR Y1, Y1, Y1
VPXOR Y2, Y2, Y2
VPXOR Y3, Y3, Y3
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 16(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 32(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 48(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 64(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 80(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 96(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 112(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 128(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 144(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 160(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 176(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 192(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 208(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 224(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 240(DX), Y4
VPADDD Y4, Y3, Y3
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
VMOVDQU -16(DX)(BX*1), X4
VPAND -16(CX)(BX*8), X4, X4
VPMOVZXWD X4, Y4
VPADDD Y4, Y0, Y0
doneSIMD:
// Multi-chain loop is done, combine the accumulators
VPADDD Y1, Y0, Y0
VPADDD Y2, Y0, Y0
VPADDD Y3, Y0, Y0
// extract the YMM into a pair of XMM and sum them
VEXTRACTI128 $0x01, Y0, X1
VPADDD X0, X1, X0
// extract the XMM into GP64
VPEXTRQ $0x00, X0, CX
VPEXTRQ $0x01, X0, DX
// no more AVX code, clear upper registers to avoid SSE slowdowns
VZEROUPPER
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
RORXQ $0x20, AX, CX
ADCL CX, AX
RORXL $0x10, AX, CX
ADCW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumSSE2(b []byte, initial uint16) uint16
// Requires: SSE2
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
PXOR X0, X0
PXOR X1, X1
PXOR X2, X2
PXOR X3, X3
PXOR X4, X4
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 16(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 32(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 48(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 64(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 80(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 96(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 112(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 128(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 144(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 160(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 176(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 192(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 208(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 224(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 240(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
MOVOU -16(DX)(BX*1), X5
PAND -16(CX)(BX*8), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
doneSIMD:
// Multi-chain loop is done, combine the accumulators
PADDD X1, X0
PADDD X2, X0
PADDD X3, X0
// extract the XMM into GP64
MOVQ X0, CX
PSRLDQ $0x08, X0
MOVQ X0, DX
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumAMD64(b []byte, initial uint16) uint16
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// Number of 256 byte iterations into loop counter
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
SHRQ $0x08, CX
JZ startCleanup
CLC
XORQ SI, SI
XORQ DI, DI
XORQ R8, R8
XORQ R9, R9
XORQ R10, R10
XORQ R11, R11
XORQ R12, R12
bigLoop:
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ 32(DX), DI
ADCQ 40(DX), DI
ADCQ 48(DX), DI
ADCQ 56(DX), DI
ADCQ $0x00, R8
ADDQ 64(DX), R9
ADCQ 72(DX), R9
ADCQ 80(DX), R9
ADCQ 88(DX), R9
ADCQ $0x00, R10
ADDQ 96(DX), R11
ADCQ 104(DX), R11
ADCQ 112(DX), R11
ADCQ 120(DX), R11
ADCQ $0x00, R12
ADDQ 128(DX), AX
ADCQ 136(DX), AX
ADCQ 144(DX), AX
ADCQ 152(DX), AX
ADCQ $0x00, SI
ADDQ 160(DX), DI
ADCQ 168(DX), DI
ADCQ 176(DX), DI
ADCQ 184(DX), DI
ADCQ $0x00, R8
ADDQ 192(DX), R9
ADCQ 200(DX), R9
ADCQ 208(DX), R9
ADCQ 216(DX), R9
ADCQ $0x00, R10
ADDQ 224(DX), R11
ADCQ 232(DX), R11
ADCQ 240(DX), R11
ADCQ 248(DX), R11
ADCQ $0x00, R12
ADDQ $0x00000100, DX
SUBQ $0x01, CX
JNZ bigLoop
ADDQ SI, AX
ADCQ DI, AX
ADCQ R8, AX
ADCQ R9, AX
ADCQ R10, AX
ADCQ R11, AX
ADCQ R12, AX
// accumulate CF (twice, in case the first time overflows)
ADCQ $0x00, AX
ADCQ $0x00, AX
startCleanup:
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET

View file

@ -0,0 +1,15 @@
// This file contains IP checksum algorithms that are not specific to any
// architecture and don't use hardware acceleration.
//go:build !amd64
package tschecksum
import "strconv"
func Checksum(data []byte, initial uint16) uint16 {
if strconv.IntSize < 64 {
return checksumGeneric32(data, initial)
}
return checksumGeneric64(data, initial)
}

View file

@ -0,0 +1,578 @@
//go:build ignore
//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go
package main
import (
"fmt"
"math"
"math/bits"
"github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
const checksumSignature = "func(b []byte, initial uint16) uint16"
func loadParams() (accum, buf, n reg.GPVirtual) {
accum, buf, n = GP64(), GP64(), GP64()
Load(Param("initial"), accum)
XCHGB(accum.As8H(), accum.As8L())
Load(Param("b").Base(), buf)
Load(Param("b").Len(), n)
return
}
type simdStrategy int
const (
sse2 = iota
avx2
)
const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes.
func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const simdReadSize = 16
if minSIMDSize > tinyBufferSize {
Comment("skip all SIMD for small buffers")
if minSIMDSize <= math.MaxUint8 {
CMPQ(n, operand.U8(minSIMDSize))
} else {
CMPQ(n, operand.U32(minSIMDSize))
}
JGE(operand.LabelRef("startSIMD"))
handleRemaining(n, buf, accum64, minSIMDSize-1)
JMP(operand.LabelRef("foldAndReturn"))
}
Label("startSIMD")
// chains is the number of accumulators to use. This improves speed via
// reduced data dependency. We combine the accumulators once when the big
// loop is complete.
simdAccumulate := make([]reg.VecVirtual, chains)
for i := range simdAccumulate {
switch strategy {
case sse2:
simdAccumulate[i] = XMM()
PXOR(simdAccumulate[i], simdAccumulate[i])
case avx2:
simdAccumulate[i] = YMM()
VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i])
}
}
var zero reg.VecVirtual
if strategy == sse2 {
zero = XMM()
PXOR(zero, zero)
}
// Number of loads per big loop
const unroll = 16
// Number of bytes
loopSize := uint64(simdReadSize * unroll)
if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
Comment(fmt.Sprintf("Number of %d byte iterations", loopSize))
SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount)
JZ(operand.LabelRef("smallLoop"))
Label("bigLoop")
for i := 0; i < unroll; i++ {
chain := i % chains
switch strategy {
case sse2:
sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains])
case avx2:
avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain])
}
}
ADDQ(operand.U32(loopSize), buf)
DECQ(loopCount)
JNZ(operand.LabelRef("bigLoop"))
Label("bigCleanup")
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JLT(operand.LabelRef("doneSmallLoop"))
Commentf("now read a single %d byte unit of data at a time", simdReadSize)
Label("smallLoop")
switch strategy {
case sse2:
sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1])
case avx2:
avx2AccumulateStep(0, buf, simdAccumulate[0])
}
ADDQ(operand.Imm(uint64(simdReadSize)), buf)
SUBQ(operand.Imm(uint64(simdReadSize)), n)
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JGE(operand.LabelRef("smallLoop"))
Label("doneSmallLoop")
CMPQ(n, operand.Imm(0))
JE(operand.LabelRef("doneSIMD"))
Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1)
maskDataPtr := GP64()
LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr)
dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize}
// scale 8 is only correct here because n is guaranteed to be even and we
// do not generate masks for odd lengths
maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16}
remainder := XMM()
switch strategy {
case sse2:
MOVOU(dataAddr, remainder)
PAND(maskAddr, remainder)
low := XMM()
MOVOA(remainder, low)
PUNPCKHWL(zero, remainder)
PUNPCKLWL(zero, low)
PADDD(remainder, simdAccumulate[0])
PADDD(low, simdAccumulate[1])
case avx2:
// Note: this is very similar to the sse2 path but MOVOU has a massive
// performance hit if used here, presumably due to switching between SSE
// and AVX2 modes.
VMOVDQU(dataAddr, remainder)
VPAND(maskAddr, remainder, remainder)
temp := YMM()
VPMOVZXWD(remainder, temp)
VPADDD(temp, simdAccumulate[0], simdAccumulate[0])
}
Label("doneSIMD")
Comment("Multi-chain loop is done, combine the accumulators")
for i := range simdAccumulate {
if i == 0 {
continue
}
switch strategy {
case sse2:
PADDD(simdAccumulate[i], simdAccumulate[0])
case avx2:
VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0])
}
}
if strategy == avx2 {
Comment("extract the YMM into a pair of XMM and sum them")
tmp := YMM()
VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX())
xAccumulate := XMM()
VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate)
simdAccumulate = []reg.VecVirtual{xAccumulate}
}
Comment("extract the XMM into GP64")
low, high := GP64(), GP64()
switch strategy {
case sse2:
MOVQ(simdAccumulate[0], low)
PSRLDQ(operand.Imm(8), simdAccumulate[0])
MOVQ(simdAccumulate[0], high)
case avx2:
VPEXTRQ(operand.Imm(0), simdAccumulate[0], low)
VPEXTRQ(operand.Imm(1), simdAccumulate[0], high)
Comment("no more AVX code, clear upper registers to avoid SSE slowdowns")
VZEROUPPER()
}
ADDQ(low, accum64)
ADCQ(high, accum64)
Label("foldAndReturn")
foldWithCF(accum64, strategy == avx2)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleOddLength generates instructions to incorporate the last byte into
// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to
// handle that if overflow is possible.
func handleOddLength(n, buf, accum64 reg.GPVirtual) {
Comment("handle odd length buffers; they are difficult to handle in general")
TESTQ(operand.U32(1), n)
JZ(operand.LabelRef("lengthIsEven"))
tmp := GP64()
MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp)
DECQ(n)
ADDQ(tmp, accum64)
Label("lengthIsEven")
}
func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) {
high, low := XMM(), XMM()
MOVOU(operand.Mem{Disp: offset, Base: buf}, high)
MOVOA(high, low)
PUNPCKHWL(zero, high)
PUNPCKLWL(zero, low)
PADDD(high, accumulate1)
PADDD(low, accumulate2)
}
func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) {
tmp := YMM()
VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp)
VPADDD(tmp, accumulate, accumulate)
}
func generateAMD64Checksum(name, doc string) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const (
// numChains is the number of accumulators and carry counters to use.
// This improves speed via reduced data dependency. We combine the
// accumulators and carry counters once when the loop is complete.
numChains = 4
unroll = 32 // The number of 64-bit reads to perform per iteration of the loop.
loopSize = 8 * unroll // The number of bytes read per iteration of the loop.
)
if bits.Len(loopSize) != bits.Len(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize))
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount)
JZ(operand.LabelRef("startCleanup"))
CLC()
chains := make([]struct {
accum reg.GPVirtual
carries reg.GPVirtual
}, numChains)
for i := range chains {
if i == 0 {
chains[i].accum = accum64
} else {
chains[i].accum = GP64()
XORQ(chains[i].accum, chains[i].accum)
}
chains[i].carries = GP64()
XORQ(chains[i].carries, chains[i].carries)
}
Label("bigLoop")
var curChain int
for i := 0; i < unroll; i++ {
// It is significantly faster to use a ADCX/ADOX pair instead of plain
// ADC, which results in two dependency chains, however those require
// ADX support, which was added after AVX2. If AVX2 is available, that's
// even better than ADCX/ADOX.
//
// However, multiple dependency chains using multiple accumulators and
// occasionally storing CF into temporary counters seems to work almost
// as well.
addr := operand.Mem{Disp: i * 8, Base: buf}
if i%4 == 0 {
if i > 0 {
ADCQ(operand.Imm(0), chains[curChain].carries)
curChain = (curChain + 1) % len(chains)
}
ADDQ(addr, chains[curChain].accum)
} else {
ADCQ(addr, chains[curChain].accum)
}
}
ADCQ(operand.Imm(0), chains[curChain].carries)
ADDQ(operand.U32(loopSize), buf)
SUBQ(operand.Imm(1), loopCount)
JNZ(operand.LabelRef("bigLoop"))
for i := range chains {
if i == 0 {
ADDQ(chains[i].carries, accum64)
continue
}
ADCQ(chains[i].accum, accum64)
ADCQ(chains[i].carries, accum64)
}
accumulateCF(accum64)
Label("startCleanup")
handleRemaining(n, buf, accum64, loopSize-1)
Label("foldAndReturn")
foldWithCF(accum64, false)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleTinyBuffers computes checksums if the buffer length (the n parameter)
// is less than 32. After computing the checksum, a jump to returnLabel will
// be executed. Otherwise, if the buffer length is at least 32, nothing will be
// modified; a jump to continueLabel will be executed instead.
//
// When jumping to returnLabel, CF may be set and must be accommodated e.g.
// using foldWithCF or accumulateCF.
//
// Anecdotally, this appears to be faster than attempting to coordinate an
// overlapped read (which would also require special handling for buffers
// smaller than 8).
func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) {
Comment("handle tiny buffers (<=31 bytes) specially")
CMPQ(n, operand.Imm(tinyBufferSize))
JGT(continueLabel)
tmp2, tmp4, tmp8 := GP64(), GP64(), GP64()
XORQ(tmp2, tmp2)
XORQ(tmp4, tmp4)
XORQ(tmp8, tmp8)
Comment("shift twice to start because length is guaranteed to be even",
"n = n >> 2; CF = originalN & 2")
SHRQ(operand.Imm(2), n)
JNC(operand.LabelRef("handleTiny4"))
Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]")
MOVWQZX(operand.Mem{Base: buf}, tmp2)
ADDQ(operand.Imm(2), buf)
Label("handleTiny4")
Comment("n = n >> 1; CF = originalN & 4")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny8"))
Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]")
MOVLQZX(operand.Mem{Base: buf}, tmp4)
ADDQ(operand.Imm(4), buf)
Label("handleTiny8")
Comment("n = n >> 1; CF = originalN & 8")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny16"))
Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]")
MOVQ(operand.Mem{Base: buf}, tmp8)
ADDQ(operand.Imm(8), buf)
Label("handleTiny16")
Comment("n = n >> 1; CF = originalN & 16",
"n == 0 now, otherwise we would have branched after comparing with tinyBufferSize")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTinyFinish"))
ADDQ(operand.Mem{Base: buf}, accum)
ADCQ(operand.Mem{Base: buf, Disp: 8}, accum)
Label("handleTinyFinish")
Comment("CF should be included from the previous add, so we use ADCQ.",
"If we arrived via the JNC above, then CF=0 due to the branch condition,",
"so ADCQ will still produce the correct result.")
ADCQ(tmp2, accum)
ADCQ(tmp4, accum)
ADCQ(tmp8, accum)
JMP(returnLabel)
}
// handleRemaining generates a series of conditional unrolled additions,
// starting with 8 bytes long and doubling each time until the length reaches
// max. This is the reverse order of what may be intuitive, but makes the branch
// conditions convenient to compute: perform one right shift each time and test
// against CF.
//
// When done, CF may be set and must be accommodated e.g., using foldWithCF or
// accumulateCF.
//
// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer
// will be performed, overlapping with data that will be read later. The
// duplicate data will be shifted off.
//
// The original buffer length must have been at least 8 bytes long, even if
// n < 8, otherwise this will access memory before the start of the buffer,
// which may be unsafe.
func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) {
Comment("Accumulate carries in this register. It is never expected to overflow.")
carries := GP64()
XORQ(carries, carries)
Comment("We will perform an overlapped read for buffers with length not a multiple of 8.",
"Overlapped in this context means some memory will be read twice, but a shift will",
"eliminate the duplicated data. This extra read is performed at the end of the buffer to",
"preserve any alignment that may exist for the start of the buffer.")
leftover := reg.RCX
MOVQ(n, leftover)
SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining
ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end
JZ(operand.LabelRef("handleRemaining8"))
endBuf := GP64()
// endBuf is the position near the end of the buffer that is just past the
// last multiple of 8: (buf + len(buf)) & ^0x7
LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf)
overlapRead := GP64()
// equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)])
MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead)
Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)")
SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8
NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression
ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8)
SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL
ADDQ(overlapRead, accum64)
ADCQ(operand.Imm(0), carries)
for curBytes := 8; curBytes <= max; curBytes *= 2 {
Label(fmt.Sprintf("handleRemaining%d", curBytes))
SHRQ(operand.Imm(1), n)
if curBytes*2 <= max {
JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
} else {
JNC(operand.LabelRef("handleRemainingComplete"))
}
numLoads := curBytes / 8
for i := 0; i < numLoads; i++ {
addr := operand.Mem{Base: buf, Disp: i * 8}
// It is possible to add the multiple dependency chains trick here
// that generateAMD64Checksum uses but anecdotally it does not
// appear to outweigh the cost.
if i == 0 {
ADDQ(addr, accum64)
continue
}
ADCQ(addr, accum64)
}
ADCQ(operand.Imm(0), carries)
if curBytes > math.MaxUint8 {
ADDQ(operand.U32(uint64(curBytes)), buf)
} else {
ADDQ(operand.U8(uint64(curBytes)), buf)
}
if curBytes*2 >= max {
continue
}
JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
}
Label("handleRemainingComplete")
ADDQ(carries, accum64)
}
func accumulateCF(accum64 reg.GPVirtual) {
Comment("accumulate CF (twice, in case the first time overflows)")
// accum64 += CF
ADCQ(operand.Imm(0), accum64)
// accum64 += CF again if the previous add overflowed. The previous add was
// 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can
// never overflow.
ADCQ(operand.Imm(0), accum64)
}
// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value
// according to ones-complement arithmetic. BMI2 instructions will be used if
// allowBMI2 is true (requires fewer instructions).
func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) {
Comment("add CF and fold")
// CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff
tmp := GP64()
if allowBMI2 {
// effectively, tmp = accum >> 32 (technically, this is a rotate)
RORXQ(operand.Imm(32), accum, tmp)
// accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set
ADCL(tmp.As32(), accum.As32())
// effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate)
RORXL(operand.Imm(16), accum.As32(), tmp.As32())
// accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set
ADCW(tmp.As16(), accum.As16())
} else {
// tmp = uint32(accum); max value 0xffff_ffff
// MOVL clears the upper 32 bits of a GP64 so this is equivalent to the
// non-existent MOVLQZX.
MOVL(accum.As32(), tmp.As32())
// tmp += CF; max value 0x1_0000_0000, CF unset
ADCQ(operand.Imm(0), tmp)
// accum = accum >> 32; max value 0xffff_ffff
SHRQ(operand.Imm(32), accum)
// accum = accum + tmp; max value 0x1_ffff_ffff + CF unset
ADDQ(tmp, accum)
// tmp = uint16(accum); max value 0xffff
MOVWQZX(accum.As16(), tmp)
// accum = accum >> 16; max value 0x1_ffff
SHRQ(operand.Imm(16), accum)
// accum = accum + tmp; max value 0x2_fffe + CF unset
ADDQ(tmp, accum)
// tmp as uint16 = uint16(accum); max value 0xffff
MOVW(accum.As16(), tmp.As16())
// accum = accum >> 16; max value 0x2
SHRQ(operand.Imm(16), accum)
// accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set
ADDW(tmp.As16(), accum.As16())
}
// accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe
ADCW(operand.Imm(0), accum.As16())
}
func generateLoadMasks() {
var offset int
// xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14
GLOBL("xmmLoadMasks", RODATA|NOPTR)
for n := 2; n < 16; n += 2 {
var pattern [16]byte
for i := 0; i < len(pattern); i++ {
if i < len(pattern)-n {
pattern[i] = 0
continue
}
pattern[i] = 0xff
}
DATA(offset, operand.String(pattern[:]))
offset += len(pattern)
}
}
func main() {
generateLoadMasks()
generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2)
generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2)
generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions")
Generate()
}