uquic/integrationtests/tools/proxy/proxy_test.go
2024-09-21 05:15:13 -07:00

424 lines
12 KiB
Go

package quicproxy
import (
"bytes"
"fmt"
"net"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
type packetData []byte
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
t.Helper()
hdr := wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
Version: protocol.Version1,
Length: 4 + protocol.ByteCount(len(payload)),
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
},
PacketNumber: p,
PacketNumberLen: protocol.PacketNumberLen4,
}
b, err := hdr.Append(nil, protocol.Version1)
require.NoError(t, err)
b = append(b, payload...)
return b
}
func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
t.Helper()
hdr, data, _, err := wire.ParsePacket(b)
require.NoError(t, err)
require.Equal(t, protocol.PacketTypeInitial, hdr.Type)
extHdr, err := hdr.ParseExtended(data)
require.NoError(t, err)
return extHdr.PacketNumber
}
func TestProxySetupt(t *testing.T) {
proxy, err := NewQuicProxy("localhost:0", nil)
require.NoError(t, err)
require.Len(t, proxy.clientDict, 0)
// Check that the proxy port is in use
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
require.NoError(t, err)
_, err = net.ListenUDP("udp", addr)
if runtime.GOOS == "windows" {
require.EqualError(t, err, fmt.Sprintf("listen udp 127.0.0.1:%d: bind: Only one usage of each socket address (protocol/network address/port) is normally permitted.", proxy.LocalPort()))
} else {
require.EqualError(t, err, fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort()))
}
require.Equal(t, "127.0.0.1:"+strconv.Itoa(proxy.LocalPort()), proxy.LocalAddr().String())
require.NotZero(t, proxy.LocalPort())
require.NoError(t, proxy.Close())
}
func TestProxyShutdown(t *testing.T) {
isProxyRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
}
proxy, err := NewQuicProxy("localhost:0", nil)
require.NoError(t, err)
port := proxy.LocalPort()
require.Eventually(t, func() bool { return isProxyRunning() }, time.Second, 10*time.Millisecond)
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
_, err = conn.Write(makePacket(t, 1, []byte("foobar")))
require.NoError(t, err)
require.NoError(t, proxy.Close())
// check that the proxy port is not in use anymore
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
require.NoError(t, err)
// sometimes it takes a while for the OS to free the port
require.Eventually(t, func() bool {
ln, err := net.ListenUDP("udp", addr)
if err != nil {
return false
}
ln.Close()
return true
}, time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool { return !isProxyRunning() }, time.Second, 10*time.Millisecond)
}
// Set up a dumb UDP server.
// In production this would be a QUIC server.
func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
serverReceivedPackets := make(chan packetData, 100)
done := make(chan struct{})
go func() {
defer close(done)
for {
buf := make([]byte, protocol.MaxPacketBufferSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, addr, err := serverConn.ReadFromUDP(buf)
if err != nil {
return
}
data := buf[:n]
serverReceivedPackets <- packetData(data)
if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet
return
}
}
}()
t.Cleanup(func() {
require.NoError(t, serverConn.Close())
select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
})
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
}
func startProxy(t *testing.T, opts *Opts) (clientConn *net.UDPConn) {
proxy, err := NewQuicProxy("localhost:0", opts)
require.NoError(t, err)
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, proxy.Close())
require.NoError(t, clientConn.Close())
})
return clientConn
}
func TestProxyyingBackAndForth(t *testing.T) {
serverAddr, _ := runServer(t)
clientConn := startProxy(t, &Opts{RemoteAddr: serverAddr.String()})
// send the first packet
_, err := clientConn.Write(makePacket(t, 1, []byte("foobar")))
require.NoError(t, err)
// send the second packet
_, err = clientConn.Write(makePacket(t, 2, []byte("decafbad")))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := clientConn.Read(buf)
require.NoError(t, err)
require.Contains(t, string(buf[:n]), "foobar")
n, err = clientConn.Read(buf)
require.NoError(t, err)
require.Contains(t, string(buf[:n]), "decafbad")
}
func TestDropIncomingPackets(t *testing.T) {
const numPackets = 6
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionIncoming {
return false
}
return counter.Add(1)%2 == 1
},
})
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
for i := 0; i < numPackets/2; i++ {
select {
case <-serverReceivedPackets:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
select {
case <-serverReceivedPackets:
t.Fatalf("received unexpected packet")
case <-time.After(100 * time.Millisecond):
}
}
func TestDropOutgoingPackets(t *testing.T) {
const numPackets = 6
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionOutgoing {
return false
}
return counter.Add(1)%2 == 1
},
})
clientReceivedPackets := make(chan struct{}, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxPacketBufferSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
return
}
clientReceivedPackets <- struct{}{}
}
}()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
for i := 0; i < numPackets/2; i++ {
select {
case <-clientReceivedPackets:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
select {
case <-clientReceivedPackets:
t.Fatalf("received unexpected packet")
case <-time.After(100 * time.Millisecond):
}
require.Len(t, serverReceivedPackets, numPackets)
}
func TestDelayIncomingPackets(t *testing.T) {
const numPackets = 3
const delay = 200 * time.Millisecond
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
p := counter.Add(1)
return time.Duration(p) * delay
},
})
start := time.Now()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
for i := 1; i <= numPackets; i++ {
select {
case data := <-serverReceivedPackets:
require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2)
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, data))
case <-time.After(time.Second):
t.Fatalf("timeout waiting for packet %d", i)
}
}
}
func TestPacketReordering(t *testing.T) {
const delay = 200 * time.Millisecond
expectDelay := func(startTime time.Time, numRTTs int) {
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay)
now := time.Now()
require.True(t, now.After(expectedReceiveTime) || now.Equal(expectedReceiveTime))
require.True(t, now.Before(expectedReceiveTime.Add(delay/2)))
}
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 600 ms
// delay packet 2 by 400 ms
// delay packet 3 by 200 ms
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
p := counter.Add(1)
return 600*time.Millisecond - time.Duration(p-1)*delay
},
})
// send 3 packets
start := time.Now()
for i := 1; i <= 3; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
for i := 1; i <= 3; i++ {
select {
case packet := <-serverReceivedPackets:
expectDelay(start, i)
expectedPacketNumber := protocol.PacketNumber(4 - i) // 3, 2, 1 in reverse order
require.Equal(t, expectedPacketNumber, readPacketNumber(t, packet))
case <-time.After(time.Second):
t.Fatalf("timeout waiting for packet %d", i)
}
}
}
func TestConstantDelay(t *testing.T) { // no reordering expected here
serverAddr, serverReceivedPackets := runServer(t)
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
return 100 * time.Millisecond
},
})
// send 100 packets
for i := 0; i < 100; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
require.Eventually(t, func() bool { return len(serverReceivedPackets) == 100 }, 5*time.Second, 10*time.Millisecond)
timeout := time.After(5 * time.Second)
for i := 0; i < 100; i++ {
select {
case packet := <-serverReceivedPackets:
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet))
case <-timeout:
t.Fatalf("timeout waiting for packet %d", i)
}
}
}
func TestDelayOutgoingPackets(t *testing.T) {
const numPackets = 3
const delay = 200 * time.Millisecond
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionIncoming {
return 0
}
p := counter.Add(1)
return time.Duration(p) * delay
},
})
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxPacketBufferSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err := clientConn.ReadFromUDP(buf)
if err != nil {
return
}
clientReceivedPackets <- packetData(buf[0:n])
}
}()
start := time.Now()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
require.NoError(t, err)
}
// the packets should have arrived immediately at the server
for i := 0; i < numPackets; i++ {
select {
case <-serverReceivedPackets:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
require.WithinDuration(t, start, time.Now(), delay/2)
for i := 1; i <= numPackets; i++ {
select {
case packet := <-clientReceivedPackets:
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet))
require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2)
case <-time.After(time.Second):
t.Fatalf("timeout waiting for packet %d", i)
}
}
}