proxy: rename to Proxy, refactor initialization (#4921)

* proxy: rename to Proxy, refactor initialization

* improve documentation
This commit is contained in:
Marten Seemann 2025-01-25 11:13:33 +01:00 committed by GitHub
parent f5145eb627
commit 79bae396b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 274 additions and 303 deletions

View file

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync/atomic"
"testing"
@ -28,9 +27,10 @@ func TestConnectionCloseRetransmission(t *testing.T) {
var drop atomic.Bool
dropped := make(chan []byte, 100)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond // 10ms RTT
},
DropPacket: func(dir quicproxy.Direction, b []byte) bool {
@ -40,8 +40,8 @@ func TestConnectionCloseRetransmission(t *testing.T) {
}
return false
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View file

@ -3,7 +3,6 @@ package self_test
import (
"bytes"
"context"
"fmt"
mrand "math/rand/v2"
"net"
"sync/atomic"
@ -138,8 +137,9 @@ func TestDatagramLoss(t *testing.T) {
defer server.Close()
var droppedIncoming, droppedOutgoing, total atomic.Int32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
return false
@ -159,9 +159,9 @@ func TestDatagramLoss(t *testing.T) {
}
return false
},
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(numDatagrams*time.Millisecond))

View file

@ -33,9 +33,10 @@ func TestDropTests(t *testing.T) {
defer ln.Close()
var numDroppedPackets atomic.Int32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
DropPacket: func(d quicproxy.Direction, b []byte) bool {
if !d.Is(direction) {
return false
@ -49,8 +50,8 @@ func TestDropTests(t *testing.T) {
}
return drop
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View file

@ -2,7 +2,6 @@ package self_test
import (
"context"
"fmt"
"io"
"net"
"testing"
@ -20,11 +19,12 @@ func TestEarlyData(t *testing.T) {
require.NoError(t, err)
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
connChan := make(chan quic.EarlyConnection)

View file

@ -43,12 +43,13 @@ func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dro
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: dropCallback,
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
return ln, proxy.LocalAddr()
}

View file

@ -17,11 +17,12 @@ import (
func handshakeWithRTT(t *testing.T, serverAddr net.Addr, tlsConf *tls.Config, quicConf *quic.Config, rtt time.Duration) quic.Connection {
t.Helper()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr.String(),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr.(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
ctx, cancel := context.WithTimeout(context.Background(), 10*rtt)

View file

@ -359,11 +359,12 @@ func TestHandshakingConnectionsClosedOnServerShutdown(t *testing.T) {
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
require.NoError(t, err)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
errChan := make(chan error, 1)
@ -549,14 +550,12 @@ func TestInvalidToken(t *testing.T) {
require.NoError(t, err)
defer server.Close()
serverPort := server.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return rtt / 2
},
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View file

@ -848,16 +848,17 @@ func TestHTTP0RTT(t *testing.T) {
port := startHTTPServer(t, mux)
var num0RTTPackets atomic.Uint32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", port),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port},
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
if contains0RTTPacket(data) {
num0RTTPackets.Add(1)
}
return scaleDuration(25 * time.Millisecond)
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
tlsConf := getTLSClientConfigWithoutServerName()

View file

@ -3,7 +3,6 @@ package self_test
import (
"context"
"errors"
"fmt"
"io"
"math"
"net"
@ -204,13 +203,13 @@ func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Dura
require.NoError(t, err)
defer ln.Close()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
DropPacket: dropCb,
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(time.Second))
@ -413,12 +412,12 @@ func runMITMTestSuccessful(t *testing.T, serverTransport, clientTransport *quic.
require.NoError(t, err)
defer ln.Close()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: delayCb,
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))

View file

@ -79,10 +79,10 @@ func TestPathMTUDiscovery(t *testing.T) {
var mx sync.Mutex
var maxPacketSizeServer int
var clientPacketSizes []int
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if len(packet) > mtu {
return true
@ -99,8 +99,8 @@ func TestPathMTUDiscovery(t *testing.T) {
}
return false
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
// Make sure to use v4-only socket here.

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net"
"os"
"testing"
"time"
@ -33,13 +34,14 @@ func TestACKBundling(t *testing.T) {
require.NoError(t, err)
defer server.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientCounter, clientTracer := newPacketTracer()
@ -161,13 +163,14 @@ func testConnAndStreamDataBlocked(t *testing.T, limitStream, limitConn bool) {
require.NoError(t, err)
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration {
return rtt / 2
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
counter, tracer := newPacketTracer()

View file

@ -65,11 +65,12 @@ func TestDownloadWithFixedRTT(t *testing.T) {
}
})
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@ -109,13 +110,14 @@ func TestDownloadWithReordering(t *testing.T) {
}
})
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration {
return randomDuration(rtt/2, rtt*3/2) / 2
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View file

@ -3,7 +3,6 @@ package self_test
import (
"context"
"crypto/rand"
"fmt"
"net"
"sync/atomic"
"testing"
@ -37,7 +36,6 @@ func testStatelessReset(t *testing.T, connIDLen int) {
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
serverPort := ln.Addr().(*net.UDPAddr).Port
serverErr := make(chan error, 1)
go func() {
@ -60,12 +58,12 @@ func testStatelessReset(t *testing.T, connIDLen int) {
}()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
require.NoError(t, err)
defer proxy.Close()

View file

@ -111,13 +111,12 @@ func TestIdleTimeout(t *testing.T) {
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
conn, err := quic.Dial(
@ -179,11 +178,12 @@ func TestKeepAlive(t *testing.T) {
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@ -320,11 +320,12 @@ func TestTimeoutAfterSendingPacket(t *testing.T) {
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(d quicproxy.Direction, _ []byte) bool { return d == quicproxy.DirectionOutgoing && drop.Load() },
})
require.NoError(t, err)
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: server.Addr().(*net.UDPAddr),
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View file

@ -21,19 +21,20 @@ import (
"github.com/stretchr/testify/require"
)
func runCountingProxyAndCount0RTTPackets(t *testing.T, serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) {
func runCountingProxyAndCount0RTTPackets(t *testing.T, serverPort int, rtt time.Duration) (*quicproxy.Proxy, *atomic.Uint32) {
t.Helper()
var num0RTTPackets atomic.Uint32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: serverPort},
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
if contains0RTTPacket(data) {
num0RTTPackets.Add(1)
}
return rtt / 2
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
return proxy, &num0RTTPackets
}
@ -51,11 +52,12 @@ func dialAndReceiveTicket(
require.NoError(t, err)
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
proxy := &quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientTLSConf = getTLSClientConfig()
@ -361,8 +363,9 @@ func Test0RTTDataLoss(t *testing.T) {
defer ln.Close()
var num0RTTPackets, numDropped atomic.Uint32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 },
DropPacket: func(_ quicproxy.Direction, data []byte) bool {
if !wire.IsLongHeaderPacket(data[0]) {
@ -380,8 +383,8 @@ func Test0RTTDataLoss(t *testing.T) {
}
return false
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
transfer0RTTData(t, ln, proxy.LocalAddr(), clientConf, nil, PRData)
@ -430,8 +433,9 @@ func Test0RTTRetransmitOnRetry(t *testing.T) {
}
var mutex sync.Mutex
var connIDToCounter []*connIDCounter
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
connID, err := wire.ParseConnectionID(data, 0)
if err != nil {
@ -455,8 +459,8 @@ func Test0RTTRetransmitOnRetry(t *testing.T) {
}
return 2 * time.Millisecond
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
transfer0RTTData(t, ln, proxy.LocalAddr(), clientConf, nil, GeneratePRData(5000)) // ~5 packets
@ -905,8 +909,9 @@ func Test0RTTPacketQueueing(t *testing.T) {
require.NoError(t, err)
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
// delay the client's Initial by 1 RTT
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 {
@ -914,8 +919,8 @@ func Test0RTTPacketQueueing(t *testing.T) {
}
return rtt / 2
},
})
require.NoError(t, err)
}
require.NoError(t, proxy.Start())
defer proxy.Close()
data := GeneratePRData(5000) // ~5 packets

View file

@ -113,100 +113,52 @@ func (d Direction) Is(dir Direction) bool {
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(dir Direction, packet []byte) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, []byte) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(dir Direction, packet []byte) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, []byte) time.Duration {
return 0
}
// Proxy is a QUIC proxy that can drop and delay packets.
type Proxy struct {
// Conn is the UDP socket that the proxy listens on for incoming packets
// from clients.
Conn *net.UDPConn
// Opts are proxy options.
type Opts struct {
// The address this proxy proxies packets to.
RemoteAddr string
// DropPacket determines whether a packet gets dropped.
// ServerAddr is the address of the server that the proxy forwards packets to.
ServerAddr *net.UDPAddr
// DropPacket is a callback that determines which packet gets dropped.
DropPacket DropCallback
// DelayPacket determines how long a packet gets delayed. This allows
// simulating a connection with non-zero RTTs.
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
DelayPacket DelayCallback
}
// QuicProxy is a QUIC proxy that can drop and delay packets.
type QuicProxy struct {
mutex sync.Mutex
// DelayPacket is a callback that determines how much delay to apply to a packet.
DelayPacket DelayCallback
closeChan chan struct{}
logger utils.Logger
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
// Mapping from client addresses (as host:port) to connection
// mapping from client addresses (as host:port) to connection
mutex sync.Mutex
clientDict map[string]*connection
logger utils.Logger
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
func (p *Proxy) Start() error {
p.clientDict = make(map[string]*connection)
p.closeChan = make(chan struct{})
p.logger = utils.DefaultLogger.WithPrefix("proxy")
if err := p.Conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return err
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return nil, err
}
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
return nil, err
}
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
if err != nil {
return nil, err
if err := p.Conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
return err
}
packetDropper := NoDropper
if opts.DropPacket != nil {
packetDropper = opts.DropPacket
}
packetDelayer := NoDelay
if opts.DelayPacket != nil {
packetDelayer = opts.DelayPacket
}
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
closeChan: make(chan struct{}),
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
p.logger.Debugf("Starting UDP Proxy %s <-> %s", p.Conn.LocalAddr(), p.ServerAddr)
go p.runProxy()
return &p, nil
return nil
}
// Close stops the UDP Proxy
func (p *QuicProxy) Close() error {
func (p *Proxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
@ -218,16 +170,14 @@ func (p *QuicProxy) Close() error {
c.Incoming.Close()
c.Outgoing.Close()
}
return p.conn.Close()
return nil
}
// LocalAddr is the address the proxy is listening on.
func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
conn, err := net.DialUDP("udp", nil, p.serverAddr)
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
if err != nil {
return nil, err
}
@ -247,10 +197,10 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
}
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
func (p *Proxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxPacketBufferSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
n, cliaddr, err := p.Conn.ReadFromUDP(buffer)
if err != nil {
return err
}
@ -272,14 +222,17 @@ func (p *QuicProxy) runProxy() error {
}
p.mutex.Unlock()
if p.dropPacket(DirectionIncoming, raw) {
if p.DropPacket != nil && p.DropPacket(DirectionIncoming, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionIncoming, raw)
var delay time.Duration
if p.DelayPacket != nil {
delay = p.DelayPacket(DirectionIncoming, raw)
}
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
@ -298,7 +251,7 @@ func (p *QuicProxy) runProxy() error {
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
func (p *Proxy) runOutgoingConnection(conn *connection) error {
outgoingPackets := make(chan packetEntry, 10)
go func() {
for {
@ -309,19 +262,22 @@ func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
}
raw := buffer[0:n]
if p.dropPacket(DirectionOutgoing, raw) {
if p.DropPacket != nil && p.DropPacket(DirectionOutgoing, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, raw)
var delay time.Duration
if p.DelayPacket != nil {
delay = p.DelayPacket(DirectionOutgoing, raw)
}
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
if _, err := p.Conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return
}
} else {
@ -342,14 +298,14 @@ func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
conn.Outgoing.Add(e)
case <-conn.Outgoing.Timer():
conn.Outgoing.SetTimerRead()
if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
if _, err := p.Conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
return err
}
}
}
}
func (p *QuicProxy) runIncomingConnection(conn *connection) error {
func (p *Proxy) runIncomingConnection(conn *connection) error {
for {
select {
case <-p.closeChan:

View file

@ -1,11 +1,8 @@
package quicproxy
import (
"bytes"
"net"
"runtime/pprof"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -16,6 +13,14 @@ import (
"github.com/stretchr/testify/require"
)
func newUPDConnLocalhost(t testing.TB) *net.UDPConn {
t.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
return conn
}
type packetData []byte
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
@ -47,37 +52,6 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
return extHdr.PacketNumber
}
func TestProxyShutdown(t *testing.T) {
isProxyRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
}
proxy, err := NewQuicProxy("localhost:0", nil)
require.NoError(t, err)
require.Eventually(t, func() bool { return isProxyRunning() }, time.Second, 10*time.Millisecond)
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
_, err = conn.Write(makePacket(t, 1, []byte("foobar")))
require.NoError(t, err)
require.NoError(t, proxy.Close())
// check that the proxy port is not in use anymore
// sometimes it takes a while for the OS to free the port
require.Eventually(t, func() bool {
ln, err := net.ListenUDP("udp", proxy.LocalAddr().(*net.UDPAddr))
if err != nil {
return false
}
ln.Close()
return true
}, time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool { return !isProxyRunning() }, time.Second, 10*time.Millisecond)
}
// Set up a dumb UDP server.
// In production this would be a QUIC server.
func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
@ -116,25 +90,19 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
}
func startProxy(t *testing.T, opts *Opts) (clientConn *net.UDPConn) {
proxy, err := NewQuicProxy("localhost:0", opts)
require.NoError(t, err)
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, proxy.Close())
require.NoError(t, clientConn.Close())
})
return clientConn
}
func TestProxyyingBackAndForth(t *testing.T) {
serverAddr, _ := runServer(t)
clientConn := startProxy(t, &Opts{RemoteAddr: serverAddr.String()})
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
// send the first packet
_, err := clientConn.Write(makePacket(t, 1, []byte("foobar")))
_, err = clientConn.Write(makePacket(t, 1, []byte("foobar")))
require.NoError(t, err)
// send the second packet
_, err = clientConn.Write(makePacket(t, 2, []byte("decafbad")))
@ -153,15 +121,20 @@ func TestDropIncomingPackets(t *testing.T) {
const numPackets = 6
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionIncoming {
return false
}
return counter.Add(1)%2 == 1
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
@ -186,15 +159,20 @@ func TestDropOutgoingPackets(t *testing.T) {
const numPackets = 6
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionOutgoing {
return false
}
return counter.Add(1)%2 == 1
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
clientReceivedPackets := make(chan struct{}, numPackets)
// receive the packets echoed by the server on client side
@ -234,19 +212,24 @@ func TestDelayIncomingPackets(t *testing.T) {
const delay = 200 * time.Millisecond
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DelayPacket: func(d Direction, _ []byte) time.Duration {
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
if d == DirectionOutgoing {
return 0
}
p := counter.Add(1)
return time.Duration(p) * delay
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
start := time.Now()
for i := 1; i <= numPackets; i++ {
@ -276,19 +259,24 @@ func TestPacketReordering(t *testing.T) {
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 600 ms
// delay packet 2 by 400 ms
// delay packet 3 by 200 ms
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DelayPacket: func(d Direction, _ []byte) time.Duration {
// delay packet 1 by 600 ms
// delay packet 2 by 400 ms
// delay packet 3 by 200 ms
if d == DirectionOutgoing {
return 0
}
p := counter.Add(1)
return 600*time.Millisecond - time.Duration(p-1)*delay
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
// send 3 packets
start := time.Now()
@ -310,15 +298,20 @@ func TestPacketReordering(t *testing.T) {
func TestConstantDelay(t *testing.T) { // no reordering expected here
serverAddr, serverReceivedPackets := runServer(t)
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
return 100 * time.Millisecond
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
// send 100 packets
for i := 0; i < 100; i++ {
@ -343,19 +336,24 @@ func TestDelayOutgoingPackets(t *testing.T) {
serverAddr, serverReceivedPackets := runServer(t)
var counter atomic.Int32
clientConn := startProxy(t, &Opts{
RemoteAddr: serverAddr.String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverAddr,
DelayPacket: func(d Direction, _ []byte) time.Duration {
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
if d == DirectionIncoming {
return 0
}
p := counter.Add(1)
return time.Duration(p) * delay
},
})
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side

View file

@ -2,6 +2,7 @@ package versionnegotiation
import (
"context"
"net"
"testing"
"time"
@ -32,12 +33,17 @@ func TestVersionNegotiationFailure(t *testing.T) {
require.NoError(t, err)
defer ln.Close()
// start the proxy
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
})
proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer proxyConn.Close()
// start the proxy
proxy := quicproxy.Proxy{
Conn: proxyConn,
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
startTime := time.Now()
_, err = quic.DialAddr(