Add test for copy waiter

This commit is contained in:
世界 2023-12-21 20:35:41 +08:00
parent b7a631f798
commit 57b8a4c64a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 297 additions and 4 deletions

View file

@ -0,0 +1,77 @@
package bufio
import (
"net"
"testing"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"github.com/stretchr/testify/require"
)
func TestCopyWaitTCP(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
readWaiter, created := CreateReadWaiter(outputConn)
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
Conn: outputConn,
readWaiter: readWaiter,
}))
}
type readWaitWrapper struct {
net.Conn
readWaiter N.ReadWaiter
buffer *buf.Buffer
}
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
if r.buffer != nil {
if r.buffer.Len() > 0 {
return r.buffer.Read(p)
}
if r.buffer.IsEmpty() {
r.buffer.Release()
r.buffer = nil
}
}
buffer, err := r.readWaiter.WaitReadBuffer()
if err != nil {
return
}
r.buffer = buffer
return r.buffer.Read(p)
}
func TestCopyWaitUDP(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
PacketConn: outputConn,
readWaiter: readWaiter,
}, outputAddr))
}
type packetReadWaitWrapper struct {
net.PacketConn
readWaiter N.PacketReadWaiter
}
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer, destination, err := r.readWaiter.WaitReadPacket()
if err != nil {
return
}
n = copy(p, buffer.Bytes())
buffer.Release()
addr = destination.UDPAddr()
return
}

View file

@ -2,13 +2,19 @@ package bufio
import (
"context"
"crypto/md5"
"crypto/rand"
"errors"
"io"
"net"
"sync"
"testing"
"time"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -33,6 +39,10 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
err = group.Run()
require.NoError(t, err)
listener.Close()
t.Cleanup(func() {
serverConn.Close()
clientConn.Close()
})
return serverConn, clientConn
}
@ -56,3 +66,212 @@ func Timeout(t *testing.T) context.CancelFunc {
}()
return cancel
}
type hashPair struct {
sendHash map[int][]byte
recvHash map[int][]byte
}
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
pingCh := make(chan hashPair)
pongCh := make(chan hashPair)
test := func(t *testing.T) error {
defer close(pingCh)
defer close(pongCh)
pingOpen := false
pongOpen := false
var serverPair hashPair
var clientPair hashPair
for {
if pingOpen && pongOpen {
break
}
select {
case serverPair, pingOpen = <-pingCh:
assert.True(t, pingOpen)
case clientPair, pongOpen = <-pongCh:
assert.True(t, pongOpen)
case <-time.After(10 * time.Second):
return errors.New("timeout")
}
}
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
return nil
}
return pingCh, pongCh, test
}
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
times := 100
chunkSize := int64(64 * 1024)
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
buf := make([]byte, chunkSize)
hashMap := map[int][]byte{}
for i := 0; i < times; i++ {
if _, err := rand.Read(buf[1:]); err != nil {
return nil, err
}
buf[0] = byte(i)
hash := md5.Sum(buf)
hashMap[i] = hash[:]
if _, err := conn.Write(buf); err != nil {
return nil, err
}
}
return hashMap, nil
}
go func() {
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err := io.ReadFull(outputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, chunkSize)
for i := 0; i < times; i++ {
_, err = io.ReadFull(inputConn, buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf)
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
rAddr := outputAddr.UDPAddr()
times := 50
chunkSize := 9000
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
hashMap := map[int][]byte{}
mux := sync.Mutex{}
for i := 0; i < times; i++ {
buf := make([]byte, chunkSize)
if _, err := rand.Read(buf[1:]); err != nil {
t.Log(err.Error())
continue
}
buf[0] = byte(i)
hash := md5.Sum(buf)
mux.Lock()
hashMap[i] = hash[:]
mux.Unlock()
if _, err := pc.WriteTo(buf, addr); err != nil {
t.Log(err.Error())
}
time.Sleep(10 * time.Millisecond)
}
return hashMap, nil
}
go func() {
var (
lAddr net.Addr
err error
)
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, lAddr, err = outputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn, lAddr)
if err != nil {
t.Log(err.Error())
return
}
pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
go func() {
sendHash, err := writeRandData(inputConn, rAddr)
if err != nil {
t.Log(err.Error())
return
}
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)
for i := 0; i < times; i++ {
_, _, err := inputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()
return test(t)
}

View file

@ -11,8 +11,6 @@ import (
func TestWriteVectorised(t *testing.T) {
t.Parallel()
inputConn, outputConn := TCPPipe(t)
defer inputConn.Close()
defer outputConn.Close()
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)
@ -36,9 +34,8 @@ func TestWriteVectorised(t *testing.T) {
}
func TestWriteVectorisedPacket(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
defer inputConn.Close()
defer outputConn.Close()
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
require.True(t, created)
require.NotNil(t, vectorisedWriter)