mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 11:57:39 +03:00
Add test for copy waiter
This commit is contained in:
parent
b7a631f798
commit
57b8a4c64a
3 changed files with 297 additions and 4 deletions
77
common/bufio/copy_direct_test.go
Normal file
77
common/bufio/copy_direct_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package bufio
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopyWaitTCP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
readWaiter, created := CreateReadWaiter(outputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
|
||||
Conn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}))
|
||||
}
|
||||
|
||||
type readWaitWrapper struct {
|
||||
net.Conn
|
||||
readWaiter N.ReadWaiter
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
|
||||
if r.buffer != nil {
|
||||
if r.buffer.Len() > 0 {
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Release()
|
||||
r.buffer = nil
|
||||
}
|
||||
}
|
||||
buffer, err := r.readWaiter.WaitReadBuffer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.buffer = buffer
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
func TestCopyWaitUDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
|
||||
require.True(t, created)
|
||||
require.NotNil(t, readWaiter)
|
||||
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
|
||||
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
|
||||
PacketConn: outputConn,
|
||||
readWaiter: readWaiter,
|
||||
}, outputAddr))
|
||||
}
|
||||
|
||||
type packetReadWaitWrapper struct {
|
||||
net.PacketConn
|
||||
readWaiter N.PacketReadWaiter
|
||||
}
|
||||
|
||||
func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer, destination, err := r.readWaiter.WaitReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(p, buffer.Bytes())
|
||||
buffer.Release()
|
||||
addr = destination.UDPAddr()
|
||||
return
|
||||
}
|
|
@ -2,13 +2,19 @@ package bufio
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -33,6 +39,10 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
|
|||
err = group.Run()
|
||||
require.NoError(t, err)
|
||||
listener.Close()
|
||||
t.Cleanup(func() {
|
||||
serverConn.Close()
|
||||
clientConn.Close()
|
||||
})
|
||||
return serverConn, clientConn
|
||||
}
|
||||
|
||||
|
@ -56,3 +66,212 @@ func Timeout(t *testing.T) context.CancelFunc {
|
|||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
type hashPair struct {
|
||||
sendHash map[int][]byte
|
||||
recvHash map[int][]byte
|
||||
}
|
||||
|
||||
func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
|
||||
pingCh := make(chan hashPair)
|
||||
pongCh := make(chan hashPair)
|
||||
test := func(t *testing.T) error {
|
||||
defer close(pingCh)
|
||||
defer close(pongCh)
|
||||
pingOpen := false
|
||||
pongOpen := false
|
||||
var serverPair hashPair
|
||||
var clientPair hashPair
|
||||
|
||||
for {
|
||||
if pingOpen && pongOpen {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case serverPair, pingOpen = <-pingCh:
|
||||
assert.True(t, pingOpen)
|
||||
case clientPair, pongOpen = <-pongCh:
|
||||
assert.True(t, pongOpen)
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
|
||||
assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return pingCh, pongCh, test
|
||||
}
|
||||
|
||||
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
|
||||
times := 100
|
||||
chunkSize := int64(64 * 1024)
|
||||
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(conn net.Conn) (map[int][]byte, error) {
|
||||
buf := make([]byte, chunkSize)
|
||||
hashMap := map[int][]byte{}
|
||||
for i := 0; i < times; i++ {
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[i] = hash[:]
|
||||
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err := io.ReadFull(outputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
sendHash, err := writeRandData(outputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, err = io.ReadFull(inputConn, buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
return test(t)
|
||||
}
|
||||
|
||||
func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
|
||||
rAddr := outputAddr.UDPAddr()
|
||||
times := 50
|
||||
chunkSize := 9000
|
||||
pingCh, pongCh, test := newLargeDataPair()
|
||||
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
|
||||
hashMap := map[int][]byte{}
|
||||
mux := sync.Mutex{}
|
||||
for i := 0; i < times; i++ {
|
||||
buf := make([]byte, chunkSize)
|
||||
if _, err := rand.Read(buf[1:]); err != nil {
|
||||
t.Log(err.Error())
|
||||
continue
|
||||
}
|
||||
buf[0] = byte(i)
|
||||
|
||||
hash := md5.Sum(buf)
|
||||
mux.Lock()
|
||||
hashMap[i] = hash[:]
|
||||
mux.Unlock()
|
||||
|
||||
if _, err := pc.WriteTo(buf, addr); err != nil {
|
||||
t.Log(err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return hashMap, nil
|
||||
}
|
||||
go func() {
|
||||
var (
|
||||
lAddr net.Addr
|
||||
err error
|
||||
)
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, lAddr, err = outputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
sendHash, err := writeRandData(outputConn, lAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pingCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sendHash, err := writeRandData(inputConn, rAddr)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hashMap := map[int][]byte{}
|
||||
buf := make([]byte, 64*1024)
|
||||
|
||||
for i := 0; i < times; i++ {
|
||||
_, _, err := inputConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
hash := md5.Sum(buf[:chunkSize])
|
||||
hashMap[int(buf[0])] = hash[:]
|
||||
}
|
||||
|
||||
pongCh <- hashPair{
|
||||
sendHash: sendHash,
|
||||
recvHash: hashMap,
|
||||
}
|
||||
}()
|
||||
|
||||
return test(t)
|
||||
}
|
||||
|
|
|
@ -11,8 +11,6 @@ import (
|
|||
func TestWriteVectorised(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn := TCPPipe(t)
|
||||
defer inputConn.Close()
|
||||
defer outputConn.Close()
|
||||
vectorisedWriter, created := CreateVectorisedWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
|
@ -36,9 +34,8 @@ func TestWriteVectorised(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWriteVectorisedPacket(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputConn, outputConn, outputAddr := UDPPipe(t)
|
||||
defer inputConn.Close()
|
||||
defer outputConn.Close()
|
||||
vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn)
|
||||
require.True(t, created)
|
||||
require.NotNil(t, vectorisedWriter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue