diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 14babc64..ce573063 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -26,8 +26,7 @@ func testStatelessReset(t *testing.T, connIDLen int) { var statelessResetKey quic.StatelessResetKey rand.Read(statelessResetKey[:]) - c, err := net.ListenUDP("udp", nil) - require.NoError(t, err) + c := newUPDConnLocalhost(t) tr := &quic.Transport{ Conn: c, StatelessResetKey: &statelessResetKey, @@ -61,10 +60,9 @@ func testStatelessReset(t *testing.T, connIDLen int) { proxy := quicproxy.Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), - DropPacket: func(_ quicproxy.Direction, _ []byte) bool { return drop.Load() }, + DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() }, } require.NoError(t, proxy.Start()) - require.NoError(t, err) defer proxy.Close() cl := &quic.Transport{ diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index 2ab49085..d18e3276 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -1,7 +1,10 @@ package quicproxy import ( + "errors" + "fmt" "net" + "os" "sort" "sync" "time" @@ -13,6 +16,9 @@ import ( // Connection is a UDP connection type connection struct { ClientAddr *net.UDPAddr // Address of the client + ServerAddr *net.UDPAddr // Address of the server + + mx sync.Mutex ServerConn *net.UDPConn // UDP connection to server incomingPackets chan packetEntry @@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) { c.incomingPackets <- packetEntry{Time: t, Raw: b} } +func (c *connection) SwitchConn(conn *net.UDPConn) { + c.mx.Lock() + defer c.mx.Unlock() + + old := c.ServerConn + old.SetReadDeadline(time.Now()) + c.ServerConn = conn +} + +func (c *connection) GetServerConn() *net.UDPConn { + c.mx.Lock() + defer c.mx.Unlock() + + return c.ServerConn +} + // Direction is the direction a packet is sent. type Direction int @@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration // Proxy is a QUIC proxy that can drop and delay packets. type Proxy struct { - // Conn is the UDP socket that the proxy listens on for incoming packets - // from clients. + // Conn is the UDP socket that the proxy listens on for incoming packets from clients. Conn *net.UDPConn // ServerAddr is the address of the server that the proxy forwards packets to. @@ -139,7 +160,6 @@ type Proxy struct { clientDict map[string]*connection } -// NewQuicProxy creates a new UDP proxy func (p *Proxy) Start() error { p.clientDict = make(map[string]*connection) p.closeChan = make(chan struct{}) @@ -157,6 +177,25 @@ func (p *Proxy) Start() error { return nil } +// SwitchConn switches the connection for a client, +// identified the address that the client is sending from. +func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error { + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return err + } + if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { + return err + } + p.mutex.Lock() + defer p.mutex.Unlock() + c, ok := p.clientDict[clientAddr.String()] + if !ok { + return fmt.Errorf("client %s not found", clientAddr) + } + c.SwitchConn(conn) + return nil +} + // Close stops the UDP Proxy func (p *Proxy) Close() error { p.mutex.Lock() @@ -164,7 +203,7 @@ func (p *Proxy) Close() error { close(p.closeChan) for _, c := range p.clientDict { - if err := c.ServerConn.Close(); err != nil { + if err := c.GetServerConn().Close(); err != nil { return err } c.Incoming.Close() @@ -177,7 +216,7 @@ func (p *Proxy) Close() error { func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() } func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { - conn, err := net.DialUDP("udp", nil, p.ServerAddr) + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) if err != nil { return nil, err } @@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { } return &connection{ ClientAddr: cliAddr, - ServerConn: conn, + ServerAddr: p.ServerAddr, incomingPackets: make(chan packetEntry, 10), Incoming: newQueue(), Outgoing: newQueue(), + ServerConn: conn, }, nil } @@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error { if err != nil { return err } - raw := buffer[0:n] + raw := buffer[:n] - saddr := cliaddr.String() p.mutex.Lock() - conn, ok := p.clientDict[saddr] + conn, ok := p.clientDict[cliaddr.String()] if !ok { conn, err = p.newConnection(cliaddr) @@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error { p.mutex.Unlock() return err } - p.clientDict[saddr] = conn + p.clientDict[cliaddr.String()] = conn go p.runIncomingConnection(conn) go p.runOutgoingConnection(conn) } @@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error { } if delay == 0 { if p.logger.Debug() { - p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr()) + p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr) } - if _, err := conn.ServerConn.Write(raw); err != nil { + if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil { return err } } else { now := time.Now() if p.logger.Debug() { - p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay) + p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay) } conn.queuePacket(now.Add(delay), raw) } @@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error { go func() { for { buffer := make([]byte, protocol.MaxPacketBufferSize) - n, err := conn.ServerConn.Read(buffer) + n, err := conn.GetServerConn().Read(buffer) if err != nil { + // when the connection is switched out, we set a deadline on the old connection, + // in order to return it immediately + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } return } raw := buffer[0:n] @@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error { conn.Incoming.Add(e) case <-conn.Incoming.Timer(): conn.Incoming.SetTimerRead() - if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil { + if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil { return err } } diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 39d0f8cb..ac963221 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -21,8 +21,6 @@ func newUPDConnLocalhost(t testing.TB) *net.UDPConn { return conn } -type packetData []byte - func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte { t.Helper() hdr := wire.ExtendedHeader{ @@ -54,11 +52,10 @@ func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber { // Set up a dumb UDP server. // In production this would be a QUIC server. -func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { - serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) - require.NoError(t, err) +func runServer(t *testing.T) (*net.UDPAddr, chan []byte) { + serverConn := newUPDConnLocalhost(t) - serverReceivedPackets := make(chan packetData, 100) + serverReceivedPackets := make(chan []byte, 100) done := make(chan struct{}) go func() { defer close(done) @@ -69,17 +66,15 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { if err != nil { return } - data := buf[:n] - serverReceivedPackets <- packetData(data) - if _, err := serverConn.WriteToUDP(data, addr); err != nil { // echo the packet - + serverReceivedPackets <- buf[:n] + // echo the packet + if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil { return } } }() t.Cleanup(func() { - require.NoError(t, serverConn.Close()) select { case <-done: case <-time.After(time.Second): @@ -90,7 +85,7 @@ func runServer(t *testing.T) (*net.UDPAddr, chan packetData) { return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets } -func TestProxyyingBackAndForth(t *testing.T) { +func TestProxyingBackAndForth(t *testing.T) { serverAddr, _ := runServer(t) proxy := Proxy{ Conn: newUPDConnLocalhost(t), @@ -179,7 +174,6 @@ func TestDropOutgoingPackets(t *testing.T) { go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) - // the ReadFromUDP will error as soon as the UDP conn is closed if _, _, err := clientConn.ReadFromUDP(buf); err != nil { return } @@ -355,17 +349,16 @@ func TestDelayOutgoingPackets(t *testing.T) { clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) - clientReceivedPackets := make(chan packetData, numPackets) + clientReceivedPackets := make(chan []byte, numPackets) // receive the packets echoed by the server on client side go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) - // the ReadFromUDP will error as soon as the UDP conn is closed n, _, err := clientConn.ReadFromUDP(buf) if err != nil { return } - clientReceivedPackets <- packetData(buf[0:n]) + clientReceivedPackets <- buf[:n] } }() @@ -394,3 +387,82 @@ func TestDelayOutgoingPackets(t *testing.T) { } } } + +func TestProxySwitchConn(t *testing.T) { + serverConn := newUPDConnLocalhost(t) + + type packet struct { + Data []byte + Addr *net.UDPAddr + } + + serverReceivedPackets := make(chan packet, 1) + done := make(chan struct{}) + go func() { + defer close(done) + for { + buf := make([]byte, 1000) + n, addr, err := serverConn.ReadFromUDP(buf) + if err != nil { + return + } + serverReceivedPackets <- packet{Data: buf[:n], Addr: addr} + } + }() + + proxy := Proxy{ + Conn: newUPDConnLocalhost(t), + ServerAddr: serverConn.LocalAddr().(*net.UDPAddr), + } + require.NoError(t, proxy.Start()) + defer proxy.Close() + + clientConn := newUPDConnLocalhost(t) + _, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + clientConn.SetReadDeadline(time.Now().Add(time.Second)) + + var firstConnAddr *net.UDPAddr + select { + case p := <-serverReceivedPackets: + require.Equal(t, "hello", string(p.Data)) + require.NotEqual(t, clientConn.LocalAddr(), p.Addr) + firstConnAddr = p.Addr + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + _, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr) + require.NoError(t, err) + buf := make([]byte, 1000) + n, addr, err := clientConn.ReadFromUDP(buf) + require.NoError(t, err) + require.Equal(t, "hi", string(buf[:n])) + require.Equal(t, proxy.LocalAddr(), addr) + + newConn := newUPDConnLocalhost(t) + require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn)) + + _, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + + select { + case p := <-serverReceivedPackets: + require.Equal(t, "foobar", string(p.Data)) + require.NotEqual(t, clientConn.LocalAddr(), p.Addr) + require.NotEqual(t, firstConnAddr, p.Addr) + require.Equal(t, newConn.LocalAddr(), p.Addr) + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // the old connection doesn't deliver any packets to the client anymore + _, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr) + require.NoError(t, err) + _, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr()) + require.NoError(t, err) + n, addr, err = clientConn.ReadFromUDP(buf) + require.NoError(t, err) + require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered + require.Equal(t, proxy.LocalAddr(), addr) +}