mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
proxy: add function to simulate NAT rebinding (#4922)
This commit is contained in:
parent
79bae396b4
commit
3e87ea3f50
3 changed files with 149 additions and 35 deletions
|
@ -26,8 +26,7 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
|||
var statelessResetKey quic.StatelessResetKey
|
||||
rand.Read(statelessResetKey[:])
|
||||
|
||||
c, err := net.ListenUDP("udp", nil)
|
||||
require.NoError(t, err)
|
||||
c := newUPDConnLocalhost(t)
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
|
@ -61,10 +60,9 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
|||
proxy := quicproxy.Proxy{
|
||||
Conn: newUPDConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
cl := &quic.Transport{
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -13,6 +16,9 @@ import (
|
|||
// Connection is a UDP connection
|
||||
type connection struct {
|
||||
ClientAddr *net.UDPAddr // Address of the client
|
||||
ServerAddr *net.UDPAddr // Address of the server
|
||||
|
||||
mx sync.Mutex
|
||||
ServerConn *net.UDPConn // UDP connection to server
|
||||
|
||||
incomingPackets chan packetEntry
|
||||
|
@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) {
|
|||
c.incomingPackets <- packetEntry{Time: t, Raw: b}
|
||||
}
|
||||
|
||||
func (c *connection) SwitchConn(conn *net.UDPConn) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
old := c.ServerConn
|
||||
old.SetReadDeadline(time.Now())
|
||||
c.ServerConn = conn
|
||||
}
|
||||
|
||||
func (c *connection) GetServerConn() *net.UDPConn {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
return c.ServerConn
|
||||
}
|
||||
|
||||
// Direction is the direction a packet is sent.
|
||||
type Direction int
|
||||
|
||||
|
@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration
|
|||
|
||||
// Proxy is a QUIC proxy that can drop and delay packets.
|
||||
type Proxy struct {
|
||||
// Conn is the UDP socket that the proxy listens on for incoming packets
|
||||
// from clients.
|
||||
// Conn is the UDP socket that the proxy listens on for incoming packets from clients.
|
||||
Conn *net.UDPConn
|
||||
|
||||
// ServerAddr is the address of the server that the proxy forwards packets to.
|
||||
|
@ -139,7 +160,6 @@ type Proxy struct {
|
|||
clientDict map[string]*connection
|
||||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func (p *Proxy) Start() error {
|
||||
p.clientDict = make(map[string]*connection)
|
||||
p.closeChan = make(chan struct{})
|
||||
|
@ -157,6 +177,25 @@ func (p *Proxy) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SwitchConn switches the connection for a client,
|
||||
// identified the address that the client is sending from.
|
||||
func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error {
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
|
||||
return err
|
||||
}
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
c, ok := p.clientDict[clientAddr.String()]
|
||||
if !ok {
|
||||
return fmt.Errorf("client %s not found", clientAddr)
|
||||
}
|
||||
c.SwitchConn(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops the UDP Proxy
|
||||
func (p *Proxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
|
@ -164,7 +203,7 @@ func (p *Proxy) Close() error {
|
|||
|
||||
close(p.closeChan)
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
if err := c.GetServerConn().Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Incoming.Close()
|
||||
|
@ -177,7 +216,7 @@ func (p *Proxy) Close() error {
|
|||
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
|
||||
|
||||
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|||
}
|
||||
return &connection{
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: conn,
|
||||
ServerAddr: p.ServerAddr,
|
||||
incomingPackets: make(chan packetEntry, 10),
|
||||
Incoming: newQueue(),
|
||||
Outgoing: newQueue(),
|
||||
ServerConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
raw := buffer[:n]
|
||||
|
||||
saddr := cliaddr.String()
|
||||
p.mutex.Lock()
|
||||
conn, ok := p.clientDict[saddr]
|
||||
conn, ok := p.clientDict[cliaddr.String()]
|
||||
|
||||
if !ok {
|
||||
conn, err = p.newConnection(cliaddr)
|
||||
|
@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error {
|
|||
p.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
p.clientDict[saddr] = conn
|
||||
p.clientDict[cliaddr.String()] = conn
|
||||
go p.runIncomingConnection(conn)
|
||||
go p.runOutgoingConnection(conn)
|
||||
}
|
||||
|
@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error {
|
|||
}
|
||||
if delay == 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr)
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
now := time.Now()
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay)
|
||||
}
|
||||
conn.queuePacket(now.Add(delay), raw)
|
||||
}
|
||||
|
@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error {
|
|||
go func() {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketBufferSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
n, err := conn.GetServerConn().Read(buffer)
|
||||
if err != nil {
|
||||
// when the connection is switched out, we set a deadline on the old connection,
|
||||
// in order to return it immediately
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error {
|
|||
conn.Incoming.Add(e)
|
||||
case <-conn.Incoming.Timer():
|
||||
conn.Incoming.SetTimerRead()
|
||||
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
|
||||
if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,6 @@ func newUPDConnLocalhost(t testing.TB) *net.UDPConn {
|
|||
return conn
|
||||
}
|
||||
|
||||
type packetData []byte
|
||||
|
||||
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
|
||||
t.Helper()
|
||||
hdr := wire.ExtendedHeader{
|
||||
|
@ -54,11 +52,10 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
|
|||
|
||||
// Set up a dumb UDP server.
|
||||
// In production this would be a QUIC server.
|
||||
func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
||||
serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
func runServer(t *testing.T) (*net.UDPAddr, chan []byte) {
|
||||
serverConn := newUPDConnLocalhost(t)
|
||||
|
||||
serverReceivedPackets := make(chan packetData, 100)
|
||||
serverReceivedPackets := make(chan []byte, 100)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
@ -69,17 +66,15 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
data := buf[:n]
|
||||
serverReceivedPackets <- packetData(data)
|
||||
if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet
|
||||
|
||||
serverReceivedPackets <- buf[:n]
|
||||
// echo the packet
|
||||
if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, serverConn.Close())
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
|
@ -90,7 +85,7 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
|||
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
|
||||
}
|
||||
|
||||
func TestProxyyingBackAndForth(t *testing.T) {
|
||||
func TestProxyingBackAndForth(t *testing.T) {
|
||||
serverAddr, _ := runServer(t)
|
||||
proxy := Proxy{
|
||||
Conn: newUPDConnLocalhost(t),
|
||||
|
@ -179,7 +174,6 @@ func TestDropOutgoingPackets(t *testing.T) {
|
|||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxPacketBufferSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -355,17 +349,16 @@ func TestDelayOutgoingPackets(t *testing.T) {
|
|||
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
||||
require.NoError(t, err)
|
||||
|
||||
clientReceivedPackets := make(chan packetData, numPackets)
|
||||
clientReceivedPackets := make(chan []byte, numPackets)
|
||||
// receive the packets echoed by the server on client side
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxPacketBufferSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
n, _, err := clientConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
clientReceivedPackets <- packetData(buf[0:n])
|
||||
clientReceivedPackets <- buf[:n]
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -394,3 +387,82 @@ func TestDelayOutgoingPackets(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxySwitchConn(t *testing.T) {
|
||||
serverConn := newUPDConnLocalhost(t)
|
||||
|
||||
type packet struct {
|
||||
Data []byte
|
||||
Addr *net.UDPAddr
|
||||
}
|
||||
|
||||
serverReceivedPackets := make(chan packet, 1)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
buf := make([]byte, 1000)
|
||||
n, addr, err := serverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
serverReceivedPackets <- packet{Data: buf[:n], Addr: addr}
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := Proxy{
|
||||
Conn: newUPDConnLocalhost(t),
|
||||
ServerAddr: serverConn.LocalAddr().(*net.UDPAddr),
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
clientConn := newUPDConnLocalhost(t)
|
||||
_, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr))
|
||||
require.NoError(t, err)
|
||||
clientConn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
|
||||
var firstConnAddr *net.UDPAddr
|
||||
select {
|
||||
case p := <-serverReceivedPackets:
|
||||
require.Equal(t, "hello", string(p.Data))
|
||||
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
||||
firstConnAddr = p.Addr
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
|
||||
_, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr)
|
||||
require.NoError(t, err)
|
||||
buf := make([]byte, 1000)
|
||||
n, addr, err := clientConn.ReadFromUDP(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hi", string(buf[:n]))
|
||||
require.Equal(t, proxy.LocalAddr(), addr)
|
||||
|
||||
newConn := newUPDConnLocalhost(t)
|
||||
require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn))
|
||||
|
||||
_, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr))
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case p := <-serverReceivedPackets:
|
||||
require.Equal(t, "foobar", string(p.Data))
|
||||
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
||||
require.NotEqual(t, firstConnAddr, p.Addr)
|
||||
require.Equal(t, newConn.LocalAddr(), p.Addr)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
|
||||
// the old connection doesn't deliver any packets to the client anymore
|
||||
_, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr)
|
||||
require.NoError(t, err)
|
||||
_, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
n, addr, err = clientConn.ReadFromUDP(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered
|
||||
require.Equal(t, proxy.LocalAddr(), addr)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue