uquic/integrationtests/self/handshake_test.go
2024-03-09 19:32:15 +09:30

852 lines
27 KiB
Go

package self_test
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type tokenStore struct {
store quic.TokenStore
gets chan<- string
puts chan<- string
}
var _ quic.TokenStore = &tokenStore{}
func newTokenStore(gets, puts chan<- string) quic.TokenStore {
return &tokenStore{
store: quic.NewLRUTokenStore(10, 4),
gets: gets,
puts: puts,
}
}
func (c *tokenStore) Put(key string, token *quic.ClientToken) {
c.puts <- key
c.store.Put(key, token)
}
func (c *tokenStore) Pop(key string) *quic.ClientToken {
c.gets <- key
return c.store.Pop(key)
}
var _ = Describe("Handshake tests", func() {
var (
server *quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = getQuicConfig(nil)
})
AfterEach(func() {
if server != nil {
server.Close()
<-acceptStopped
}
})
runServer := func(tlsConf *tls.Config) {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}
}()
}
It("returns the context cancellation error on timeouts", func() {
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
defer cancel()
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(context.DeadlineExceeded))
})
It("returns the cancellation reason when a dial is canceled", func() {
ctx, cancel := context.WithCancelCause(context.Background())
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
cancel(errors.New("application cancelled"))
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError("application cancelled"))
})
Context("using different cipher suites", func() {
for n, id := range map[string]uint16{
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
} {
name := n
suiteID := id
It(fmt.Sprintf("using %s", name), func() {
reset := qtls.SetCipherSuite(suiteID)
defer reset()
tlsConf := getTLSConfig()
ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID))
Expect(conn.CloseWithError(0, "")).To(Succeed())
})
}
})
Context("Certificate validation", func() {
It("accepts the certificate", func() {
runServer(getTLSConfig())
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
var local, remote net.Addr
var local2, remote2 net.Addr
done := make(chan struct{})
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
local = info.Conn.LocalAddr()
remote = info.Conn.RemoteAddr()
conf := getTLSConfig()
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
defer close(done)
local2 = info.Conn.LocalAddr()
remote2 = info.Conn.RemoteAddr()
return &(conf.Certificates[0]), nil
}
return conf, nil
},
}
runServer(tlsConf)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
Expect(local).To(Equal(local2))
Expect(remote).To(Equal(remote2))
})
It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
It("errors if the server name doesn't match", func() {
runServer(getTLSConfig())
conn, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
conf := getTLSClientConfig()
conf.ServerName = "foo.bar"
_, err = quic.Dial(
context.Background(),
conn,
server.Addr(),
conf,
getQuicConfig(nil),
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
var certErr *tls.CertificateVerificationError
Expect(errors.As(transportErr, &certErr)).To(BeTrue())
})
It("fails the handshake if the client fails to provide the requested client cert", func() {
tlsConf := getTLSConfig()
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
runServer(tlsConf)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
// Usually, the error will occur after the client already finished the handshake.
// However, there's a race condition here. The server's CONNECTION_CLOSE might be
// received before the connection is returned, so we might already get the error while dialing.
if err == nil {
errChan := make(chan error)
go func() {
defer GinkgoRecover()
_, err := conn.AcceptStream(context.Background())
errChan <- err
}()
Eventually(errChan).Should(Receive(&err))
}
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
Expect(transportErr.Error()).To(Or(
ContainSubstring("tls: certificate required"),
ContainSubstring("tls: bad certificate"),
))
})
It("uses the ServerName in the tls.Config", func() {
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.ServerName = "foo.bar"
_, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
})
})
Context("queuening and accepting connections", func() {
var (
server *quic.Listener
pconn net.PacketConn
dialer *quic.Transport
)
dial := func() (quic.Connection, error) {
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
Expect(err).ToNot(HaveOccurred())
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
}
BeforeEach(func() {
var err error
// start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// prepare a (single) packet conn for dialing to the server
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{
Conn: pconn,
ConnectionIDLength: 4,
}
})
AfterEach(func() {
Expect(server.Close()).To(Succeed())
Expect(pconn.Close()).To(Succeed())
Expect(dialer.Close()).To(Succeed())
})
It("rejects new connection attempts if connections don't get accepted", func() {
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
}
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn.AcceptStream(ctx)
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
// now accept one connection, freeing one spot in the queue
_, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// dial again, and expect that this dial succeeds
conn2, err := dial()
Expect(err).ToNot(HaveOccurred())
defer conn2.CloseWithError(0, "")
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
conn3, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn3.AcceptStream(ctx)
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})
It("also returns closed connections from the accept queue", func() {
firstConn, err := dial()
Expect(err).ToNot(HaveOccurred())
for i := 1; i < protocol.MaxAcceptQueueSize; i++ {
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
}
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn.AcceptStream(ctx)
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
// Now close the one of the connection that are waiting to be accepted.
const appErrCode quic.ApplicationErrorCode = 12345
Expect(firstConn.CloseWithError(appErrCode, ""))
Eventually(firstConn.Context().Done()).Should(BeClosed())
time.Sleep(scaleDuration(200 * time.Millisecond))
// dial again, and expect that this fails again
conn2, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn2.AcceptStream(ctx)
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
// now accept all connections
var closedConn quic.Connection
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
if conn.Context().Err() != nil {
if closedConn != nil {
Fail("only expected a single closed connection")
}
closedConn = conn
}
}
Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection
_, err = closedConn.AcceptStream(context.Background())
var appErr *quic.ApplicationError
Expect(errors.As(err, &appErr)).To(BeTrue())
Expect(appErr.ErrorCode).To(Equal(appErrCode))
})
It("closes handshaking connections when the server is closed", func() {
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: udpConn}
addTracer(tr)
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
<-done
return nil, errors.New("closed")
}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
errChan := make(chan error, 1)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go func() {
defer GinkgoRecover()
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
errChan <- err
}()
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
Expect(ln.Close()).To(Succeed())
close(done)
err = <-errChan
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})
})
Context("limiting handshakes", func() {
var conn *net.UDPConn
BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() { conn.Close() })
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := &quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
addTracer(tr)
defer tr.Close()
// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
const additional = 2
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
for i := 0; i < limit+additional; i++ {
go func(index int) {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
ClosedConnection: func(error) { results[index].closed.Store(true) },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
}(i)
}
numRetries := func() (n int) {
for i := 0; i < limit+additional; i++ {
if results[i].retry.Load() {
n++
}
}
return
}
numClosed := func() (n int) {
for i := 0; i < limit+2; i++ {
if results[i].closed.Load() {
n++
}
}
return
}
Eventually(numRetries).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit+additional; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(numClosed).Should(Equal(limit + additional))
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
})
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := &quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
addTracer(tr)
defer tr.Close()
// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
const additional = 2
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
var numSuccessful, numFailed atomic.Int32
for i := 0; i < limit+additional; i++ {
go func() {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
if err != nil {
var transportErr *quic.TransportError
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
}
numFailed.Add(1)
return
}
numSuccessful.Add(1)
conn.CloseWithError(0, "")
}()
}
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))
// make sure that the server is reachable again after these handshakes have completed
go func() { <-handshakes }() // allow this handshake to complete immediately
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
})
Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
cs := conn.ConnectionState()
Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
close(done)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
nil,
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cs := conn.ConnectionState()
Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
Eventually(done).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
})
It("errors if application protocol negotiation fails", func() {
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.NextProtos = []string{"foobar"}
_, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
Expect(transportErr.Error()).To(ContainSubstring("no application protocol"))
})
})
Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
// dial the first connection and receive the token
go func() {
defer GinkgoRecover()
_, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}()
gets := make(chan string, 100)
puts := make(chan string, 100)
tokenStore := newTokenStore(gets, puts)
quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore})
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
quicConf,
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive())
Eventually(puts).Should(Receive())
// received a token. Close this connection.
Expect(conn.CloseWithError(0, "")).To(Succeed())
// dial the second connection and verify that the token was used
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
_, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}()
conn, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
quicConf,
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Expect(gets).To(Receive())
Eventually(done).Should(BeClosed())
})
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
const rtt = 10 * time.Millisecond
// The validity period of the retry token is the handshake timeout,
// which is twice the handshake idle timeout.
// By setting the handshake timeout shorter than the RTT, the token will have expired by the time
// it reaches the server.
serverConfig.HandshakeIdleTimeout = rtt / 5
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverPort := server.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
nil,
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
})
})
Context("GetConfigForClient", func() {
It("uses the quic.Config returned by GetConfigForClient", func() {
serverConfig.EnableDatagrams = false
var calledFrom net.Addr
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
conf := serverConfig.Clone()
conf.EnableDatagrams = true
calledFrom = info.RemoteAddr
return getQuicConfig(conf), nil
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
close(done)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cs := conn.ConnectionState()
Expect(cs.SupportsDatagrams).To(BeTrue())
Eventually(done).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
})
It("rejects the connection attempt if GetConfigForClient errors", func() {
serverConfig.EnableDatagrams = false
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
return nil, errors.New("rejected")
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
close(done)
}()
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
})
})
It("doesn't send any packets when generating the ClientHello fails", func() {
ln, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
packetChan := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
for {
_, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize))
if err != nil {
return
}
packetChan <- struct{}{}
}
}()
tlsConf := getTLSClientConfig()
tlsConf.NextProtos = []string{""}
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port),
tlsConf,
nil,
)
Expect(err).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: "tls: invalid NextProtos value",
}))
Consistently(packetChan).ShouldNot(Receive())
ln.Close()
Eventually(done).Should(BeClosed())
})
})