diff --git a/common/bufio/copy_direct_test.go b/common/bufio/copy_direct_test.go new file mode 100644 index 0000000..41fed63 --- /dev/null +++ b/common/bufio/copy_direct_test.go @@ -0,0 +1,77 @@ +package bufio + +import ( + "net" + "testing" + + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/require" +) + +func TestCopyWaitTCP(t *testing.T) { + t.Parallel() + inputConn, outputConn := TCPPipe(t) + readWaiter, created := CreateReadWaiter(outputConn) + require.True(t, created) + require.NotNil(t, readWaiter) + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) + require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{ + Conn: outputConn, + readWaiter: readWaiter, + })) +} + +type readWaitWrapper struct { + net.Conn + readWaiter N.ReadWaiter + buffer *buf.Buffer +} + +func (r *readWaitWrapper) Read(p []byte) (n int, err error) { + if r.buffer != nil { + if r.buffer.Len() > 0 { + return r.buffer.Read(p) + } + if r.buffer.IsEmpty() { + r.buffer.Release() + r.buffer = nil + } + } + buffer, err := r.readWaiter.WaitReadBuffer() + if err != nil { + return + } + r.buffer = buffer + return r.buffer.Read(p) +} + +func TestCopyWaitUDP(t *testing.T) { + t.Parallel() + inputConn, outputConn, outputAddr := UDPPipe(t) + readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn)) + require.True(t, created) + require.NotNil(t, readWaiter) + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) + require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{ + PacketConn: outputConn, + readWaiter: readWaiter, + }, outputAddr)) +} + +type packetReadWaitWrapper struct { + net.PacketConn + readWaiter N.PacketReadWaiter +} + +func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buffer, destination, err := r.readWaiter.WaitReadPacket() + if err != nil { + return + } + n = copy(p, buffer.Bytes()) + buffer.Release() + addr = destination.UDPAddr() + return +} diff --git a/common/bufio/net_test.go b/common/bufio/net_test.go index 6baefac..8642572 100644 --- a/common/bufio/net_test.go +++ b/common/bufio/net_test.go @@ -2,13 +2,19 @@ package bufio import ( "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" "net" + "sync" "testing" "time" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/task" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,6 +39,10 @@ func TCPPipe(t *testing.T) (net.Conn, net.Conn) { err = group.Run() require.NoError(t, err) listener.Close() + t.Cleanup(func() { + serverConn.Close() + clientConn.Close() + }) return serverConn, clientConn } @@ -56,3 +66,212 @@ func Timeout(t *testing.T) context.CancelFunc { }() return cancel } + +type hashPair struct { + sendHash map[int][]byte + recvHash map[int][]byte +} + +func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) { + pingCh := make(chan hashPair) + pongCh := make(chan hashPair) + test := func(t *testing.T) error { + defer close(pingCh) + defer close(pongCh) + pingOpen := false + pongOpen := false + var serverPair hashPair + var clientPair hashPair + + for { + if pingOpen && pongOpen { + break + } + + select { + case serverPair, pingOpen = <-pingCh: + assert.True(t, pingOpen) + case clientPair, pongOpen = <-pongCh: + assert.True(t, pongOpen) + case <-time.After(10 * time.Second): + return errors.New("timeout") + } + } + + assert.Equal(t, serverPair.recvHash, clientPair.sendHash) + assert.Equal(t, serverPair.sendHash, clientPair.recvHash) + + return nil + } + + return pingCh, pongCh, test +} + +func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error { + times := 100 + chunkSize := int64(64 * 1024) + + pingCh, pongCh, test := newLargeDataPair() + writeRandData := func(conn net.Conn) (map[int][]byte, error) { + buf := make([]byte, chunkSize) + hashMap := map[int][]byte{} + for i := 0; i < times; i++ { + if _, err := rand.Read(buf[1:]); err != nil { + return nil, err + } + buf[0] = byte(i) + + hash := md5.Sum(buf) + hashMap[i] = hash[:] + + if _, err := conn.Write(buf); err != nil { + return nil, err + } + } + + return hashMap, nil + } + go func() { + hashMap := map[int][]byte{} + buf := make([]byte, chunkSize) + + for i := 0; i < times; i++ { + _, err := io.ReadFull(outputConn, buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf) + hashMap[int(buf[0])] = hash[:] + } + + sendHash, err := writeRandData(outputConn) + if err != nil { + t.Log(err.Error()) + return + } + + pingCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + go func() { + sendHash, err := writeRandData(inputConn) + if err != nil { + t.Log(err.Error()) + return + } + + hashMap := map[int][]byte{} + buf := make([]byte, chunkSize) + + for i := 0; i < times; i++ { + _, err = io.ReadFull(inputConn, buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf) + hashMap[int(buf[0])] = hash[:] + } + + pongCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + return test(t) +} + +func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error { + rAddr := outputAddr.UDPAddr() + times := 50 + chunkSize := 9000 + pingCh, pongCh, test := newLargeDataPair() + writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) { + hashMap := map[int][]byte{} + mux := sync.Mutex{} + for i := 0; i < times; i++ { + buf := make([]byte, chunkSize) + if _, err := rand.Read(buf[1:]); err != nil { + t.Log(err.Error()) + continue + } + buf[0] = byte(i) + + hash := md5.Sum(buf) + mux.Lock() + hashMap[i] = hash[:] + mux.Unlock() + + if _, err := pc.WriteTo(buf, addr); err != nil { + t.Log(err.Error()) + } + + time.Sleep(10 * time.Millisecond) + } + + return hashMap, nil + } + go func() { + var ( + lAddr net.Addr + err error + ) + hashMap := map[int][]byte{} + buf := make([]byte, 64*1024) + + for i := 0; i < times; i++ { + _, lAddr, err = outputConn.ReadFrom(buf) + if err != nil { + t.Log(err.Error()) + return + } + hash := md5.Sum(buf[:chunkSize]) + hashMap[int(buf[0])] = hash[:] + } + sendHash, err := writeRandData(outputConn, lAddr) + if err != nil { + t.Log(err.Error()) + return + } + + pingCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + go func() { + sendHash, err := writeRandData(inputConn, rAddr) + if err != nil { + t.Log(err.Error()) + return + } + + hashMap := map[int][]byte{} + buf := make([]byte, 64*1024) + + for i := 0; i < times; i++ { + _, _, err := inputConn.ReadFrom(buf) + if err != nil { + t.Log(err.Error()) + return + } + + hash := md5.Sum(buf[:chunkSize]) + hashMap[int(buf[0])] = hash[:] + } + + pongCh <- hashPair{ + sendHash: sendHash, + recvHash: hashMap, + } + }() + + return test(t) +} diff --git a/common/bufio/vectorised_test.go b/common/bufio/vectorised_test.go index 8dccf63..7d2e42d 100644 --- a/common/bufio/vectorised_test.go +++ b/common/bufio/vectorised_test.go @@ -11,8 +11,6 @@ import ( func TestWriteVectorised(t *testing.T) { t.Parallel() inputConn, outputConn := TCPPipe(t) - defer inputConn.Close() - defer outputConn.Close() vectorisedWriter, created := CreateVectorisedWriter(inputConn) require.True(t, created) require.NotNil(t, vectorisedWriter) @@ -36,9 +34,8 @@ func TestWriteVectorised(t *testing.T) { } func TestWriteVectorisedPacket(t *testing.T) { + t.Parallel() inputConn, outputConn, outputAddr := UDPPipe(t) - defer inputConn.Close() - defer outputConn.Close() vectorisedWriter, created := CreateVectorisedPacketWriter(inputConn) require.True(t, created) require.NotNil(t, vectorisedWriter)