uquic/integrationtests/self/handshake_drop_test.go
Marten Seemann 79bae396b4
proxy: rename to Proxy, refactor initialization (#4921)
* proxy: rename to Proxy, refactor initialization

* improve documentation
2025-01-25 11:13:33 +01:00

285 lines
8 KiB
Go

package self_test
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
mrand "math/rand"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
"github.com/stretchr/testify/require"
)
func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (_ *quic.Listener, proxyAddr net.Addr) {
t.Helper()
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
})
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
tr := &quic.Transport{
Conn: newUPDConnLocalhost(t),
VerifySourceAddress: func(net.Addr) bool { return doRetry },
}
t.Cleanup(func() { tr.Close() })
ln, err := tr.Listen(tlsConf, conf)
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
proxy := quicproxy.Proxy{
Conn: newUPDConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DropPacket: dropCallback,
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
t.Cleanup(func() { proxy.Close() })
return ln, proxy.LocalAddr()
}
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUPDConnLocalhost(t),
addr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer str.Close()
_, err := str.Write(data)
errChan <- err
}()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverStr, err := serverConn.AcceptUniStream(ctx)
require.NoError(t, err)
b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
require.NoError(t, err)
require.Equal(t, b, data)
serverConn.CloseWithError(0, "")
}
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUPDConnLocalhost(t),
addr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(ctx)
if err != nil {
errChan <- err
return
}
b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
if err != nil {
errChan <- err
return
}
if !bytes.Equal(b, data) {
errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
return
}
}()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
_, err = serverStr.Write(data)
require.NoError(t, err)
require.NoError(t, serverStr.Close())
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
select {
case <-conn.Context().Done():
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
}
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := quic.Dial(
ctx,
newUPDConnLocalhost(t),
addr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverConn.CloseWithError(0, "")
}
func dropCallbackDropNthPacket(direction quicproxy.Direction, n int) quicproxy.DropCallback {
var incoming, outgoing atomic.Int32
return func(d quicproxy.Direction, packet []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = incoming.Add(1)
case quicproxy.DirectionOutgoing:
p = outgoing.Add(1)
}
return p == int32(n) && d.Is(direction)
}
}
func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallback {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var incoming, outgoing int
return func(d quicproxy.Direction, _ []byte) bool {
drop := mrand.Int63n(int64(3)) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d.Is(quicproxy.DirectionIncoming) {
if drop {
incoming++
if incoming > maxSequentiallyDropped {
drop = false
}
}
if !drop {
incoming = 0
}
}
if d.Is(quicproxy.DirectionOutgoing) {
if drop {
outgoing++
if outgoing > maxSequentiallyDropped {
drop = false
}
}
if !drop {
outgoing = 0
}
}
return drop
}
}
func TestHandshakeWithPacketLoss(t *testing.T) {
data := GeneratePRData(5000)
const timeout = 2 * time.Minute
const rtt = 20 * time.Millisecond
type dropPattern struct {
name string
fn quicproxy.DropCallback
}
type serverConfig struct {
longCertChain bool
doRetry bool
}
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} {
for _, dropPattern := range []dropPattern{
{name: "drop 1st packet", fn: dropCallbackDropNthPacket(direction, 1)},
{name: "drop 2nd packet", fn: dropCallbackDropNthPacket(direction, 2)},
{name: "drop 1/3 of packets", fn: dropCallbackDropOneThird(direction)},
} {
t.Run(fmt.Sprintf("%s in %s direction", dropPattern.name, direction), func(t *testing.T) {
for _, conf := range []serverConfig{
{longCertChain: false, doRetry: true},
{longCertChain: false, doRetry: false},
{longCertChain: true, doRetry: false},
} {
t.Run(fmt.Sprintf("retry: %t", conf.doRetry), func(t *testing.T) {
t.Run("client speaks first", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolClientSpeaksFirst(t, ln, proxyAddr, timeout, data)
})
t.Run("server speaks first", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolServerSpeaksFirst(t, ln, proxyAddr, timeout, data)
})
t.Run("nobody speaks", func(t *testing.T) {
ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolNobodySpeaks(t, ln, proxyAddr, timeout)
})
})
}
})
}
}
}
func TestPostQuantumClientHello(t *testing.T) {
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
t.Cleanup(func() { wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient })
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
mrand.New(mrand.NewSource(time.Now().UnixNano())).Read(b)
wire.AdditionalTransportParametersClient = map[uint64][]byte{
// Avoid random collisions with the greased transport parameters.
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
}
ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false)
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, time.Minute, GeneratePRData(5000))
}