add integration tests using a very long certificate chain

This will trigger the amplification protection.
This commit is contained in:
Marten Seemann 2020-05-11 15:58:40 +07:00
parent e4f02ff68c
commit e33f7d0fb9
3 changed files with 185 additions and 90 deletions

View file

@ -2,6 +2,7 @@ package self_test
import (
"context"
"crypto/tls"
"fmt"
mrand "math/rand"
"net"
@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() {
const timeout = 10 * time.Minute
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) {
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfigForServer(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeTimeout: timeout,
@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() {
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
var err error
ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf)
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() {
}
Context(desc, func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
for _, lcc := range []bool{false, true} {
longCertChain := lcc
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, version)
app.run(version)
})
Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, version)
app.run(version)
})
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, longCertChain, version)
app.run(version)
})
})
}
})
}
})

View file

@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() {
server quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
tlsServerConf *tls.Config
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = getQuicConfigForServer(nil)
tlsServerConf = getTLSConfig()
})
AfterEach(func() {
@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() {
}
})
runServer := func() quic.Listener {
runServer := func(tlsConf *tls.Config) {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() {
}
}
}()
return server
}
if !israce.Enabled {
@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() {
suiteID := id
It(fmt.Sprintf("using %s", name), func() {
tlsServerConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
tlsConf := getTLSConfig()
tlsConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
go func() {
defer GinkgoRecover()
@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() {
}
})
Context("Certifiate validation", func() {
Context("Certificate validation", func() {
for _, v := range protocol.SupportedVersions {
version := v
@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() {
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}})
})
JustBeforeEach(func() {
runServer()
})
It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
})
It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}),
)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the server name doesn't match", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() {
})
It("fails the handshake if the client fails to provide the requested client cert", func() {
tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert
tlsConf := getTLSConfig()
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
runServer(tlsConf)
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() {
})
It("uses the ServerName in the tls.Config", func() {
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.ServerName = "localhost"
_, err := quic.DialAddr(
@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() {
Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() {
})
It("errors if application protocol negotiation fails", func() {
server := runServer()
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.NextProtos = []string{"foobar"}
@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
Expect(err.Error()).To(ContainSubstring("no application protocol"))
Expect(server.Close()).To(Succeed())
})
})

View file

@ -85,10 +85,9 @@ var (
logBuf *syncedBuffer
enableQlog bool
caPrivateKey *rsa.PrivateKey
ca *x509.Certificate
leafPrivateKey *rsa.PrivateKey
leafCert *x509.Certificate
tlsConfig *tls.Config
tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config
)
// read the logfile command line flag
@ -97,16 +96,37 @@ func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
if err := generateCA(); err != nil {
ca, caPrivateKey, err := generateCA()
if err != nil {
panic(err)
}
if err := generateCertChain(); err != nil {
leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey)
if err != nil {
panic(err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{leafCert.Raw},
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{alpn},
}
tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey)
if err != nil {
panic(err)
}
tlsConfigLongChain = tlsConfLongChain
root := x509.NewCertPool()
root.AddCert(ca)
tlsClientConfig = &tls.Config{
RootCAs: root,
NextProtos: []string{alpn},
}
}
func generateCA() error {
caCert := &x509.Certificate{
func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
@ -116,21 +136,23 @@ func generateCA() error {
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
var err error
caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return err
return nil, nil, err
}
ca, err = x509.ParseCertificate(caBytes)
return err
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, caPrivateKey, nil
}
func generateCertChain() error {
cert := &x509.Certificate{
func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{"localhost"},
NotBefore: time.Now(),
@ -138,36 +160,86 @@ func generateCertChain() error {
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
var err error
leafPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &leafPrivateKey.PublicKey, caPrivateKey)
certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey)
if err != nil {
return err
return nil, nil, err
}
leafCert, err = x509.ParseCertificate(certBytes)
return err
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}
return cert, privKey, nil
}
func getTLSConfig() *tls.Config {
// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
// The Root CA used is the same as for the config returned from getTLSConfig().
func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) {
const chainLen = 7
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
lastCA := ca
lastCAPrivKey := caPrivateKey
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
certs := make([]*x509.Certificate, chainLen)
for i := 0; i < chainLen; i++ {
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
if err != nil {
return nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, err
}
certs[i] = ca
lastCA = ca
lastCAPrivKey = privKey
}
leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey)
if err != nil {
return nil, err
}
rawCerts := make([][]byte, chainLen+1)
for i, cert := range certs {
rawCerts[chainLen-i] = cert.Raw
}
rawCerts[0] = leafCert.Raw
return &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{leafCert.Raw},
Certificate: rawCerts,
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{alpn},
}
}, nil
}
func getTLSConfig() *tls.Config {
return tlsConfig.Clone()
}
func getTLSConfigWithLongCertChain() *tls.Config {
return tlsConfigLongChain.Clone()
}
func getTLSClientConfig() *tls.Config {
root := x509.NewCertPool()
root.AddCert(ca)
return &tls.Config{
RootCAs: root,
NextProtos: []string{alpn},
}
return tlsClientConfig.Clone()
}
func getQuicConfigForClient(conf *quic.Config) *quic.Config {