Improve vectorised writer

This commit is contained in:
世界 2023-12-14 17:42:52 +08:00
parent edd320c3a8
commit 2e36fa6849
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 432 additions and 55 deletions

View file

@ -14,21 +14,17 @@ on:
jobs:
build:
name: Debug build
name: Linux Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
go-version: ">=1.21.0 <1.22.0"
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
@ -41,3 +37,83 @@ jobs:
- name: Build
run: |
make test
build_go118:
name: Linux Debug build (Go 1.18)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ">=1.18.0 <1.19.0"
continue-on-error: true
- name: Build
run: |
make test
build_go119:
name: Linux Debug build (Go 1.19)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ">=1.19.0 <1.20.0"
continue-on-error: true
- name: Build
run: |
make test
build_go120:
name: Linux Debug build (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ">=1.20.0 <1.21.0"
continue-on-error: true
- name: Build
run: |
make test
build__windows:
name: Windows Debug build
runs-on: windows-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ">=1.21.0 <1.22.0"
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS Debug build
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ">=1.21.0 <1.22.0"
continue-on-error: true
- name: Build
run: |
make test

View file

@ -18,4 +18,4 @@ lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test -v ./...
go test $(shell go list ./... | grep -v /internal/)

34
common/bufio/addr_bsd.go Normal file
View file

@ -0,0 +1,34 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen uint32) {
if destination.Addr().Is4() {
sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet4
} else {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = unix.SizeofSockaddrInet6
}
return
}

View file

@ -0,0 +1,30 @@
package bufio
import (
"encoding/binary"
"net/netip"
"unsafe"
"golang.org/x/sys/windows"
)
func ToSockaddr(destination netip.AddrPort) (name unsafe.Pointer, nameLen int32) {
if destination.Addr().Is4() {
sa := windows.RawSockaddrInet4{
Family: windows.AF_INET,
Addr: destination.Addr().As4(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
} else {
sa := windows.RawSockaddrInet6{
Family: windows.AF_INET6,
Addr: destination.Addr().As16(),
}
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], destination.Port())
name = unsafe.Pointer(&sa)
nameLen = int32(unsafe.Sizeof(sa))
}
return
}

60
common/bufio/net_test.go Normal file
View file

@ -0,0 +1,60 @@
package bufio
import (
"context"
"net"
"testing"
"time"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/stretchr/testify/require"
)
func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
var (
group task.Group
serverConn net.Conn
clientConn net.Conn
)
group.Append0(func(ctx context.Context) error {
var serverErr error
serverConn, serverErr = listener.Accept()
require.NoError(t, serverErr)
return nil
})
group.Append0(func(ctx context.Context) error {
var clientErr error
clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
require.NoError(t, clientErr)
return nil
})
err = group.Run()
require.NoError(t, err)
listener.Close()
return serverConn, clientConn
}
func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
}
func Timeout(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}()
return cancel
}

View file

@ -33,10 +33,10 @@ func CreateVectorisedWriter(writer any) (N.VectorisedWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedWriter{writer, rawConn}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedWriter{writer, w}, true
return &SyscallVectorisedWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -48,10 +48,10 @@ func CreateVectorisedPacketWriter(writer any) (N.VectorisedPacketWriter, bool) {
case syscall.Conn:
rawConn, err := w.SyscallConn()
if err == nil {
return &SyscallVectorisedPacketWriter{writer, rawConn}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: rawConn}, true
}
case syscall.RawConn:
return &SyscallVectorisedPacketWriter{writer, w}, true
return &SyscallVectorisedPacketWriter{upstream: writer, rawConn: w}, true
}
return nil, false
}
@ -111,6 +111,7 @@ var _ N.VectorisedWriter = (*SyscallVectorisedWriter)(nil)
type SyscallVectorisedWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedWriter) Upstream() any {
@ -126,6 +127,7 @@ var _ N.VectorisedPacketWriter = (*SyscallVectorisedPacketWriter)(nil)
type SyscallVectorisedPacketWriter struct {
upstream any
rawConn syscall.RawConn
syscallVectorisedWriterFields
}
func (w *SyscallVectorisedPacketWriter) Upstream() any {

View file

@ -0,0 +1,63 @@
package bufio
import (
"crypto/rand"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestWriteVectorised(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
defer inputConn.Close()
defer outputConn.Close()
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorised(vectorisedWriter, [][]byte{bufA[:], bufB[:]})
require.NoError(t, err)
output := make([]byte, 2048)
_, err = io.ReadFull(outputConn, output)
finish()
require.NoError(t, err)
require.Equal(t, bufC[:], output)
}
func TestWriteVectorisedPacket(t *testing.T) {
inputConn, outputConn, outputAddr := UDPPipe(t)
defer inputConn.Close()
defer outputConn.Close()
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
var bufA [1024]byte
var bufB [1024]byte
var bufC [2048]byte
_, err := io.ReadFull(rand.Reader, bufA[:])
require.NoError(t, err)
_, err = io.ReadFull(rand.Reader, bufB[:])
require.NoError(t, err)
copy(bufC[:], bufA[:])
copy(bufC[1024:], bufB[:])
finish := Timeout(t)
_, err = WriteVectorisedPacket(vectorisedWriter, [][]byte{bufA[:], bufB[:]}, outputAddr)
require.NoError(t, err)
output := make([]byte, 2048)
n, _, err := outputConn.ReadFrom(output)
finish()
require.NoError(t, err)
require.Equal(t, 2048, n)
require.Equal(t, bufC[:], output)
}

View file

@ -3,6 +3,8 @@
package bufio
import (
"os"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf"
@ -11,15 +13,28 @@ import (
"golang.org/x/sys/unix"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]unix.Iovec
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]unix.Iovec, 0, len(buffers))
for _, buffer := range buffers {
var iovec unix.Iovec
iovec.Base = &buffer.Bytes()[0]
iovec.SetLen(buffer.Len())
iovecList = append(iovecList, iovec)
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
@ -28,28 +43,42 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != 0 {
err = innerErr
err = os.NewSyscallError("SYS_WRITEV", innerErr)
}
for index := range iovecList {
iovecList[index] = unix.Iovec{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
var sockaddr unix.Sockaddr
if destination.IsIPv4() {
sockaddr = &unix.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &unix.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
var iovecList []unix.Iovec
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for index, buffer := range buffers {
iovecList = append(iovecList, unix.Iovec{Base: &buffer.Bytes()[0]})
iovecList[index].SetLen(buffer.Len())
}
if w.iovecList == nil {
w.iovecList = new([]unix.Iovec)
}
*w.iovecList = iovecList // cache
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
_, innerErr = unix.SendmsgBuffers(int(fd), buf.ToSliceMulti(buffers), nil, sockaddr, 0)
var msg unix.Msghdr
name, nameLen := ToSockaddr(destination.AddrPort())
msg.Name = (*byte)(name)
msg.Namelen = nameLen
if len(iovecList) > 0 {
msg.Iov = &iovecList[0]
msg.SetIovlen(len(iovecList))
}
_, innerErr = sendmsg(int(fd), &msg, 0)
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
@ -57,3 +86,6 @@ func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buf
}
return err
}
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
func sendmsg(s int, msg *unix.Msghdr, flags int) (n int, err error)

View file

@ -1,62 +1,93 @@
package bufio
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/windows"
)
type syscallVectorisedWriterFields struct {
access sync.Mutex
iovecList *[]windows.WSABuf
}
func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
Buf: &buffer.Bytes()[0],
}
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
Len: uint32(buffer.Len()),
})
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASend(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
innerErr = windows.WSASend(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &n, 0, nil, nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}
func (w *SyscallVectorisedPacketWriter) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
w.access.Lock()
defer w.access.Unlock()
defer buf.ReleaseMulti(buffers)
iovecList := make([]*windows.WSABuf, len(buffers))
for i, buffer := range buffers {
iovecList[i] = &windows.WSABuf{
Len: uint32(buffer.Len()),
var iovecList []windows.WSABuf
if w.iovecList != nil {
iovecList = *w.iovecList
}
iovecList = iovecList[:0]
for _, buffer := range buffers {
iovecList = append(iovecList, windows.WSABuf{
Buf: &buffer.Bytes()[0],
}
Len: uint32(buffer.Len()),
})
}
var sockaddr windows.Sockaddr
if destination.IsIPv4() {
sockaddr = &windows.SockaddrInet4{
Port: int(destination.Port),
Addr: destination.Addr.As4(),
}
} else {
sockaddr = &windows.SockaddrInet6{
Port: int(destination.Port),
Addr: destination.Addr.As16(),
}
if w.iovecList == nil {
w.iovecList = new([]windows.WSABuf)
}
*w.iovecList = iovecList // cache
var n uint32
var innerErr error
err := w.rawConn.Write(func(fd uintptr) (done bool) {
innerErr = windows.WSASendto(windows.Handle(fd), iovecList[0], uint32(len(iovecList)), &n, 0, sockaddr, nil, nil)
name, nameLen := ToSockaddr(destination.AddrPort())
innerErr = windows.WSASendTo(
windows.Handle(fd),
&iovecList[0],
uint32(len(iovecList)),
&n,
0,
(*windows.RawSockaddrAny)(name),
nameLen,
nil,
nil)
return innerErr != windows.WSAEWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
for index := range iovecList {
iovecList[index] = windows.WSABuf{}
}
return err
}

11
go.mod
View file

@ -2,4 +2,13 @@ module github.com/sagernet/sing
go 1.18
require golang.org/x/sys v0.15.0
require (
github.com/stretchr/testify v1.8.4
golang.org/x/sys v0.15.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View file

@ -1,2 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=