From 79bae396b4f4d56fb1701b196f6f1b79c119ccbb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 25 Jan 2025 11:13:33 +0100 Subject: [PATCH] proxy: rename to Proxy, refactor initialization (#4921) * proxy: rename to Proxy, refactor initialization * improve documentation --- integrationtests/self/close_test.go | 12 +- integrationtests/self/datagram_test.go | 12 +- integrationtests/self/drop_test.go | 11 +- integrationtests/self/early_data_test.go | 10 +- integrationtests/self/handshake_drop_test.go | 11 +- integrationtests/self/handshake_rtt_test.go | 11 +- integrationtests/self/handshake_test.go | 23 ++- integrationtests/self/http_test.go | 9 +- integrationtests/self/mitm_test.go | 23 ++- integrationtests/self/mtu_test.go | 12 +- integrationtests/self/packetization_test.go | 23 +-- integrationtests/self/rtt_test.go | 22 +-- integrationtests/self/stateless_reset_test.go | 14 +- integrationtests/self/timeout_test.go | 35 ++-- integrationtests/self/zero_rtt_test.go | 47 +++--- integrationtests/tools/proxy/proxy.go | 136 ++++++---------- integrationtests/tools/proxy/proxy_test.go | 150 +++++++++--------- .../versionnegotiation/rtt_test.go | 16 +- 18 files changed, 274 insertions(+), 303 deletions(-) diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index fd6619b8..292a1563 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "sync/atomic" "testing" @@ -28,9 +27,10 @@ func TestConnectionCloseRetransmission(t *testing.T) { var drop atomic.Bool dropped := make(chan []byte, 100) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return 5 * time.Millisecond // 10ms RTT }, DropPacket: func(dir quicproxy.Direction, b []byte) bool { @@ -40,8 +40,8 @@ func TestConnectionCloseRetransmission(t *testing.T) { } return false }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index f16d4305..11ae654d 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -3,7 +3,6 @@ package self_test import ( "bytes" "context" - "fmt" mrand "math/rand/v2" "net" "sync/atomic" @@ -138,8 +137,9 @@ func TestDatagramLoss(t *testing.T) { defer server.Close() var droppedIncoming, droppedOutgoing, total atomic.Int32 - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), DropPacket: func(dir quicproxy.Direction, packet []byte) bool { if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets return false @@ -159,9 +159,9 @@ func TestDatagramLoss(t *testing.T) { } return false }, - DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(numDatagrams*time.Millisecond)) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index 993c9845..d4e58273 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -33,9 +33,10 @@ func TestDropTests(t *testing.T) { defer ln.Close() var numDroppedPackets atomic.Int32 - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, DropPacket: func(d quicproxy.Direction, b []byte) bool { if !d.Is(direction) { return false @@ -49,8 +50,8 @@ func TestDropTests(t *testing.T) { } return drop }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index e15179e6..9e30409f 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -2,7 +2,6 @@ package self_test import ( "context" - "fmt" "io" "net" "testing" @@ -20,11 +19,12 @@ func TestEarlyData(t *testing.T) { require.NoError(t, err) defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() connChan := make(chan quic.EarlyConnection) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 1f35bf9d..bf7feb76 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -43,12 +43,13 @@ func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dro require.NoError(t, err) t.Cleanup(func() { ln.Close() }) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DropPacket: dropCallback, - DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) t.Cleanup(func() { proxy.Close() }) return ln, proxy.LocalAddr() } diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index c85cbd5e..992a19c6 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -17,11 +17,12 @@ import ( func handshakeWithRTT(t *testing.T, serverAddr net.Addr, tlsConf *tls.Config, quicConf *quic.Config, rtt time.Duration) quic.Connection { t.Helper() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: serverAddr.String(), - DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr.(*net.UDPAddr), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) t.Cleanup(func() { proxy.Close() }) ctx, cancel := context.WithTimeout(context.Background(), 10*rtt) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 020351f4..3bb81ef8 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -359,11 +359,12 @@ func TestHandshakingConnectionsClosedOnServerShutdown(t *testing.T) { ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) require.NoError(t, err) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() errChan := make(chan error, 1) @@ -549,14 +550,12 @@ func TestInvalidToken(t *testing.T) { require.NoError(t, err) defer server.Close() - serverPort := server.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { - return rtt / 2 - }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 55c202c5..bf4a0bd5 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -848,16 +848,17 @@ func TestHTTP0RTT(t *testing.T) { port := startHTTPServer(t, mux) var num0RTTPackets atomic.Uint32 - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", port), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { if contains0RTTPacket(data) { num0RTTPackets.Add(1) } return scaleDuration(25 * time.Millisecond) }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() tlsConf := getTLSClientConfigWithoutServerName() diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index b5096d5f..bca522a2 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -3,7 +3,6 @@ package self_test import ( "context" "errors" - "fmt" "io" "math" "net" @@ -204,13 +203,13 @@ func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Dura require.NoError(t, err) defer ln.Close() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 }, + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), + DelayPacket: func(_ quicproxy.Direction, b []byte) time.Duration { return rtt / 2 }, DropPacket: dropCb, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(time.Second)) @@ -413,12 +412,12 @@ func runMITMTestSuccessful(t *testing.T, serverTransport, clientTransport *quic. require.NoError(t, err) defer ln.Close() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: delayCb, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) diff --git a/integrationtests/self/mtu_test.go b/integrationtests/self/mtu_test.go index 700e3124..fb45ea49 100644 --- a/integrationtests/self/mtu_test.go +++ b/integrationtests/self/mtu_test.go @@ -79,10 +79,10 @@ func TestPathMTUDiscovery(t *testing.T) { var mx sync.Mutex var maxPacketSizeServer int var clientPacketSizes []int - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, DropPacket: func(dir quicproxy.Direction, packet []byte) bool { if len(packet) > mtu { return true @@ -99,8 +99,8 @@ func TestPathMTUDiscovery(t *testing.T) { } return false }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() // Make sure to use v4-only socket here. diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 3b0349e7..86a30f59 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "os" "testing" "time" @@ -33,13 +34,14 @@ func TestACKBundling(t *testing.T) { require.NoError(t, err) defer server.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: server.Addr().String(), - DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return 5 * time.Millisecond }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() clientCounter, clientTracer := newPacketTracer() @@ -161,13 +163,14 @@ func testConnAndStreamDataBlocked(t *testing.T, limitStream, limitConn bool) { require.NoError(t, err) defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), - DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() counter, tracer := newPacketTracer() diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index 16793785..64c95848 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -65,11 +65,12 @@ func TestDownloadWithFixedRTT(t *testing.T) { } }) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port}, + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) t.Cleanup(func() { proxy.Close() }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -109,13 +110,14 @@ func TestDownloadWithReordering(t *testing.T) { } }) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port}, + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return randomDuration(rtt/2, rtt*3/2) / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) t.Cleanup(func() { proxy.Close() }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 49500ff8..14babc64 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -3,7 +3,6 @@ package self_test import ( "context" "crypto/rand" - "fmt" "net" "sync/atomic" "testing" @@ -37,7 +36,6 @@ func testStatelessReset(t *testing.T, connIDLen int) { defer tr.Close() ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) - serverPort := ln.Addr().(*net.UDPAddr).Port serverErr := make(chan error, 1) go func() { @@ -60,12 +58,12 @@ func testStatelessReset(t *testing.T, connIDLen int) { }() var drop atomic.Bool - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DropPacket: func(quicproxy.Direction, []byte) bool { - return drop.Load() - }, - }) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), + DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() }, + } + require.NoError(t, proxy.Start()) require.NoError(t, err) defer proxy.Close() diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index d861b816..6f83f6df 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -111,13 +111,12 @@ func TestIdleTimeout(t *testing.T) { defer server.Close() var drop atomic.Bool - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DropPacket: func(quicproxy.Direction, []byte) bool { - return drop.Load() - }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() }, + } + require.NoError(t, proxy.Start()) defer proxy.Close() conn, err := quic.Dial( @@ -179,11 +178,12 @@ func TestKeepAlive(t *testing.T) { defer server.Close() var drop atomic.Bool - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() }, + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -320,11 +320,12 @@ func TestTimeoutAfterSendingPacket(t *testing.T) { defer server.Close() var drop atomic.Bool - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DropPacket: func(d quicproxy.Direction, _ []byte) bool { return d == quicproxy.DirectionOutgoing && drop.Load() }, - }) - require.NoError(t, err) + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: server.Addr().(*net.UDPAddr), + DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() }, + } + require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index ed86537b..220cad9e 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -21,19 +21,20 @@ import ( "github.com/stretchr/testify/require" ) -func runCountingProxyAndCount0RTTPackets(t *testing.T, serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) { +func runCountingProxyAndCount0RTTPackets(t *testing.T, serverPort int, rtt time.Duration) (*quicproxy.Proxy, *atomic.Uint32) { t.Helper() var num0RTTPackets atomic.Uint32 - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: serverPort}, DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { if contains0RTTPacket(data) { num0RTTPackets.Add(1) } return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) t.Cleanup(func() { proxy.Close() }) return proxy, &num0RTTPackets } @@ -51,11 +52,12 @@ func dialAndReceiveTicket( require.NoError(t, err) defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), + proxy := &quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() clientTLSConf = getTLSClientConfig() @@ -361,8 +363,9 @@ func Test0RTTDataLoss(t *testing.T) { defer ln.Close() var num0RTTPackets, numDropped atomic.Uint32 - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, DropPacket: func(_ quicproxy.Direction, data []byte) bool { if !wire.IsLongHeaderPacket(data[0]) { @@ -380,8 +383,8 @@ func Test0RTTDataLoss(t *testing.T) { } return false }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() transfer0RTTData(t, ln, proxy.LocalAddr(), clientConf, nil, PRData) @@ -430,8 +433,9 @@ func Test0RTTRetransmitOnRetry(t *testing.T) { } var mutex sync.Mutex var connIDToCounter []*connIDCounter - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { connID, err := wire.ParseConnectionID(data, 0) if err != nil { @@ -455,8 +459,8 @@ func Test0RTTRetransmitOnRetry(t *testing.T) { } return 2 * time.Millisecond }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() transfer0RTTData(t, ln, proxy.LocalAddr(), clientConf, nil, GeneratePRData(5000)) // ~5 packets @@ -905,8 +909,9 @@ func Test0RTTPacketQueueing(t *testing.T) { require.NoError(t, err) defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), + proxy := quicproxy.Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { // delay the client's Initial by 1 RTT if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { @@ -914,8 +919,8 @@ func Test0RTTPacketQueueing(t *testing.T) { } return rtt / 2 }, - }) - require.NoError(t, err) + } + require.NoError(t, proxy.Start()) defer proxy.Close() data := GeneratePRData(5000) // ~5 packets diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index d2873bf9..2ab49085 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -113,100 +113,52 @@ func (d Direction) Is(dir Direction) bool { // DropCallback is a callback that determines which packet gets dropped. type DropCallback func(dir Direction, packet []byte) bool -// NoDropper doesn't drop packets. -var NoDropper DropCallback = func(Direction, []byte) bool { - return false -} - // DelayCallback is a callback that determines how much delay to apply to a packet. type DelayCallback func(dir Direction, packet []byte) time.Duration -// NoDelay doesn't apply a delay. -var NoDelay DelayCallback = func(Direction, []byte) time.Duration { - return 0 -} +// Proxy is a QUIC proxy that can drop and delay packets. +type Proxy struct { + // Conn is the UDP socket that the proxy listens on for incoming packets + // from clients. + Conn *net.UDPConn -// Opts are proxy options. -type Opts struct { - // The address this proxy proxies packets to. - RemoteAddr string - // DropPacket determines whether a packet gets dropped. + // ServerAddr is the address of the server that the proxy forwards packets to. + ServerAddr *net.UDPAddr + + // DropPacket is a callback that determines which packet gets dropped. DropPacket DropCallback - // DelayPacket determines how long a packet gets delayed. This allows - // simulating a connection with non-zero RTTs. - // Note that the RTT is the sum of the delay for the incoming and the outgoing packet. - DelayPacket DelayCallback -} -// QuicProxy is a QUIC proxy that can drop and delay packets. -type QuicProxy struct { - mutex sync.Mutex + // DelayPacket is a callback that determines how much delay to apply to a packet. + DelayPacket DelayCallback closeChan chan struct{} + logger utils.Logger - conn *net.UDPConn - serverAddr *net.UDPAddr - - dropPacket DropCallback - delayPacket DelayCallback - - // Mapping from client addresses (as host:port) to connection + // mapping from client addresses (as host:port) to connection + mutex sync.Mutex clientDict map[string]*connection - - logger utils.Logger } // NewQuicProxy creates a new UDP proxy -func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) { - if opts == nil { - opts = &Opts{} +func (p *Proxy) Start() error { + p.clientDict = make(map[string]*connection) + p.closeChan = make(chan struct{}) + p.logger = utils.DefaultLogger.WithPrefix("proxy") + + if err := p.Conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return err } - laddr, err := net.ResolveUDPAddr("udp", local) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return nil, err - } - if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { - return nil, err - } - raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr) - if err != nil { - return nil, err + if err := p.Conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { + return err } - packetDropper := NoDropper - if opts.DropPacket != nil { - packetDropper = opts.DropPacket - } - - packetDelayer := NoDelay - if opts.DelayPacket != nil { - packetDelayer = opts.DelayPacket - } - - p := QuicProxy{ - clientDict: make(map[string]*connection), - conn: conn, - closeChan: make(chan struct{}), - serverAddr: raddr, - dropPacket: packetDropper, - delayPacket: packetDelayer, - logger: utils.DefaultLogger.WithPrefix("proxy"), - } - - p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) + p.logger.Debugf("Starting UDP Proxy %s <-> %s", p.Conn.LocalAddr(), p.ServerAddr) go p.runProxy() - return &p, nil + return nil } // Close stops the UDP Proxy -func (p *QuicProxy) Close() error { +func (p *Proxy) Close() error { p.mutex.Lock() defer p.mutex.Unlock() @@ -218,16 +170,14 @@ func (p *QuicProxy) Close() error { c.Incoming.Close() c.Outgoing.Close() } - return p.conn.Close() + return nil } // LocalAddr is the address the proxy is listening on. -func (p *QuicProxy) LocalAddr() net.Addr { - return p.conn.LocalAddr() -} +func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() } -func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { - conn, err := net.DialUDP("udp", nil, p.serverAddr) +func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { + conn, err := net.DialUDP("udp", nil, p.ServerAddr) if err != nil { return nil, err } @@ -247,10 +197,10 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { } // runProxy listens on the proxy address and handles incoming packets. -func (p *QuicProxy) runProxy() error { +func (p *Proxy) runProxy() error { for { buffer := make([]byte, protocol.MaxPacketBufferSize) - n, cliaddr, err := p.conn.ReadFromUDP(buffer) + n, cliaddr, err := p.Conn.ReadFromUDP(buffer) if err != nil { return err } @@ -272,14 +222,17 @@ func (p *QuicProxy) runProxy() error { } p.mutex.Unlock() - if p.dropPacket(DirectionIncoming, raw) { + if p.DropPacket != nil && p.DropPacket(DirectionIncoming, raw) { if p.logger.Debug() { p.logger.Debugf("dropping incoming packet(%d bytes)", n) } continue } - delay := p.delayPacket(DirectionIncoming, raw) + var delay time.Duration + if p.DelayPacket != nil { + delay = p.DelayPacket(DirectionIncoming, raw) + } if delay == 0 { if p.logger.Debug() { p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr()) @@ -298,7 +251,7 @@ func (p *QuicProxy) runProxy() error { } // runConnection handles packets from server to a single client -func (p *QuicProxy) runOutgoingConnection(conn *connection) error { +func (p *Proxy) runOutgoingConnection(conn *connection) error { outgoingPackets := make(chan packetEntry, 10) go func() { for { @@ -309,19 +262,22 @@ func (p *QuicProxy) runOutgoingConnection(conn *connection) error { } raw := buffer[0:n] - if p.dropPacket(DirectionOutgoing, raw) { + if p.DropPacket != nil && p.DropPacket(DirectionOutgoing, raw) { if p.logger.Debug() { p.logger.Debugf("dropping outgoing packet(%d bytes)", n) } continue } - delay := p.delayPacket(DirectionOutgoing, raw) + var delay time.Duration + if p.DelayPacket != nil { + delay = p.DelayPacket(DirectionOutgoing, raw) + } if delay == 0 { if p.logger.Debug() { p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr) } - if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { + if _, err := p.Conn.WriteToUDP(raw, conn.ClientAddr); err != nil { return } } else { @@ -342,14 +298,14 @@ func (p *QuicProxy) runOutgoingConnection(conn *connection) error { conn.Outgoing.Add(e) case <-conn.Outgoing.Timer(): conn.Outgoing.SetTimerRead() - if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil { + if _, err := p.Conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil { return err } } } } -func (p *QuicProxy) runIncomingConnection(conn *connection) error { +func (p *Proxy) runIncomingConnection(conn *connection) error { for { select { case <-p.closeChan: diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index adae8953..39d0f8cb 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -1,11 +1,8 @@ package quicproxy import ( - "bytes" "net" - "runtime/pprof" "strconv" - "strings" "sync/atomic" "testing" "time" @@ -16,6 +13,14 @@ import ( "github.com/stretchr/testify/require" ) +func newUPDConnLocalhost(t testing.TB) *net.UDPConn { + t.Helper() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + return conn +} + type packetData []byte func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte { @@ -47,37 +52,6 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber { return extHdr.PacketNumber } -func TestProxyShutdown(t *testing.T) { - isProxyRunning := func() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy") - } - - proxy, err := NewQuicProxy("localhost:0", nil) - require.NoError(t, err) - require.Eventually(t, func() bool { return isProxyRunning() }, time.Second, 10*time.Millisecond) - - conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) - require.NoError(t, err) - _, err = conn.Write(makePacket(t, 1, []byte("foobar"))) - require.NoError(t, err) - - require.NoError(t, proxy.Close()) - - // check that the proxy port is not in use anymore - // sometimes it takes a while for the OS to free the port - require.Eventually(t, func() bool { - ln, err := net.ListenUDP("udp", proxy.LocalAddr().(*net.UDPAddr)) - if err != nil { - return false - } - ln.Close() - return true - }, time.Second, 10*time.Millisecond) - require.Eventually(t, func() bool { return !isProxyRunning() }, time.Second, 10*time.Millisecond) -} - // Set up a dumb UDP server. // In production this would be a QUIC server. func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { @@ -116,25 +90,19 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets } -func startProxy(t *testing.T, opts *Opts) (clientConn *net.UDPConn) { - proxy, err := NewQuicProxy("localhost:0", opts) - require.NoError(t, err) - clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) - require.NoError(t, err) - - t.Cleanup(func() { - require.NoError(t, proxy.Close()) - require.NoError(t, clientConn.Close()) - }) - return clientConn -} - func TestProxyyingBackAndForth(t *testing.T) { serverAddr, _ := runServer(t) - clientConn := startProxy(t, &Opts{RemoteAddr: serverAddr.String()}) + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) // send the first packet - _, err := clientConn.Write(makePacket(t, 1, []byte("foobar"))) + _, err = clientConn.Write(makePacket(t, 1, []byte("foobar"))) require.NoError(t, err) // send the second packet _, err = clientConn.Write(makePacket(t, 2, []byte("decafbad"))) @@ -153,15 +121,20 @@ func TestDropIncomingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DropPacket: func(d Direction, _ []byte) bool { if d != DirectionIncoming { return false } return counter.Add(1)%2 == 1 }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) @@ -186,15 +159,20 @@ func TestDropOutgoingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DropPacket: func(d Direction, _ []byte) bool { if d != DirectionOutgoing { return false } return counter.Add(1)%2 == 1 }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) clientReceivedPackets := make(chan struct{}, numPackets) // receive the packets echoed by the server on client side @@ -234,19 +212,24 @@ func TestDelayIncomingPackets(t *testing.T) { const delay = 200 * time.Millisecond serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), - // delay packet 1 by 200 ms - // delay packet 2 by 400 ms - // ... + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { + // delay packet 1 by 200 ms + // delay packet 2 by 400 ms + // ... if d == DirectionOutgoing { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) start := time.Now() for i := 1; i <= numPackets; i++ { @@ -276,19 +259,24 @@ func TestPacketReordering(t *testing.T) { serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), - // delay packet 1 by 600 ms - // delay packet 2 by 400 ms - // delay packet 3 by 200 ms + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { + // delay packet 1 by 600 ms + // delay packet 2 by 400 ms + // delay packet 3 by 200 ms if d == DirectionOutgoing { return 0 } p := counter.Add(1) return 600*time.Millisecond - time.Duration(p-1)*delay }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) // send 3 packets start := time.Now() @@ -310,15 +298,20 @@ func TestPacketReordering(t *testing.T) { func TestConstantDelay(t *testing.T) { // no reordering expected here serverAddr, serverReceivedPackets := runServer(t) - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { if d == DirectionOutgoing { return 0 } return 100 * time.Millisecond }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) // send 100 packets for i := 0; i < 100; i++ { @@ -343,19 +336,24 @@ func TestDelayOutgoingPackets(t *testing.T) { serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 - clientConn := startProxy(t, &Opts{ - RemoteAddr: serverAddr.String(), - // delay packet 1 by 200 ms - // delay packet 2 by 400 ms - // ... + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverAddr, DelayPacket: func(d Direction, _ []byte) time.Duration { + // delay packet 1 by 200 ms + // delay packet 2 by 400 ms + // ... if d == DirectionIncoming { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, - }) + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) clientReceivedPackets := make(chan packetData, numPackets) // receive the packets echoed by the server on client side diff --git a/integrationtests/versionnegotiation/rtt_test.go b/integrationtests/versionnegotiation/rtt_test.go index 34aff17e..35f99e38 100644 --- a/integrationtests/versionnegotiation/rtt_test.go +++ b/integrationtests/versionnegotiation/rtt_test.go @@ -2,6 +2,7 @@ package versionnegotiation import ( "context" + "net" "testing" "time" @@ -32,12 +33,17 @@ func TestVersionNegotiationFailure(t *testing.T) { require.NoError(t, err) defer ln.Close() - // start the proxy - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), - DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, - }) + proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) + defer proxyConn.Close() + // start the proxy + proxy := quicproxy.Proxy{ + Conn: proxyConn, + ServerAddr: ln.Addr().(*net.UDPAddr), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + } + require.NoError(t, proxy.Start()) + defer proxy.Close() startTime := time.Now() _, err = quic.DialAddr(