From 0dbe595d9fcb88dc68114a7b271bacc4597f2c13 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 22 Apr 2023 11:19:30 +0200 Subject: [PATCH] move the version negotiation tests to a separate package --- .circleci/config.yml | 6 + .github/workflows/integration.yml | 6 +- integrationtests/self/handshake_test.go | 104 ------------ integrationtests/self/self_suite_test.go | 153 ++---------------- integrationtests/tools/crypto.go | 120 ++++++++++++++ integrationtests/tools/qlog.go | 31 ++++ .../versionnegotiation/handshake_test.go | 141 ++++++++++++++++ .../versionnegotiation_suite_test.go | 89 ++++++++++ 8 files changed, 406 insertions(+), 244 deletions(-) create mode 100644 integrationtests/tools/crypto.go create mode 100644 integrationtests/tools/qlog.go create mode 100644 integrationtests/versionnegotiation/handshake_test.go create mode 100644 integrationtests/versionnegotiation/versionnegotiation_suite_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 3cb867f3..2ea17cc7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -27,12 +27,18 @@ jobs: - run: name: "Run self integration tests" command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/self + - run: + name: "Run version negotiation tests" + command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation - run: name: "Run self integration tests with race detector" command: go run github.com/onsi/ginkgo/v2/ginkgo -race -v -randomize-all -trace integrationtests/self - run: name: "Run self integration tests with qlog" command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/self -- -qlog + - run: + name: "Run version negotiation tests with qlog" + command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation -- -qlog go119: <<: *test go120: diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 07b7dc99..d13a5c7c 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -23,13 +23,15 @@ jobs: run: echo "QLOGFLAG=-- -qlog" >> $GITHUB_ENV - name: Run tests run: | - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self integrationtests + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }} go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }} - name: Run tests (32 bit) env: GOARCH: 386 run: | - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self integrationtests + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }} go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }} - name: save qlogs if: ${{ always() && env.DEBUG == 'true' }} diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 401fee99..7f046c4d 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -10,20 +10,14 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/integrationtests/tools/israce" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qtls" - "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -type versioner interface { - GetVersion() protocol.VersionNumber -} - type tokenStore struct { store quic.TokenStore gets chan<- string @@ -50,31 +44,6 @@ func (c *tokenStore) Pop(key string) *quic.ClientToken { return c.store.Pop(key) } -type versionNegotiationTracer struct { - logging.NullConnectionTracer - - loggedVersions bool - receivedVersionNegotiation bool - chosen logging.VersionNumber - clientVersions, serverVersions []logging.VersionNumber -} - -var _ logging.ConnectionTracer = &versionNegotiationTracer{} - -func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { - if t.loggedVersions { - Fail("only expected one call to NegotiatedVersions") - } - t.loggedVersions = true - t.chosen = chosen - t.clientVersions = clientVersions - t.serverVersions = serverVersions -} - -func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { - t.receivedVersionNegotiation = true -} - var _ = Describe("Handshake tests", func() { var ( server quic.Listener @@ -112,79 +81,6 @@ var _ = Describe("Handshake tests", func() { }() } - if !israce.Enabled { - Context("Version Negotiation", func() { - var supportedVersions []protocol.VersionNumber - - BeforeEach(func() { - supportedVersions = protocol.SupportedVersions - protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...) - }) - - AfterEach(func() { - protocol.SupportedVersions = supportedVersions - }) - - It("when the server supports more versions than the client", func() { - expectedVersion := protocol.SupportedVersions[0] - // the server doesn't support the highest supported version, which is the first one the client will try - // but it supports a bunch of versions that the client doesn't speak - serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) - runServer(getTLSConfig()) - defer server.Close() - clientTracer := &versionNegotiationTracer{} - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) - Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) - Expect(clientTracer.serverVersions).To(BeEmpty()) - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) - }) - - It("when the client supports more versions than the server supports", func() { - expectedVersion := protocol.SupportedVersions[0] - // the server doesn't support the highest supported version, which is the first one the client will try - // but it supports a bunch of versions that the client doesn't speak - serverConfig.Versions = supportedVersions - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) - runServer(getTLSConfig()) - defer server.Close() - clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - Versions: clientVersions, - Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) - Expect(clientTracer.clientVersions).To(Equal(clientVersions)) - Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) - }) - }) - } - Context("using different cipher suites", func() { for n, id := range map[string]uint16{ "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 943b8753..83d1b9e0 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -1,20 +1,12 @@ package self_test import ( - "bufio" "bytes" "context" - "crypto" - "crypto/rand" - "crypto/rsa" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "flag" - "fmt" - "io" "log" - "math/big" mrand "math/rand" "os" "runtime/pprof" @@ -24,19 +16,17 @@ import ( "testing" "time" - "golang.org/x/crypto/ed25519" - "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/integrationtests/tools" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" - "github.com/quic-go/quic-go/qlog" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -const alpn = "quic-go integration tests" +const alpn = tools.ALPN const ( dataLen = 500 * 1024 // 500 KB @@ -93,12 +83,13 @@ var ( logFileName string // the log file set in the ginkgo flags logBufOnce sync.Once logBuf *syncedBuffer - enableQlog bool + + qlogTracer logging.Tracer + enableQlog bool tlsConfig *tls.Config tlsConfigLongChain *tls.Config tlsClientConfig *tls.Config - quicConfigTracer logging.Tracer ) // read the logfile command line flag @@ -107,11 +98,11 @@ func init() { flag.StringVar(&logFileName, "logfile", "", "log file") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") - ca, caPrivateKey, err := generateCA() + ca, caPrivateKey, err := tools.GenerateCA() if err != nil { panic(err) } - leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) + leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) if err != nil { panic(err) } @@ -122,7 +113,7 @@ func init() { }}, NextProtos: []string{alpn}, } - tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey) + tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey) if err != nil { panic(err) } @@ -140,126 +131,10 @@ var _ = BeforeSuite(func() { mrand.Seed(GinkgoRandomSeed()) if enableQlog { - quicConfigTracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { - role := "server" - if p == logging.PerspectiveClient { - role = "client" - } - filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) - fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename) - f, err := os.Create(filename) - Expect(err).ToNot(HaveOccurred()) - bw := bufio.NewWriter(f) - return utils.NewBufferedWriteCloser(bw, f) - }) + qlogTracer = tools.NewQlogger(GinkgoWriter) } }) -func generateCA() (*x509.Certificate, crypto.PrivateKey, error) { - certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv) - if err != nil { - return nil, nil, err - } - ca, err := x509.ParseCertificate(caBytes) - if err != nil { - return nil, nil, err - } - return ca, priv, nil -} - -func generateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) { - certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - DNSNames: []string{"localhost"}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv) - if err != nil { - return nil, nil, err - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, err - } - return cert, priv, nil -} - -// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain. -// The Root CA used is the same as for the config returned from getTLSConfig(). -func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey crypto.PrivateKey) (*tls.Config, error) { - const chainLen = 7 - certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - lastCA := ca - lastCAPrivKey := caPrivateKey - privKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - certs := make([]*x509.Certificate, chainLen) - for i := 0; i < chainLen; i++ { - caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey) - if err != nil { - return nil, err - } - ca, err := x509.ParseCertificate(caBytes) - if err != nil { - return nil, err - } - certs[i] = ca - lastCA = ca - lastCAPrivKey = privKey - } - leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey) - if err != nil { - return nil, err - } - - rawCerts := make([][]byte, chainLen+1) - for i, cert := range certs { - rawCerts[chainLen-i] = cert.Raw - } - rawCerts[0] = leafCert.Raw - - return &tls.Config{ - Certificates: []tls.Certificate{{ - Certificate: rawCerts, - PrivateKey: leafPrivateKey, - }}, - NextProtos: []string{alpn}, - }, nil -} - func getTLSConfig() *tls.Config { return tlsConfig.Clone() } @@ -278,10 +153,12 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } else { conf = conf.Clone() } - if conf.Tracer == nil { - conf.Tracer = quicConfigTracer - } else if quicConfigTracer != nil { - conf.Tracer = logging.NewMultiplexedTracer(quicConfigTracer, conf.Tracer) + if enableQlog { + if conf.Tracer == nil { + conf.Tracer = qlogTracer + } else if qlogTracer != nil { + conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer) + } } return conf } diff --git a/integrationtests/tools/crypto.go b/integrationtests/tools/crypto.go new file mode 100644 index 00000000..d0bdfb41 --- /dev/null +++ b/integrationtests/tools/crypto.go @@ -0,0 +1,120 @@ +package tools + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +const ALPN = "quic-go integration tests" + +func GenerateCA() (*x509.Certificate, crypto.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv) + if err != nil { + return nil, nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, priv, nil +} + +func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv) + if err != nil { + return nil, nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, err + } + return cert, priv, nil +} + +// GenerateTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain. +// The Root CA used is the same as for the config returned from getTLSConfig(). +func GenerateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey crypto.PrivateKey) (*tls.Config, error) { + const chainLen = 7 + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + lastCA := ca + lastCAPrivKey := caPrivateKey + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + certs := make([]*x509.Certificate, chainLen) + for i := 0; i < chainLen; i++ { + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey) + if err != nil { + return nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, err + } + certs[i] = ca + lastCA = ca + lastCAPrivKey = privKey + } + leafCert, leafPrivateKey, err := GenerateLeafCert(lastCA, lastCAPrivKey) + if err != nil { + return nil, err + } + + rawCerts := make([][]byte, chainLen+1) + for i, cert := range certs { + rawCerts[chainLen-i] = cert.Raw + } + rawCerts[0] = leafCert.Raw + + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: rawCerts, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{ALPN}, + }, nil +} diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go new file mode 100644 index 00000000..a0854260 --- /dev/null +++ b/integrationtests/tools/qlog.go @@ -0,0 +1,31 @@ +package tools + +import ( + "bufio" + "fmt" + "io" + "log" + "os" + + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" + "github.com/quic-go/quic-go/qlog" +) + +func NewQlogger(logger io.Writer) logging.Tracer { + return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { + role := "server" + if p == logging.PerspectiveClient { + role = "client" + } + filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) + fmt.Fprintf(logger, "Creating %s.\n", filename) + f, err := os.Create(filename) + if err != nil { + log.Fatalf("failed to create qlog file: %s", err) + return nil + } + bw := bufio.NewWriter(f) + return utils.NewBufferedWriteCloser(bw, f) + }) +} diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go new file mode 100644 index 00000000..b2cd1269 --- /dev/null +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -0,0 +1,141 @@ +package versionnegotiation + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/integrationtests/tools/israce" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +type versioner interface { + GetVersion() protocol.VersionNumber +} + +type versionNegotiationTracer struct { + logging.NullConnectionTracer + + loggedVersions bool + receivedVersionNegotiation bool + chosen logging.VersionNumber + clientVersions, serverVersions []logging.VersionNumber +} + +var _ logging.ConnectionTracer = &versionNegotiationTracer{} + +func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + if t.loggedVersions { + Fail("only expected one call to NegotiatedVersions") + } + t.loggedVersions = true + t.chosen = chosen + t.clientVersions = clientVersions + t.serverVersions = serverVersions +} + +func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { + t.receivedVersionNegotiation = true +} + +var _ = Describe("Handshake tests", func() { + startServer := func(tlsConf *tls.Config, conf *quic.Config) (quic.Listener, func()) { + server, err := quic.ListenAddr("localhost:0", tlsConf, conf) + Expect(err).ToNot(HaveOccurred()) + + acceptStopped := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(acceptStopped) + for { + if _, err := server.Accept(context.Background()); err != nil { + return + } + } + }() + + return server, func() { + server.Close() + <-acceptStopped + } + } + + var supportedVersions []protocol.VersionNumber + + BeforeEach(func() { + supportedVersions = protocol.SupportedVersions + protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...) + }) + + AfterEach(func() { + protocol.SupportedVersions = supportedVersions + }) + + if !israce.Enabled { + It("when the server supports more versions than the client", func() { + expectedVersion := protocol.SupportedVersions[0] + // the server doesn't support the highest supported version, which is the first one the client will try + // but it supports a bunch of versions that the client doesn't speak + serverConfig := &quic.Config{} + serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} + serverTracer := &versionNegotiationTracer{} + serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + server, cl := startServer(getTLSConfig(), serverConfig) + defer cl() + clientTracer := &versionNegotiationTracer{} + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + Expect(clientTracer.chosen).To(Equal(expectedVersion)) + Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) + Expect(clientTracer.serverVersions).To(BeEmpty()) + Expect(serverTracer.chosen).To(Equal(expectedVersion)) + Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverTracer.clientVersions).To(BeEmpty()) + }) + + It("when the client supports more versions than the server supports", func() { + expectedVersion := protocol.SupportedVersions[0] + // the server doesn't support the highest supported version, which is the first one the client will try + // but it supports a bunch of versions that the client doesn't speak + serverConfig := &quic.Config{} + serverConfig.Versions = supportedVersions + serverTracer := &versionNegotiationTracer{} + serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + server, cl := startServer(getTLSConfig(), serverConfig) + defer cl() + clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientTracer := &versionNegotiationTracer{} + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + maybeAddQlogTracer(&quic.Config{ + Versions: clientVersions, + Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + Expect(clientTracer.chosen).To(Equal(expectedVersion)) + Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) + Expect(clientTracer.clientVersions).To(Equal(clientVersions)) + Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions + Expect(serverTracer.chosen).To(Equal(expectedVersion)) + Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverTracer.clientVersions).To(BeEmpty()) + }) + } +}) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go new file mode 100644 index 00000000..0e1894d1 --- /dev/null +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -0,0 +1,89 @@ +package versionnegotiation + +import ( + "context" + "crypto/tls" + "crypto/x509" + "flag" + "testing" + + "github.com/quic-go/quic-go/integrationtests/tools" + "github.com/quic-go/quic-go/logging" + + "github.com/quic-go/quic-go" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var ( + enableQlog bool + tlsConfig *tls.Config + tlsClientConfig *tls.Config +) + +func init() { + flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") + + ca, caPrivateKey, err := tools.GenerateCA() + if err != nil { + panic(err) + } + leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) + if err != nil { + panic(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{tools.ALPN}, + } + + root := x509.NewCertPool() + root.AddCert(ca) + tlsClientConfig = &tls.Config{ + RootCAs: root, + NextProtos: []string{tools.ALPN}, + } +} + +func getTLSConfig() *tls.Config { return tlsConfig } +func getTLSClientConfig() *tls.Config { return tlsClientConfig } + +func TestQuicVersionNegotiation(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Version Negotiation Suite") +} + +func maybeAddQlogTracer(c *quic.Config) *quic.Config { + if c == nil { + c = &quic.Config{} + } + if !enableQlog { + return c + } + qlogger := tools.NewQlogger(GinkgoWriter) + if c.Tracer == nil { + c.Tracer = qlogger + } else if qlogger != nil { + c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer) + } + return c +} + +type tracer struct { + logging.NullTracer + createNewConnTracer func() logging.ConnectionTracer +} + +var _ logging.Tracer = &tracer{} + +func newTracer(c func() logging.ConnectionTracer) logging.Tracer { + return &tracer{createNewConnTracer: c} +} + +func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { + return t.createNewConnTracer() +}