mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
494 lines
14 KiB
Go
494 lines
14 KiB
Go
package quicproxy
|
|
|
|
import (
|
|
"net"
|
|
"strconv"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestPacketQueue(t *testing.T) {
|
|
q := newQueue()
|
|
|
|
getPackets := func() []string {
|
|
packets := make([]string, 0, len(q.Packets))
|
|
for _, p := range q.Packets {
|
|
packets = append(packets, string(p.Raw))
|
|
}
|
|
return packets
|
|
}
|
|
|
|
require.Empty(t, getPackets())
|
|
now := time.Now()
|
|
|
|
q.Add(packetEntry{Time: now, Raw: []byte("p3")})
|
|
require.Equal(t, []string{"p3"}, getPackets())
|
|
q.Add(packetEntry{Time: now.Add(time.Second), Raw: []byte("p4")})
|
|
require.Equal(t, []string{"p3", "p4"}, getPackets())
|
|
q.Add(packetEntry{Time: now.Add(-time.Second), Raw: []byte("p1")})
|
|
require.Equal(t, []string{"p1", "p3", "p4"}, getPackets())
|
|
q.Add(packetEntry{Time: now.Add(time.Second), Raw: []byte("p5")})
|
|
require.Equal(t, []string{"p1", "p3", "p4", "p5"}, getPackets())
|
|
q.Add(packetEntry{Time: now.Add(-time.Second), Raw: []byte("p2")})
|
|
require.Equal(t, []string{"p1", "p2", "p3", "p4", "p5"}, getPackets())
|
|
}
|
|
|
|
func newUPDConnLocalhost(t testing.TB) *net.UDPConn {
|
|
t.Helper()
|
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { conn.Close() })
|
|
return conn
|
|
}
|
|
|
|
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
|
|
t.Helper()
|
|
hdr := wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
Version: protocol.Version1,
|
|
Length: 4 + protocol.ByteCount(len(payload)),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}),
|
|
},
|
|
PacketNumber: p,
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
}
|
|
b, err := hdr.Append(nil, protocol.Version1)
|
|
require.NoError(t, err)
|
|
b = append(b, payload...)
|
|
return b
|
|
}
|
|
|
|
func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
|
|
t.Helper()
|
|
hdr, data, _, err := wire.ParsePacket(b)
|
|
require.NoError(t, err)
|
|
require.Equal(t, protocol.PacketTypeInitial, hdr.Type)
|
|
extHdr, err := hdr.ParseExtended(data)
|
|
require.NoError(t, err)
|
|
return extHdr.PacketNumber
|
|
}
|
|
|
|
// Set up a dumb UDP server.
|
|
// In production this would be a QUIC server.
|
|
func runServer(t *testing.T) (*net.UDPAddr, chan []byte) {
|
|
serverConn := newUPDConnLocalhost(t)
|
|
|
|
serverReceivedPackets := make(chan []byte, 100)
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for {
|
|
buf := make([]byte, protocol.MaxPacketBufferSize)
|
|
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
n, addr, err := serverConn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
serverReceivedPackets <- buf[:n]
|
|
// echo the packet
|
|
if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
})
|
|
|
|
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
|
|
}
|
|
|
|
func TestProxyingBackAndForth(t *testing.T) {
|
|
serverAddr, _ := runServer(t)
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
// send the first packet
|
|
_, err = clientConn.Write(makePacket(t, 1, []byte("foobar")))
|
|
require.NoError(t, err)
|
|
// send the second packet
|
|
_, err = clientConn.Write(makePacket(t, 2, []byte("decafbad")))
|
|
require.NoError(t, err)
|
|
|
|
buf := make([]byte, 1024)
|
|
n, err := clientConn.Read(buf)
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(buf[:n]), "foobar")
|
|
n, err = clientConn.Read(buf)
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(buf[:n]), "decafbad")
|
|
}
|
|
|
|
func TestDropIncomingPackets(t *testing.T) {
|
|
const numPackets = 6
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
var counter atomic.Int32
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DropPacket: func(d Direction, _ []byte) bool {
|
|
if d != DirectionIncoming {
|
|
return false
|
|
}
|
|
return counter.Add(1)%2 == 1
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
for i := 1; i <= numPackets; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
for i := 0; i < numPackets/2; i++ {
|
|
select {
|
|
case <-serverReceivedPackets:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
select {
|
|
case <-serverReceivedPackets:
|
|
t.Fatalf("received unexpected packet")
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestDropOutgoingPackets(t *testing.T) {
|
|
const numPackets = 6
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
var counter atomic.Int32
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DropPacket: func(d Direction, _ []byte) bool {
|
|
if d != DirectionOutgoing {
|
|
return false
|
|
}
|
|
return counter.Add(1)%2 == 1
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
clientReceivedPackets := make(chan struct{}, numPackets)
|
|
// receive the packets echoed by the server on client side
|
|
go func() {
|
|
for {
|
|
buf := make([]byte, protocol.MaxPacketBufferSize)
|
|
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
|
|
return
|
|
}
|
|
clientReceivedPackets <- struct{}{}
|
|
}
|
|
}()
|
|
|
|
for i := 1; i <= numPackets; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
for i := 0; i < numPackets/2; i++ {
|
|
select {
|
|
case <-clientReceivedPackets:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
select {
|
|
case <-clientReceivedPackets:
|
|
t.Fatalf("received unexpected packet")
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
require.Len(t, serverReceivedPackets, numPackets)
|
|
}
|
|
|
|
func TestDelayIncomingPackets(t *testing.T) {
|
|
const numPackets = 3
|
|
const delay = 200 * time.Millisecond
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
var counter atomic.Int32
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DelayPacket: func(d Direction, _ []byte) time.Duration {
|
|
// delay packet 1 by 200 ms
|
|
// delay packet 2 by 400 ms
|
|
// ...
|
|
if d == DirectionOutgoing {
|
|
return 0
|
|
}
|
|
p := counter.Add(1)
|
|
return time.Duration(p) * delay
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
start := time.Now()
|
|
for i := 1; i <= numPackets; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
for i := 1; i <= numPackets; i++ {
|
|
select {
|
|
case data := <-serverReceivedPackets:
|
|
require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2)
|
|
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, data))
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for packet %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPacketReordering(t *testing.T) {
|
|
const delay = 200 * time.Millisecond
|
|
expectDelay := func(startTime time.Time, numRTTs int) {
|
|
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay)
|
|
now := time.Now()
|
|
require.True(t, now.After(expectedReceiveTime) || now.Equal(expectedReceiveTime))
|
|
require.True(t, now.Before(expectedReceiveTime.Add(delay/2)))
|
|
}
|
|
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
var counter atomic.Int32
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DelayPacket: func(d Direction, _ []byte) time.Duration {
|
|
// delay packet 1 by 600 ms
|
|
// delay packet 2 by 400 ms
|
|
// delay packet 3 by 200 ms
|
|
if d == DirectionOutgoing {
|
|
return 0
|
|
}
|
|
p := counter.Add(1)
|
|
return 600*time.Millisecond - time.Duration(p-1)*delay
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
// send 3 packets
|
|
start := time.Now()
|
|
for i := 1; i <= 3; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
for i := 1; i <= 3; i++ {
|
|
select {
|
|
case packet := <-serverReceivedPackets:
|
|
expectDelay(start, i)
|
|
expectedPacketNumber := protocol.PacketNumber(4 - i) // 3, 2, 1 in reverse order
|
|
require.Equal(t, expectedPacketNumber, readPacketNumber(t, packet))
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for packet %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConstantDelay(t *testing.T) { // no reordering expected here
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DelayPacket: func(d Direction, _ []byte) time.Duration {
|
|
if d == DirectionOutgoing {
|
|
return 0
|
|
}
|
|
return 100 * time.Millisecond
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
// send 100 packets
|
|
for i := 0; i < 100; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
require.Eventually(t, func() bool { return len(serverReceivedPackets) == 100 }, 5*time.Second, 10*time.Millisecond)
|
|
timeout := time.After(5 * time.Second)
|
|
for i := 0; i < 100; i++ {
|
|
select {
|
|
case packet := <-serverReceivedPackets:
|
|
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet))
|
|
case <-timeout:
|
|
t.Fatalf("timeout waiting for packet %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDelayOutgoingPackets(t *testing.T) {
|
|
const numPackets = 3
|
|
const delay = 200 * time.Millisecond
|
|
|
|
serverAddr, serverReceivedPackets := runServer(t)
|
|
var counter atomic.Int32
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverAddr,
|
|
DelayPacket: func(d Direction, _ []byte) time.Duration {
|
|
// delay packet 1 by 200 ms
|
|
// delay packet 2 by 400 ms
|
|
// ...
|
|
if d == DirectionIncoming {
|
|
return 0
|
|
}
|
|
p := counter.Add(1)
|
|
return time.Duration(p) * delay
|
|
},
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
clientReceivedPackets := make(chan []byte, numPackets)
|
|
// receive the packets echoed by the server on client side
|
|
go func() {
|
|
for {
|
|
buf := make([]byte, protocol.MaxPacketBufferSize)
|
|
n, _, err := clientConn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
clientReceivedPackets <- buf[:n]
|
|
}
|
|
}()
|
|
|
|
start := time.Now()
|
|
for i := 1; i <= numPackets; i++ {
|
|
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
require.NoError(t, err)
|
|
}
|
|
// the packets should have arrived immediately at the server
|
|
for i := 0; i < numPackets; i++ {
|
|
select {
|
|
case <-serverReceivedPackets:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
require.WithinDuration(t, start, time.Now(), delay/2)
|
|
|
|
for i := 1; i <= numPackets; i++ {
|
|
select {
|
|
case packet := <-clientReceivedPackets:
|
|
require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet))
|
|
require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for packet %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProxySwitchConn(t *testing.T) {
|
|
serverConn := newUPDConnLocalhost(t)
|
|
|
|
type packet struct {
|
|
Data []byte
|
|
Addr *net.UDPAddr
|
|
}
|
|
|
|
serverReceivedPackets := make(chan packet, 1)
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for {
|
|
buf := make([]byte, 1000)
|
|
n, addr, err := serverConn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
serverReceivedPackets <- packet{Data: buf[:n], Addr: addr}
|
|
}
|
|
}()
|
|
|
|
proxy := Proxy{
|
|
Conn: newUPDConnLocalhost(t),
|
|
ServerAddr: serverConn.LocalAddr().(*net.UDPAddr),
|
|
}
|
|
require.NoError(t, proxy.Start())
|
|
defer proxy.Close()
|
|
|
|
clientConn := newUPDConnLocalhost(t)
|
|
_, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
clientConn.SetReadDeadline(time.Now().Add(time.Second))
|
|
|
|
var firstConnAddr *net.UDPAddr
|
|
select {
|
|
case p := <-serverReceivedPackets:
|
|
require.Equal(t, "hello", string(p.Data))
|
|
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
|
firstConnAddr = p.Addr
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
_, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr)
|
|
require.NoError(t, err)
|
|
buf := make([]byte, 1000)
|
|
n, addr, err := clientConn.ReadFromUDP(buf)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "hi", string(buf[:n]))
|
|
require.Equal(t, proxy.LocalAddr(), addr)
|
|
|
|
newConn := newUPDConnLocalhost(t)
|
|
require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn))
|
|
|
|
_, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr))
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case p := <-serverReceivedPackets:
|
|
require.Equal(t, "foobar", string(p.Data))
|
|
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
|
require.NotEqual(t, firstConnAddr, p.Addr)
|
|
require.Equal(t, newConn.LocalAddr(), p.Addr)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
// the old connection doesn't deliver any packets to the client anymore
|
|
_, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr)
|
|
require.NoError(t, err)
|
|
_, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr())
|
|
require.NoError(t, err)
|
|
n, addr, err = clientConn.ReadFromUDP(buf)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered
|
|
require.Equal(t, proxy.LocalAddr(), addr)
|
|
}
|