mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
proxy: add function to simulate NAT rebinding (#4922)
This commit is contained in:
parent
79bae396b4
commit
3e87ea3f50
3 changed files with 149 additions and 35 deletions
|
@ -26,8 +26,7 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
||||||
var statelessResetKey quic.StatelessResetKey
|
var statelessResetKey quic.StatelessResetKey
|
||||||
rand.Read(statelessResetKey[:])
|
rand.Read(statelessResetKey[:])
|
||||||
|
|
||||||
c, err := net.ListenUDP("udp", nil)
|
c := newUPDConnLocalhost(t)
|
||||||
require.NoError(t, err)
|
|
||||||
tr := &quic.Transport{
|
tr := &quic.Transport{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
StatelessResetKey: &statelessResetKey,
|
StatelessResetKey: &statelessResetKey,
|
||||||
|
@ -61,10 +60,9 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
||||||
proxy := quicproxy.Proxy{
|
proxy := quicproxy.Proxy{
|
||||||
Conn: newUPDConnLocalhost(t),
|
Conn: newUPDConnLocalhost(t),
|
||||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||||
DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() },
|
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
|
||||||
}
|
}
|
||||||
require.NoError(t, proxy.Start())
|
require.NoError(t, proxy.Start())
|
||||||
require.NoError(t, err)
|
|
||||||
defer proxy.Close()
|
defer proxy.Close()
|
||||||
|
|
||||||
cl := &quic.Transport{
|
cl := &quic.Transport{
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
package quicproxy
|
package quicproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -13,6 +16,9 @@ import (
|
||||||
// Connection is a UDP connection
|
// Connection is a UDP connection
|
||||||
type connection struct {
|
type connection struct {
|
||||||
ClientAddr *net.UDPAddr // Address of the client
|
ClientAddr *net.UDPAddr // Address of the client
|
||||||
|
ServerAddr *net.UDPAddr // Address of the server
|
||||||
|
|
||||||
|
mx sync.Mutex
|
||||||
ServerConn *net.UDPConn // UDP connection to server
|
ServerConn *net.UDPConn // UDP connection to server
|
||||||
|
|
||||||
incomingPackets chan packetEntry
|
incomingPackets chan packetEntry
|
||||||
|
@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) {
|
||||||
c.incomingPackets <- packetEntry{Time: t, Raw: b}
|
c.incomingPackets <- packetEntry{Time: t, Raw: b}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *connection) SwitchConn(conn *net.UDPConn) {
|
||||||
|
c.mx.Lock()
|
||||||
|
defer c.mx.Unlock()
|
||||||
|
|
||||||
|
old := c.ServerConn
|
||||||
|
old.SetReadDeadline(time.Now())
|
||||||
|
c.ServerConn = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) GetServerConn() *net.UDPConn {
|
||||||
|
c.mx.Lock()
|
||||||
|
defer c.mx.Unlock()
|
||||||
|
|
||||||
|
return c.ServerConn
|
||||||
|
}
|
||||||
|
|
||||||
// Direction is the direction a packet is sent.
|
// Direction is the direction a packet is sent.
|
||||||
type Direction int
|
type Direction int
|
||||||
|
|
||||||
|
@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration
|
||||||
|
|
||||||
// Proxy is a QUIC proxy that can drop and delay packets.
|
// Proxy is a QUIC proxy that can drop and delay packets.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
// Conn is the UDP socket that the proxy listens on for incoming packets
|
// Conn is the UDP socket that the proxy listens on for incoming packets from clients.
|
||||||
// from clients.
|
|
||||||
Conn *net.UDPConn
|
Conn *net.UDPConn
|
||||||
|
|
||||||
// ServerAddr is the address of the server that the proxy forwards packets to.
|
// ServerAddr is the address of the server that the proxy forwards packets to.
|
||||||
|
@ -139,7 +160,6 @@ type Proxy struct {
|
||||||
clientDict map[string]*connection
|
clientDict map[string]*connection
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQuicProxy creates a new UDP proxy
|
|
||||||
func (p *Proxy) Start() error {
|
func (p *Proxy) Start() error {
|
||||||
p.clientDict = make(map[string]*connection)
|
p.clientDict = make(map[string]*connection)
|
||||||
p.closeChan = make(chan struct{})
|
p.closeChan = make(chan struct{})
|
||||||
|
@ -157,6 +177,25 @@ func (p *Proxy) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SwitchConn switches the connection for a client,
|
||||||
|
// identified the address that the client is sending from.
|
||||||
|
func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error {
|
||||||
|
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.mutex.Lock()
|
||||||
|
defer p.mutex.Unlock()
|
||||||
|
c, ok := p.clientDict[clientAddr.String()]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("client %s not found", clientAddr)
|
||||||
|
}
|
||||||
|
c.SwitchConn(conn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close stops the UDP Proxy
|
// Close stops the UDP Proxy
|
||||||
func (p *Proxy) Close() error {
|
func (p *Proxy) Close() error {
|
||||||
p.mutex.Lock()
|
p.mutex.Lock()
|
||||||
|
@ -164,7 +203,7 @@ func (p *Proxy) Close() error {
|
||||||
|
|
||||||
close(p.closeChan)
|
close(p.closeChan)
|
||||||
for _, c := range p.clientDict {
|
for _, c := range p.clientDict {
|
||||||
if err := c.ServerConn.Close(); err != nil {
|
if err := c.GetServerConn().Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.Incoming.Close()
|
c.Incoming.Close()
|
||||||
|
@ -177,7 +216,7 @@ func (p *Proxy) Close() error {
|
||||||
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
|
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
|
||||||
|
|
||||||
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||||
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||||
}
|
}
|
||||||
return &connection{
|
return &connection{
|
||||||
ClientAddr: cliAddr,
|
ClientAddr: cliAddr,
|
||||||
ServerConn: conn,
|
ServerAddr: p.ServerAddr,
|
||||||
incomingPackets: make(chan packetEntry, 10),
|
incomingPackets: make(chan packetEntry, 10),
|
||||||
Incoming: newQueue(),
|
Incoming: newQueue(),
|
||||||
Outgoing: newQueue(),
|
Outgoing: newQueue(),
|
||||||
|
ServerConn: conn,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
raw := buffer[0:n]
|
raw := buffer[:n]
|
||||||
|
|
||||||
saddr := cliaddr.String()
|
|
||||||
p.mutex.Lock()
|
p.mutex.Lock()
|
||||||
conn, ok := p.clientDict[saddr]
|
conn, ok := p.clientDict[cliaddr.String()]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
conn, err = p.newConnection(cliaddr)
|
conn, err = p.newConnection(cliaddr)
|
||||||
|
@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error {
|
||||||
p.mutex.Unlock()
|
p.mutex.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.clientDict[saddr] = conn
|
p.clientDict[cliaddr.String()] = conn
|
||||||
go p.runIncomingConnection(conn)
|
go p.runIncomingConnection(conn)
|
||||||
go p.runOutgoingConnection(conn)
|
go p.runOutgoingConnection(conn)
|
||||||
}
|
}
|
||||||
|
@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error {
|
||||||
}
|
}
|
||||||
if delay == 0 {
|
if delay == 0 {
|
||||||
if p.logger.Debug() {
|
if p.logger.Debug() {
|
||||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
|
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr)
|
||||||
}
|
}
|
||||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if p.logger.Debug() {
|
if p.logger.Debug() {
|
||||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
|
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay)
|
||||||
}
|
}
|
||||||
conn.queuePacket(now.Add(delay), raw)
|
conn.queuePacket(now.Add(delay), raw)
|
||||||
}
|
}
|
||||||
|
@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, protocol.MaxPacketBufferSize)
|
buffer := make([]byte, protocol.MaxPacketBufferSize)
|
||||||
n, err := conn.ServerConn.Read(buffer)
|
n, err := conn.GetServerConn().Read(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// when the connection is switched out, we set a deadline on the old connection,
|
||||||
|
// in order to return it immediately
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
raw := buffer[0:n]
|
raw := buffer[0:n]
|
||||||
|
@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error {
|
||||||
conn.Incoming.Add(e)
|
conn.Incoming.Add(e)
|
||||||
case <-conn.Incoming.Timer():
|
case <-conn.Incoming.Timer():
|
||||||
conn.Incoming.SetTimerRead()
|
conn.Incoming.SetTimerRead()
|
||||||
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
|
if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,6 @@ func newUPDConnLocalhost(t testing.TB) *net.UDPConn {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type packetData []byte
|
|
||||||
|
|
||||||
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
|
func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
hdr := wire.ExtendedHeader{
|
hdr := wire.ExtendedHeader{
|
||||||
|
@ -54,11 +52,10 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber {
|
||||||
|
|
||||||
// Set up a dumb UDP server.
|
// Set up a dumb UDP server.
|
||||||
// In production this would be a QUIC server.
|
// In production this would be a QUIC server.
|
||||||
func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
func runServer(t *testing.T) (*net.UDPAddr, chan []byte) {
|
||||||
serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
serverConn := newUPDConnLocalhost(t)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverReceivedPackets := make(chan packetData, 100)
|
serverReceivedPackets := make(chan []byte, 100)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
@ -69,17 +66,15 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
data := buf[:n]
|
serverReceivedPackets <- buf[:n]
|
||||||
serverReceivedPackets <- packetData(data)
|
// echo the packet
|
||||||
if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet
|
if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, serverConn.Close())
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
|
@ -90,7 +85,7 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) {
|
||||||
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
|
return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyyingBackAndForth(t *testing.T) {
|
func TestProxyingBackAndForth(t *testing.T) {
|
||||||
serverAddr, _ := runServer(t)
|
serverAddr, _ := runServer(t)
|
||||||
proxy := Proxy{
|
proxy := Proxy{
|
||||||
Conn: newUPDConnLocalhost(t),
|
Conn: newUPDConnLocalhost(t),
|
||||||
|
@ -179,7 +174,6 @@ func TestDropOutgoingPackets(t *testing.T) {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, protocol.MaxPacketBufferSize)
|
buf := make([]byte, protocol.MaxPacketBufferSize)
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
|
if _, _, err := clientConn.ReadFromUDP(buf); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -355,17 +349,16 @@ func TestDelayOutgoingPackets(t *testing.T) {
|
||||||
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientReceivedPackets := make(chan packetData, numPackets)
|
clientReceivedPackets := make(chan []byte, numPackets)
|
||||||
// receive the packets echoed by the server on client side
|
// receive the packets echoed by the server on client side
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, protocol.MaxPacketBufferSize)
|
buf := make([]byte, protocol.MaxPacketBufferSize)
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
n, _, err := clientConn.ReadFromUDP(buf)
|
n, _, err := clientConn.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clientReceivedPackets <- packetData(buf[0:n])
|
clientReceivedPackets <- buf[:n]
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -394,3 +387,82 @@ func TestDelayOutgoingPackets(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxySwitchConn(t *testing.T) {
|
||||||
|
serverConn := newUPDConnLocalhost(t)
|
||||||
|
|
||||||
|
type packet struct {
|
||||||
|
Data []byte
|
||||||
|
Addr *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
serverReceivedPackets := make(chan packet, 1)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
buf := make([]byte, 1000)
|
||||||
|
n, addr, err := serverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverReceivedPackets <- packet{Data: buf[:n], Addr: addr}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := Proxy{
|
||||||
|
Conn: newUPDConnLocalhost(t),
|
||||||
|
ServerAddr: serverConn.LocalAddr().(*net.UDPAddr),
|
||||||
|
}
|
||||||
|
require.NoError(t, proxy.Start())
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
clientConn := newUPDConnLocalhost(t)
|
||||||
|
_, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
clientConn.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
|
||||||
|
var firstConnAddr *net.UDPAddr
|
||||||
|
select {
|
||||||
|
case p := <-serverReceivedPackets:
|
||||||
|
require.Equal(t, "hello", string(p.Data))
|
||||||
|
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
||||||
|
firstConnAddr = p.Addr
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
buf := make([]byte, 1000)
|
||||||
|
n, addr, err := clientConn.ReadFromUDP(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "hi", string(buf[:n]))
|
||||||
|
require.Equal(t, proxy.LocalAddr(), addr)
|
||||||
|
|
||||||
|
newConn := newUPDConnLocalhost(t)
|
||||||
|
require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn))
|
||||||
|
|
||||||
|
_, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p := <-serverReceivedPackets:
|
||||||
|
require.Equal(t, "foobar", string(p.Data))
|
||||||
|
require.NotEqual(t, clientConn.LocalAddr(), p.Addr)
|
||||||
|
require.NotEqual(t, firstConnAddr, p.Addr)
|
||||||
|
require.Equal(t, newConn.LocalAddr(), p.Addr)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// the old connection doesn't deliver any packets to the client anymore
|
||||||
|
_, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
n, addr, err = clientConn.ReadFromUDP(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered
|
||||||
|
require.Equal(t, proxy.LocalAddr(), addr)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue