proxy: add function to simulate NAT rebinding (#4922)

This commit is contained in:
Marten Seemann 2025-01-26 05:03:08 +01:00 committed by GitHub
parent 79bae396b4
commit 3e87ea3f50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 149 additions and 35 deletions

View file

@ -26,8 +26,7 @@ func testStatelessReset(t *testing.T, connIDLen int) {
var statelessResetKey quic.StatelessResetKey
rand.Read(statelessResetKey[:])
c, err := net.ListenUDP("udp", nil)
require.NoError(t, err)
c := newUPDConnLocalhost(t)
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
@ -61,10 +60,9 @@ func testStatelessReset(t *testing.T, connIDLen int) {
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
}
require.NoError(t, proxy.Start())
require.NoError(t, err)
defer proxy.Close()
cl := &quic.Transport{

View file

@ -1,7 +1,10 @@
package quicproxy
import (
"errors"
"fmt"
"net"
"os"
"sort"
"sync"
"time"
@ -13,6 +16,9 @@ import (
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerAddr *net.UDPAddr // Address of the server
mx sync.Mutex
ServerConn *net.UDPConn // UDP connection to server
incomingPackets chan packetEntry
@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) {
c.incomingPackets <- packetEntry{Time: t, Raw: b}
}
func (c *connection) SwitchConn(conn *net.UDPConn) {
c.mx.Lock()
defer c.mx.Unlock()
old := c.ServerConn
old.SetReadDeadline(time.Now())
c.ServerConn = conn
}
func (c *connection) GetServerConn() *net.UDPConn {
c.mx.Lock()
defer c.mx.Unlock()
return c.ServerConn
}
// Direction is the direction a packet is sent.
type Direction int
@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration
// Proxy is a QUIC proxy that can drop and delay packets.
type Proxy struct {
// Conn is the UDP socket that the proxy listens on for incoming packets
// from clients.
// Conn is the UDP socket that the proxy listens on for incoming packets from clients.
Conn *net.UDPConn
// ServerAddr is the address of the server that the proxy forwards packets to.
@ -139,7 +160,6 @@ type Proxy struct {
clientDict map[string]*connection
}
// NewQuicProxy creates a new UDP proxy
func (p *Proxy) Start() error {
p.clientDict = make(map[string]*connection)
p.closeChan = make(chan struct{})
@ -157,6 +177,25 @@ func (p *Proxy) Start() error {
return nil
}
// SwitchConn switches the connection for a client,
// identified the address that the client is sending from.
func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error {
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return err
}
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
return err
}
p.mutex.Lock()
defer p.mutex.Unlock()
c, ok := p.clientDict[clientAddr.String()]
if !ok {
return fmt.Errorf("client %s not found", clientAddr)
}
c.SwitchConn(conn)
return nil
}
// Close stops the UDP Proxy
func (p *Proxy) Close() error {
p.mutex.Lock()
@ -164,7 +203,7 @@ func (p *Proxy) Close() error {
close(p.closeChan)
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
if err := c.GetServerConn().Close(); err != nil {
return err
}
c.Incoming.Close()
@ -177,7 +216,7 @@ func (p *Proxy) Close() error {
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
return nil, err
}
@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
}
return &connection{
ClientAddr: cliAddr,
ServerConn: conn,
ServerAddr: p.ServerAddr,
incomingPackets: make(chan packetEntry, 10),
Incoming: newQueue(),
Outgoing: newQueue(),
ServerConn: conn,
}, nil
}
@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error {
if err != nil {
return err
}
raw := buffer[0:n]
raw := buffer[:n]
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
conn, ok := p.clientDict[cliaddr.String()]
if !ok {
conn, err = p.newConnection(cliaddr)
@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
p.clientDict[cliaddr.String()] = conn
go p.runIncomingConnection(conn)
go p.runOutgoingConnection(conn)
}
@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error {
}
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr)
}
if _, err := conn.ServerConn.Write(raw); err != nil {
if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil {
return err
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay)
}
conn.queuePacket(now.Add(delay), raw)
}
@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error {
go func() {
for {
buffer := make([]byte, protocol.MaxPacketBufferSize)
n, err := conn.ServerConn.Read(buffer)
n, err := conn.GetServerConn().Read(buffer)
if err != nil {
// when the connection is switched out, we set a deadline on the old connection,
// in order to return it immediately
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
return
}
raw := buffer[0:n]
@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error {
conn.Incoming.Add(e)
case <-conn.Incoming.Timer():
conn.Incoming.SetTimerRead()
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil {
return err
}
}

View file

@ -21,8 +21,6 @@ func newUPDConnLocalhost(t testing.TB) *net.UDPConn {
return conn
}
type packetData []byte
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
t.Helper()
hdr := wire.ExtendedHeader{
@ -54,11 +52,10 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
// Set up a dumb UDP server.
// In production this would be a QUIC server.
func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
func runServer(t *testing.T) (*net.UDPAddr, chan []byte) {
serverConn := newUPDConnLocalhost(t)
serverReceivedPackets := make(chan packetData, 100)
serverReceivedPackets := make(chan []byte, 100)
done := make(chan struct{})
go func() {
defer close(done)
@ -69,17 +66,15 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
if err != nil {
return
}
data := buf[:n]
serverReceivedPackets <- packetData(data)
if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet
serverReceivedPackets <- buf[:n]
// echo the packet
if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil {
return
}
}
}()
t.Cleanup(func() {
require.NoError(t, serverConn.Close())
select {
case <-done:
case <-time.After(time.Second):
@ -90,7 +85,7 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
}
func TestProxyyingBackAndForth(t *testing.T) {
func TestProxyingBackAndForth(t *testing.T) {
serverAddr, _ := runServer(t)
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
@ -179,7 +174,6 @@ func TestDropOutgoingPackets(t *testing.T) {
go func() {
for {
buf := make([]byte, protocol.MaxPacketBufferSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
return
}
@ -355,17 +349,16 @@ func TestDelayOutgoingPackets(t *testing.T) {
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
clientReceivedPackets := make(chan packetData, numPackets)
clientReceivedPackets := make(chan []byte, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxPacketBufferSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err := clientConn.ReadFromUDP(buf)
if err != nil {
return
}
clientReceivedPackets <- packetData(buf[0:n])
clientReceivedPackets <- buf[:n]
}
}()
@ -394,3 +387,82 @@ func TestDelayOutgoingPackets(t *testing.T) {
}
}
}
func TestProxySwitchConn(t *testing.T) {
serverConn := newUPDConnLocalhost(t)
type packet struct {
Data []byte
Addr *net.UDPAddr
}
serverReceivedPackets := make(chan packet, 1)
done := make(chan struct{})
go func() {
defer close(done)
for {
buf := make([]byte, 1000)
n, addr, err := serverConn.ReadFromUDP(buf)
if err != nil {
return
}
serverReceivedPackets <- packet{Data: buf[:n], Addr: addr}
}
}()
proxy := Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: serverConn.LocalAddr().(*net.UDPAddr),
}
require.NoError(t, proxy.Start())
defer proxy.Close()
clientConn := newUPDConnLocalhost(t)
_, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
clientConn.SetReadDeadline(time.Now().Add(time.Second))
var firstConnAddr *net.UDPAddr
select {
case p := <-serverReceivedPackets:
require.Equal(t, "hello", string(p.Data))
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
firstConnAddr = p.Addr
case <-time.After(time.Second):
t.Fatalf("timeout")
}
_, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr)
require.NoError(t, err)
buf := make([]byte, 1000)
n, addr, err := clientConn.ReadFromUDP(buf)
require.NoError(t, err)
require.Equal(t, "hi", string(buf[:n]))
require.Equal(t, proxy.LocalAddr(), addr)
newConn := newUPDConnLocalhost(t)
require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn))
_, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
select {
case p := <-serverReceivedPackets:
require.Equal(t, "foobar", string(p.Data))
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
require.NotEqual(t, firstConnAddr, p.Addr)
require.Equal(t, newConn.LocalAddr(), p.Addr)
case <-time.After(time.Second):
t.Fatalf("timeout")
}
// the old connection doesn't deliver any packets to the client anymore
_, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr)
require.NoError(t, err)
_, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr())
require.NoError(t, err)
n, addr, err = clientConn.ReadFromUDP(buf)
require.NoError(t, err)
require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered
require.Equal(t, proxy.LocalAddr(), addr)
}