package quicproxy import ( "net" "strconv" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func newUPDConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) return conn } type packetData []byte func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte { t.Helper() hdr := wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, Version: protocol.Version1, Length: 4 + protocol.ByteCount(len(payload)), DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), }, PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen4, } b, err := hdr.Append(nil, protocol.Version1) require.NoError(t, err) b = append(b, payload...) return b } func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber { t.Helper() hdr, data, _, err := wire.ParsePacket(b) require.NoError(t, err) require.Equal(t, protocol.PacketTypeInitial, hdr.Type) extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) return extHdr.PacketNumber } // Set up a dumb UDP server. // In production this would be a QUIC server. func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) serverReceivedPackets := make(chan packetData, 100) done := make(chan struct{}) go func() { defer close(done) for { buf := make([]byte, protocol.MaxPacketBufferSize) // the ReadFromUDP will error as soon as the UDP conn is closed n, addr, err := serverConn.ReadFromUDP(buf) if err != nil { return } data := buf[:n] serverReceivedPackets <- packetData(data) if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet return } } }() t.Cleanup(func() { require.NoError(t, serverConn.Close()) select { case <-done: case <-time.After(time.Second): t.Fatalf("timeout") } }) return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets } func TestProxyyingBackAndForth(t *testing.T) { serverAddr, _ := runServer(t) proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send the first packet _, err = clientConn.Write(makePacket(t, 1, []byte("foobar"))) require.NoError(t, err) // send the second packet _, err = clientConn.Write(makePacket(t, 2, []byte("decafbad"))) require.NoError(t, err) buf := make([]byte, 1024) n, err := clientConn.Read(buf) require.NoError(t, err) require.Contains(t, string(buf[:n]), "foobar") n, err = clientConn.Read(buf) require.NoError(t, err) require.Contains(t, string(buf[:n]), "decafbad") } func TestDropIncomingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DropPacket: func(d Direction, _ []byte) bool { if d != DirectionIncoming { return false } return counter.Add(1)%2 == 1 }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 0; i < numPackets/2; i++ { select { case <-serverReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } select { case <-serverReceivedPackets: t.Fatalf("received unexpected packet") case <-time.After(100 * time.Millisecond): } } func TestDropOutgoingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DropPacket: func(d Direction, _ []byte) bool { if d != DirectionOutgoing { return false } return counter.Add(1)%2 == 1 }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) clientReceivedPackets := make(chan struct{}, numPackets) // receive the packets echoed by the server on client side go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) // the ReadFromUDP will error as soon as the UDP conn is closed if _, _, err := clientConn.ReadFromUDP(buf); err != nil { return } clientReceivedPackets <- struct{}{} } }() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 0; i < numPackets/2; i++ { select { case <-clientReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } select { case <-clientReceivedPackets: t.Fatalf("received unexpected packet") case <-time.After(100 * time.Millisecond): } require.Len(t, serverReceivedPackets, numPackets) } func TestDelayIncomingPackets(t *testing.T) { const numPackets = 3 const delay = 200 * time.Millisecond serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... if d == DirectionOutgoing { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) start := time.Now() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 1; i <= numPackets; i++ { select { case data := <-serverReceivedPackets: require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2) require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, data)) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } } func TestPacketReordering(t *testing.T) { const delay = 200 * time.Millisecond expectDelay := func(startTime time.Time, numRTTs int) { expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay) now := time.Now() require.True(t, now.After(expectedReceiveTime) || now.Equal(expectedReceiveTime)) require.True(t, now.Before(expectedReceiveTime.Add(delay/2))) } serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { // delay packet 1 by 600 ms // delay packet 2 by 400 ms // delay packet 3 by 200 ms if d == DirectionOutgoing { return 0 } p := counter.Add(1) return 600*time.Millisecond - time.Duration(p-1)*delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send 3 packets start := time.Now() for i := 1; i <= 3; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 1; i <= 3; i++ { select { case packet := <-serverReceivedPackets: expectDelay(start, i) expectedPacketNumber := protocol.PacketNumber(4 - i) // 3, 2, 1 in reverse order require.Equal(t, expectedPacketNumber, readPacketNumber(t, packet)) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } } func TestConstantDelay(t *testing.T) { // no reordering expected here serverAddr, serverReceivedPackets := runServer(t) proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { if d == DirectionOutgoing { return 0 } return 100 * time.Millisecond }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send 100 packets for i := 0; i < 100; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } require.Eventually(t, func() bool { return len(serverReceivedPackets) == 100 }, 5*time.Second, 10*time.Millisecond) timeout := time.After(5 * time.Second) for i := 0; i < 100; i++ { select { case packet := <-serverReceivedPackets: require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet)) case <-timeout: t.Fatalf("timeout waiting for packet %d", i) } } } func TestDelayOutgoingPackets(t *testing.T) { const numPackets = 3 const delay = 200 * time.Millisecond serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... if d == DirectionIncoming { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) clientReceivedPackets := make(chan packetData, numPackets) // receive the packets echoed by the server on client side go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) // the ReadFromUDP will error as soon as the UDP conn is closed n, _, err := clientConn.ReadFromUDP(buf) if err != nil { return } clientReceivedPackets <- packetData(buf[0:n]) } }() start := time.Now() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } // the packets should have arrived immediately at the server for i := 0; i < numPackets; i++ { select { case <-serverReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } require.WithinDuration(t, start, time.Now(), delay/2) for i := 1; i <= numPackets; i++ { select { case packet := <-clientReceivedPackets: require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet)) require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } }