move the version negotiation tests to a separate package

This commit is contained in:
Marten Seemann 2023-04-22 11:19:30 +02:00
parent 172123c340
commit 0dbe595d9f
8 changed files with 406 additions and 244 deletions

View file

@ -27,12 +27,18 @@ jobs:
- run:
name: "Run self integration tests"
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/self
- run:
name: "Run version negotiation tests"
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation
- run:
name: "Run self integration tests with race detector"
command: go run github.com/onsi/ginkgo/v2/ginkgo -race -v -randomize-all -trace integrationtests/self
- run:
name: "Run self integration tests with qlog"
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/self -- -qlog
- run:
name: "Run version negotiation tests with qlog"
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation -- -qlog
go119:
<<: *test
go120:

View file

@ -23,13 +23,15 @@ jobs:
run: echo "QLOGFLAG=-- -qlog" >> $GITHUB_ENV
- name: Run tests
run: |
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self integrationtests
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }}
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }}
- name: Run tests (32 bit)
env:
GOARCH: 386
run: |
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self integrationtests
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }}
go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }}
- name: save qlogs
if: ${{ always() && env.DEBUG == 'true' }}

View file

@ -10,20 +10,14 @@ import (
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/integrationtests/tools/israce"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type versioner interface {
GetVersion() protocol.VersionNumber
}
type tokenStore struct {
store quic.TokenStore
gets chan<- string
@ -50,31 +44,6 @@ func (c *tokenStore) Pop(key string) *quic.ClientToken {
return c.store.Pop(key)
}
type versionNegotiationTracer struct {
logging.NullConnectionTracer
loggedVersions bool
receivedVersionNegotiation bool
chosen logging.VersionNumber
clientVersions, serverVersions []logging.VersionNumber
}
var _ logging.ConnectionTracer = &versionNegotiationTracer{}
func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
if t.loggedVersions {
Fail("only expected one call to NegotiatedVersions")
}
t.loggedVersions = true
t.chosen = chosen
t.clientVersions = clientVersions
t.serverVersions = serverVersions
}
func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
t.receivedVersionNegotiation = true
}
var _ = Describe("Handshake tests", func() {
var (
server quic.Listener
@ -112,79 +81,6 @@ var _ = Describe("Handshake tests", func() {
}()
}
if !israce.Enabled {
Context("Version Negotiation", func() {
var supportedVersions []protocol.VersionNumber
BeforeEach(func() {
supportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
})
AfterEach(func() {
protocol.SupportedVersions = supportedVersions
})
It("when the server supports more versions than the client", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
runServer(getTLSConfig())
defer server.Close()
clientTracer := &versionNegotiationTracer{}
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
Expect(clientTracer.serverVersions).To(BeEmpty())
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
})
It("when the client supports more versions than the server supports", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
runServer(getTLSConfig())
defer server.Close()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
clientTracer := &versionNegotiationTracer{}
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
Versions: clientVersions,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
Expect(clientTracer.clientVersions).To(Equal(clientVersions))
Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
})
})
}
Context("using different cipher suites", func() {
for n, id := range map[string]uint16{
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,

View file

@ -1,20 +1,12 @@
package self_test
import (
"bufio"
"bytes"
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"flag"
"fmt"
"io"
"log"
"math/big"
mrand "math/rand"
"os"
"runtime/pprof"
@ -24,19 +16,17 @@ import (
"testing"
"time"
"golang.org/x/crypto/ed25519"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/integrationtests/tools"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const alpn = "quic-go integration tests"
const alpn = tools.ALPN
const (
dataLen = 500 * 1024 // 500 KB
@ -93,12 +83,13 @@ var (
logFileName string // the log file set in the ginkgo flags
logBufOnce sync.Once
logBuf *syncedBuffer
enableQlog bool
qlogTracer logging.Tracer
enableQlog bool
tlsConfig *tls.Config
tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config
quicConfigTracer logging.Tracer
)
// read the logfile command line flag
@ -107,11 +98,11 @@ func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
ca, caPrivateKey, err := generateCA()
ca, caPrivateKey, err := tools.GenerateCA()
if err != nil {
panic(err)
}
leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey)
leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
if err != nil {
panic(err)
}
@ -122,7 +113,7 @@ func init() {
}},
NextProtos: []string{alpn},
}
tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey)
tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
if err != nil {
panic(err)
}
@ -140,126 +131,10 @@ var _ = BeforeSuite(func() {
mrand.Seed(GinkgoRandomSeed())
if enableQlog {
quicConfigTracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename)
f, err := os.Create(filename)
Expect(err).ToNot(HaveOccurred())
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
qlogTracer = tools.NewQlogger(GinkgoWriter)
}
})
func generateCA() (*x509.Certificate, crypto.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv)
if err != nil {
return nil, nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, priv, nil
}
func generateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{"localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}
return cert, priv, nil
}
// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
// The Root CA used is the same as for the config returned from getTLSConfig().
func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey crypto.PrivateKey) (*tls.Config, error) {
const chainLen = 7
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
lastCA := ca
lastCAPrivKey := caPrivateKey
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
certs := make([]*x509.Certificate, chainLen)
for i := 0; i < chainLen; i++ {
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
if err != nil {
return nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, err
}
certs[i] = ca
lastCA = ca
lastCAPrivKey = privKey
}
leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey)
if err != nil {
return nil, err
}
rawCerts := make([][]byte, chainLen+1)
for i, cert := range certs {
rawCerts[chainLen-i] = cert.Raw
}
rawCerts[0] = leafCert.Raw
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: rawCerts,
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{alpn},
}, nil
}
func getTLSConfig() *tls.Config {
return tlsConfig.Clone()
}
@ -278,10 +153,12 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
} else {
conf = conf.Clone()
}
if conf.Tracer == nil {
conf.Tracer = quicConfigTracer
} else if quicConfigTracer != nil {
conf.Tracer = logging.NewMultiplexedTracer(quicConfigTracer, conf.Tracer)
if enableQlog {
if conf.Tracer == nil {
conf.Tracer = qlogTracer
} else if qlogTracer != nil {
conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer)
}
}
return conf
}

View file

@ -0,0 +1,120 @@
package tools
import (
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"time"
)
const ALPN = "quic-go integration tests"
func GenerateCA() (*x509.Certificate, crypto.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv)
if err != nil {
return nil, nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, priv, nil
}
func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{"localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}
return cert, priv, nil
}
// GenerateTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
// The Root CA used is the same as for the config returned from getTLSConfig().
func GenerateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey crypto.PrivateKey) (*tls.Config, error) {
const chainLen = 7
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
lastCA := ca
lastCAPrivKey := caPrivateKey
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
certs := make([]*x509.Certificate, chainLen)
for i := 0; i < chainLen; i++ {
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
if err != nil {
return nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, err
}
certs[i] = ca
lastCA = ca
lastCAPrivKey = privKey
}
leafCert, leafPrivateKey, err := GenerateLeafCert(lastCA, lastCAPrivKey)
if err != nil {
return nil, err
}
rawCerts := make([][]byte, chainLen+1)
for i, cert := range certs {
rawCerts[chainLen-i] = cert.Raw
}
rawCerts[0] = leafCert.Raw
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: rawCerts,
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{ALPN},
}, nil
}

View file

@ -0,0 +1,31 @@
package tools
import (
"bufio"
"fmt"
"io"
"log"
"os"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
func NewQlogger(logger io.Writer) logging.Tracer {
return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
log.Fatalf("failed to create qlog file: %s", err)
return nil
}
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
}

View file

@ -0,0 +1,141 @@
package versionnegotiation
import (
"context"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/integrationtests/tools/israce"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type versioner interface {
GetVersion() protocol.VersionNumber
}
type versionNegotiationTracer struct {
logging.NullConnectionTracer
loggedVersions bool
receivedVersionNegotiation bool
chosen logging.VersionNumber
clientVersions, serverVersions []logging.VersionNumber
}
var _ logging.ConnectionTracer = &versionNegotiationTracer{}
func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
if t.loggedVersions {
Fail("only expected one call to NegotiatedVersions")
}
t.loggedVersions = true
t.chosen = chosen
t.clientVersions = clientVersions
t.serverVersions = serverVersions
}
func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
t.receivedVersionNegotiation = true
}
var _ = Describe("Handshake tests", func() {
startServer := func(tlsConf *tls.Config, conf *quic.Config) (quic.Listener, func()) {
server, err := quic.ListenAddr("localhost:0", tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
acceptStopped := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}
}()
return server, func() {
server.Close()
<-acceptStopped
}
}
var supportedVersions []protocol.VersionNumber
BeforeEach(func() {
supportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
})
AfterEach(func() {
protocol.SupportedVersions = supportedVersions
})
if !israce.Enabled {
It("when the server supports more versions than the client", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig := &quic.Config{}
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientTracer := &versionNegotiationTracer{}
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
Expect(clientTracer.serverVersions).To(BeEmpty())
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
})
It("when the client supports more versions than the server supports", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
clientTracer := &versionNegotiationTracer{}
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{
Versions: clientVersions,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
Expect(clientTracer.clientVersions).To(Equal(clientVersions))
Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
})
}
})

View file

@ -0,0 +1,89 @@
package versionnegotiation
import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"testing"
"github.com/quic-go/quic-go/integrationtests/tools"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var (
enableQlog bool
tlsConfig *tls.Config
tlsClientConfig *tls.Config
)
func init() {
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
ca, caPrivateKey, err := tools.GenerateCA()
if err != nil {
panic(err)
}
leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
if err != nil {
panic(err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{leafCert.Raw},
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{tools.ALPN},
}
root := x509.NewCertPool()
root.AddCert(ca)
tlsClientConfig = &tls.Config{
RootCAs: root,
NextProtos: []string{tools.ALPN},
}
}
func getTLSConfig() *tls.Config { return tlsConfig }
func getTLSClientConfig() *tls.Config { return tlsClientConfig }
func TestQuicVersionNegotiation(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Version Negotiation Suite")
}
func maybeAddQlogTracer(c *quic.Config) *quic.Config {
if c == nil {
c = &quic.Config{}
}
if !enableQlog {
return c
}
qlogger := tools.NewQlogger(GinkgoWriter)
if c.Tracer == nil {
c.Tracer = qlogger
} else if qlogger != nil {
c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer)
}
return c
}
type tracer struct {
logging.NullTracer
createNewConnTracer func() logging.ConnectionTracer
}
var _ logging.Tracer = &tracer{}
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c}
}
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer()
}