mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
674 lines
20 KiB
Go
674 lines
20 KiB
Go
package self_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/qtls"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type tokenStore struct {
|
|
store quic.TokenStore
|
|
gets chan<- string
|
|
puts chan<- string
|
|
}
|
|
|
|
var _ quic.TokenStore = &tokenStore{}
|
|
|
|
func newTokenStore(gets, puts chan<- string) quic.TokenStore {
|
|
return &tokenStore{
|
|
store: quic.NewLRUTokenStore(10, 4),
|
|
gets: gets,
|
|
puts: puts,
|
|
}
|
|
}
|
|
|
|
func (c *tokenStore) Put(key string, token *quic.ClientToken) {
|
|
c.puts <- key
|
|
c.store.Put(key, token)
|
|
}
|
|
|
|
func (c *tokenStore) Pop(key string) *quic.ClientToken {
|
|
c.gets <- key
|
|
return c.store.Pop(key)
|
|
}
|
|
|
|
func TestHandshakeAddrResolutionHelpers(t *testing.T) {
|
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.DialAddr(
|
|
ctx,
|
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
|
getTLSClientConfig(),
|
|
getQuicConfig(nil),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
serverConn, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
defer serverConn.CloseWithError(0, "")
|
|
}
|
|
|
|
func TestHandshake(t *testing.T) {
|
|
for _, tt := range []struct {
|
|
name string
|
|
conf *tls.Config
|
|
}{
|
|
{"short cert chain", getTLSConfig()},
|
|
{"long cert chain", getTLSConfigWithLongCertChain()},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), tt.conf, getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
serverConn, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
defer serverConn.CloseWithError(0, "")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandshakeServerMismatch(t *testing.T) {
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
conf := getTLSClientConfig()
|
|
conf.ServerName = "foo.bar"
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err = quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), conf, getQuicConfig(nil))
|
|
require.Error(t, err)
|
|
var transportErr *quic.TransportError
|
|
require.True(t, errors.As(err, &transportErr))
|
|
require.True(t, transportErr.ErrorCode.IsCryptoError())
|
|
require.Contains(t, transportErr.Error(), "x509: certificate is valid for localhost, not foo.bar")
|
|
var certErr *tls.CertificateVerificationError
|
|
require.True(t, errors.As(transportErr, &certErr))
|
|
}
|
|
|
|
func TestHandshakeCipherSuites(t *testing.T) {
|
|
for _, suiteID := range []uint16{
|
|
tls.TLS_AES_128_GCM_SHA256,
|
|
tls.TLS_AES_256_GCM_SHA384,
|
|
tls.TLS_CHACHA20_POLY1305_SHA256,
|
|
} {
|
|
t.Run(tls.CipherSuiteName(suiteID), func(t *testing.T) {
|
|
reset := qtls.SetCipherSuite(suiteID)
|
|
defer reset()
|
|
|
|
ln, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
serverConn, err := ln.Accept(context.Background())
|
|
require.NoError(t, err)
|
|
defer serverConn.CloseWithError(0, "")
|
|
serverStr, err := serverConn.OpenStream()
|
|
require.NoError(t, err)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer serverStr.Close()
|
|
_, err = serverStr.Write(PRData)
|
|
errChan <- err
|
|
}()
|
|
require.NoError(t, <-errChan)
|
|
|
|
str, err := conn.AcceptStream(context.Background())
|
|
require.NoError(t, err)
|
|
data, err := io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
require.Equal(t, PRData, data)
|
|
require.Equal(t, suiteID, conn.ConnectionState().TLS.CipherSuite)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTLSGetConfigForClientError(t *testing.T) {
|
|
tr := &quic.Transport{Conn: newUPDConnLocalhost(t)}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return nil, errors.New("nope")
|
|
},
|
|
}
|
|
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.True(t, transportErr.ErrorCode.IsCryptoError())
|
|
}
|
|
|
|
// Since we're not operating on a net.Conn, we need to jump through some hoops to set the addresses on the tls.ClientHelloInfo.
|
|
// Use a recursive setup to test that this works under all conditions.
|
|
func TestTLSConfigGetConfigForClientAddresses(t *testing.T) {
|
|
var local, remote net.Addr
|
|
var local2, remote2 net.Addr
|
|
done := make(chan struct{})
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
local = info.Conn.LocalAddr()
|
|
remote = info.Conn.RemoteAddr()
|
|
conf := getTLSConfig()
|
|
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
defer close(done)
|
|
local2 = info.Conn.LocalAddr()
|
|
remote2 = info.Conn.RemoteAddr()
|
|
return &(conf.Certificates[0]), nil
|
|
}
|
|
return conf, nil
|
|
},
|
|
}
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), tlsConf, getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for GetCertificate callback")
|
|
}
|
|
|
|
require.Equal(t, server.Addr(), local)
|
|
require.Equal(t, conn.LocalAddr().(*net.UDPAddr).Port, remote.(*net.UDPAddr).Port)
|
|
require.Equal(t, local, local2)
|
|
require.Equal(t, remote, remote2)
|
|
}
|
|
|
|
func TestHandshakeFailsWithoutClientCert(t *testing.T) {
|
|
tlsConf := getTLSConfig()
|
|
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
|
|
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), tlsConf, getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
|
|
// Usually, the error will occur after the client already finished the handshake.
|
|
// However, there's a race condition here. The server's CONNECTION_CLOSE might be
|
|
// received before the connection is returned, so we might already get the error while dialing.
|
|
if err == nil {
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := conn.AcceptStream(context.Background())
|
|
errChan <- err
|
|
}()
|
|
|
|
err = <-errChan
|
|
}
|
|
|
|
require.Error(t, err)
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.True(t, transportErr.ErrorCode.IsCryptoError())
|
|
require.Condition(t, func() bool {
|
|
errStr := transportErr.Error()
|
|
return strings.Contains(errStr, "tls: certificate required") ||
|
|
strings.Contains(errStr, "tls: bad certificate")
|
|
})
|
|
}
|
|
|
|
func TestClosedConnectionsInAcceptQueue(t *testing.T) {
|
|
dialer := &quic.Transport{Conn: newUPDConnLocalhost(t)}
|
|
defer dialer.Close()
|
|
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
// Create first connection
|
|
conn1, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
conn2, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn2.CloseWithError(0, "")
|
|
// close the first connection
|
|
const appErrCode quic.ApplicationErrorCode = 12345
|
|
require.NoError(t, conn1.CloseWithError(appErrCode, ""))
|
|
|
|
time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued and closed
|
|
|
|
// accept all connections, and find the closed one
|
|
var closedConn quic.Connection
|
|
for i := 0; i < 2; i++ {
|
|
conn, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
if conn.Context().Err() != nil {
|
|
require.Nil(t, closedConn, "only expected a single closed connection")
|
|
closedConn = conn
|
|
}
|
|
}
|
|
require.NotNil(t, closedConn, "expected one closed connection")
|
|
|
|
_, err = closedConn.AcceptStream(context.Background())
|
|
var appErr *quic.ApplicationError
|
|
require.ErrorAs(t, err, &appErr)
|
|
require.Equal(t, appErrCode, appErr.ErrorCode)
|
|
}
|
|
|
|
func TestServerAcceptQueueOverflow(t *testing.T) {
|
|
server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
dialer := &quic.Transport{Conn: newUPDConnLocalhost(t)}
|
|
defer dialer.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
// fill up the accept queue
|
|
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
|
conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
}
|
|
time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued
|
|
|
|
// next connection should be rejected
|
|
conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
_, err = conn.AcceptStream(ctx)
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode)
|
|
|
|
// accept one connection to free up a spot
|
|
_, err = server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// should be able to dial again
|
|
conn2, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn2.CloseWithError(0, "")
|
|
time.Sleep(scaleDuration(25 * time.Millisecond))
|
|
|
|
// but next connection should be rejected again
|
|
conn3, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
_, err = conn3.AcceptStream(ctx)
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode)
|
|
}
|
|
|
|
func TestHandshakingConnectionsClosedOnServerShutdown(t *testing.T) {
|
|
tr := &quic.Transport{Conn: newUPDConnLocalhost(t)}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
rtt := scaleDuration(40 * time.Millisecond)
|
|
connQueued := make(chan struct{})
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
close(connQueued)
|
|
// Sleep for a bit.
|
|
// This allows the server to close the connection before the handshake completes.
|
|
time.Sleep(rtt / 2)
|
|
return getTLSConfig(), nil
|
|
},
|
|
}
|
|
|
|
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
|
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
|
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
|
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
|
})
|
|
require.NoError(t, err)
|
|
defer proxy.Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
_, err := quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
errChan <- err
|
|
}()
|
|
|
|
select {
|
|
case <-connQueued:
|
|
case <-time.After(5 * rtt):
|
|
t.Fatal("timeout waiting for connection queued")
|
|
}
|
|
require.NoError(t, ln.Close())
|
|
|
|
err = <-errChan
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode)
|
|
}
|
|
|
|
func TestALPN(t *testing.T) {
|
|
ln, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
acceptChan := make(chan quic.Connection, 2)
|
|
go func() {
|
|
for {
|
|
conn, err := ln.Accept(context.Background())
|
|
if err != nil {
|
|
return
|
|
}
|
|
acceptChan <- conn
|
|
}
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), nil)
|
|
require.NoError(t, err)
|
|
cs := conn.ConnectionState()
|
|
require.Equal(t, alpn, cs.TLS.NegotiatedProtocol)
|
|
|
|
select {
|
|
case c := <-acceptChan:
|
|
require.Equal(t, alpn, c.ConnectionState().TLS.NegotiatedProtocol)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for server connection")
|
|
}
|
|
require.NoError(t, conn.CloseWithError(0, ""))
|
|
|
|
// now try with a different ALPN
|
|
tlsConf := getTLSClientConfig()
|
|
tlsConf.NextProtos = []string{"foobar"}
|
|
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), tlsConf, nil)
|
|
require.Error(t, err)
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.True(t, transportErr.ErrorCode.IsCryptoError())
|
|
require.Contains(t, transportErr.Error(), "no application protocol")
|
|
}
|
|
|
|
func TestTokensFromNewTokenFrames(t *testing.T) {
|
|
t.Run("MaxTokenAge: 1 hour", func(t *testing.T) {
|
|
testTokensFromNewTokenFrames(t, 0, true)
|
|
})
|
|
// If unset, the default value is 24h.
|
|
t.Run("MaxTokenAge: default", func(t *testing.T) {
|
|
testTokensFromNewTokenFrames(t, 0, true)
|
|
})
|
|
t.Run("MaxTokenAge: very short", func(t *testing.T) {
|
|
testTokensFromNewTokenFrames(t, time.Microsecond, false)
|
|
})
|
|
}
|
|
|
|
func testTokensFromNewTokenFrames(t *testing.T, maxTokenAge time.Duration, expectTokenUsed bool) {
|
|
addrVerifiedChan := make(chan bool, 2)
|
|
quicConf := getQuicConfig(nil)
|
|
quicConf.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
|
addrVerifiedChan <- info.AddrVerified
|
|
return quicConf, nil
|
|
}
|
|
tr := &quic.Transport{Conn: newUPDConnLocalhost(t), MaxTokenAge: maxTokenAge}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
server, err := tr.Listen(getTLSConfig(), quicConf)
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
// dial the first connection and receive the token
|
|
acceptChan := make(chan error, 2)
|
|
go func() {
|
|
_, err := server.Accept(context.Background())
|
|
acceptChan <- err
|
|
_, err = server.Accept(context.Background())
|
|
acceptChan <- err
|
|
}()
|
|
|
|
gets := make(chan string, 2)
|
|
puts := make(chan string, 2)
|
|
ts := newTokenStore(gets, puts)
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{TokenStore: ts}))
|
|
require.NoError(t, err)
|
|
|
|
// verify token store was used
|
|
select {
|
|
case <-gets:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for token store get")
|
|
}
|
|
select {
|
|
case <-puts:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for token store put")
|
|
}
|
|
select {
|
|
case addrVerified := <-addrVerifiedChan:
|
|
require.False(t, addrVerified)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for addr verified")
|
|
}
|
|
select {
|
|
case <-acceptChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for accept")
|
|
}
|
|
// received a token. Close this connection.
|
|
require.NoError(t, conn.CloseWithError(0, ""))
|
|
|
|
time.Sleep(scaleDuration(5 * time.Millisecond))
|
|
conn, err = quic.Dial(ctx, newUPDConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{TokenStore: ts}))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
select {
|
|
case addrVerified := <-addrVerifiedChan:
|
|
// this time, the address was verified using the token
|
|
if expectTokenUsed {
|
|
require.True(t, addrVerified)
|
|
} else {
|
|
require.False(t, addrVerified)
|
|
}
|
|
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for addr verified")
|
|
}
|
|
select {
|
|
case <-gets:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for token store get")
|
|
}
|
|
select {
|
|
case <-acceptChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for accept")
|
|
}
|
|
}
|
|
|
|
func TestInvalidToken(t *testing.T) {
|
|
const rtt = 10 * time.Millisecond
|
|
|
|
// The validity period of the retry token is the handshake timeout,
|
|
// which is twice the handshake idle timeout.
|
|
// By setting the handshake timeout shorter than the RTT, the token will have
|
|
// expired by the time it reaches the server.
|
|
serverConfig := getQuicConfig(&quic.Config{HandshakeIdleTimeout: rtt / 5})
|
|
|
|
tr := &quic.Transport{
|
|
Conn: newUPDConnLocalhost(t),
|
|
VerifySourceAddress: func(net.Addr) bool { return true },
|
|
}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
server, err := tr.Listen(getTLSConfig(), serverConfig)
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
serverPort := server.Addr().(*net.UDPAddr).Port
|
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
|
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
|
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
|
|
return rtt / 2
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
defer proxy.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err = quic.Dial(ctx, newUPDConnLocalhost(t), proxy.LocalAddr(), getTLSClientConfig(), nil)
|
|
require.Error(t, err)
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.Equal(t, quic.InvalidToken, transportErr.ErrorCode)
|
|
}
|
|
|
|
func TestGetConfigForClient(t *testing.T) {
|
|
var calledFrom net.Addr
|
|
serverConfig := getQuicConfig(&quic.Config{EnableDatagrams: true})
|
|
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
|
conf := serverConfig.Clone()
|
|
conf.EnableDatagrams = true
|
|
calledFrom = info.RemoteAddr
|
|
return getQuicConfig(conf), nil
|
|
}
|
|
ln, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), serverConfig)
|
|
require.NoError(t, err)
|
|
|
|
acceptDone := make(chan struct{})
|
|
go func() {
|
|
_, err := ln.Accept(context.Background())
|
|
require.NoError(t, err)
|
|
close(acceptDone)
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{EnableDatagrams: true}))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
cs := conn.ConnectionState()
|
|
require.True(t, cs.SupportsDatagrams)
|
|
|
|
select {
|
|
case <-acceptDone:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout waiting for accept")
|
|
}
|
|
|
|
require.NoError(t, ln.Close())
|
|
require.Equal(t, conn.LocalAddr().(*net.UDPAddr).Port, calledFrom.(*net.UDPAddr).Port)
|
|
}
|
|
|
|
func TestGetConfigForClientErrorsConnectionRejection(t *testing.T) {
|
|
ln, err := quic.Listen(
|
|
newUPDConnLocalhost(t),
|
|
getTLSConfig(),
|
|
getQuicConfig(&quic.Config{
|
|
GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
|
return nil, errors.New("rejected")
|
|
},
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
acceptChan := make(chan bool, 1)
|
|
go func() {
|
|
_, err := ln.Accept(context.Background())
|
|
acceptChan <- err == nil
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.Equal(t, qerr.ConnectionRefused, transportErr.ErrorCode)
|
|
|
|
// verify no connection was accepted
|
|
ln.Close()
|
|
require.False(t, <-acceptChan)
|
|
}
|
|
|
|
func TestNoPacketsSentWhenClientHelloFails(t *testing.T) {
|
|
conn := newUPDConnLocalhost(t)
|
|
|
|
packetChan := make(chan struct{}, 1)
|
|
go func() {
|
|
for {
|
|
_, _, err := conn.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize))
|
|
if err != nil {
|
|
return
|
|
}
|
|
select {
|
|
case packetChan <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
}()
|
|
|
|
tlsConf := getTLSClientConfig()
|
|
tlsConf.NextProtos = []string{""}
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err := quic.Dial(ctx, newUPDConnLocalhost(t), conn.LocalAddr(), tlsConf, getQuicConfig(nil))
|
|
|
|
var transportErr *quic.TransportError
|
|
require.ErrorAs(t, err, &transportErr)
|
|
require.True(t, transportErr.ErrorCode.IsCryptoError())
|
|
require.Contains(t, err.Error(), "tls: invalid NextProtos value")
|
|
|
|
// verify no packets were sent
|
|
select {
|
|
case <-packetChan:
|
|
t.Fatal("received unexpected packet")
|
|
case <-time.After(50 * time.Millisecond):
|
|
// no packets received, as expected
|
|
}
|
|
}
|