mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
add integration tests using a very long certificate chain
This will trigger the amplification protection.
This commit is contained in:
parent
e4f02ff68c
commit
e33f7d0fb9
3 changed files with 185 additions and 90 deletions
|
@ -2,6 +2,7 @@ package self_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
|
@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
|
||||
const timeout = 10 * time.Minute
|
||||
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) {
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
|
||||
conf := getQuicConfigForServer(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeTimeout: timeout,
|
||||
|
@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
if !doRetry {
|
||||
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
|
||||
}
|
||||
var tlsConf *tls.Config
|
||||
if longCertChain {
|
||||
tlsConf = getTLSConfigWithLongCertChain()
|
||||
} else {
|
||||
tlsConf = getTLSConfig()
|
||||
}
|
||||
var err error
|
||||
ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
||||
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
|
@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}
|
||||
|
||||
Context(desc, func() {
|
||||
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
|
||||
app := a
|
||||
for _, lcc := range []bool{false, true} {
|
||||
longCertChain := lcc
|
||||
|
||||
Context(app.name, func() {
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = atomic.AddInt32(&incoming, 1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = atomic.AddInt32(&outgoing, 1)
|
||||
}
|
||||
return p == 1 && d.Is(direction)
|
||||
}, doRetry, version)
|
||||
app.run(version)
|
||||
})
|
||||
Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
|
||||
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
|
||||
app := a
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = atomic.AddInt32(&incoming, 1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = atomic.AddInt32(&outgoing, 1)
|
||||
}
|
||||
return p == 2 && d.Is(direction)
|
||||
}, doRetry, version)
|
||||
app.run(version)
|
||||
})
|
||||
Context(app.name, func() {
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = atomic.AddInt32(&incoming, 1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = atomic.AddInt32(&outgoing, 1)
|
||||
}
|
||||
return p == 1 && d.Is(direction)
|
||||
}, doRetry, longCertChain, version)
|
||||
app.run(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
return d.Is(direction) && stochasticDropper(3)
|
||||
}, doRetry, version)
|
||||
app.run(version)
|
||||
})
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = atomic.AddInt32(&incoming, 1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = atomic.AddInt32(&outgoing, 1)
|
||||
}
|
||||
return p == 2 && d.Is(direction)
|
||||
}, doRetry, longCertChain, version)
|
||||
app.run(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
return d.Is(direction) && stochasticDropper(3)
|
||||
}, doRetry, longCertChain, version)
|
||||
app.run(version)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
|
@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() {
|
|||
server quic.Listener
|
||||
serverConfig *quic.Config
|
||||
acceptStopped chan struct{}
|
||||
tlsServerConf *tls.Config
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
server = nil
|
||||
acceptStopped = make(chan struct{})
|
||||
serverConfig = getQuicConfigForServer(nil)
|
||||
tlsServerConf = getTLSConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() {
|
|||
}
|
||||
})
|
||||
|
||||
runServer := func() quic.Listener {
|
||||
runServer := func(tlsConf *tls.Config) {
|
||||
var err error
|
||||
// start the server
|
||||
server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
|
||||
server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
|
@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() {
|
|||
}
|
||||
}
|
||||
}()
|
||||
return server
|
||||
}
|
||||
|
||||
if !israce.Enabled {
|
||||
|
@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
|
||||
server := runServer()
|
||||
runServer(getTLSConfig())
|
||||
defer server.Close()
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
|
@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = supportedVersions
|
||||
server := runServer()
|
||||
runServer(getTLSConfig())
|
||||
defer server.Close()
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
|
@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() {
|
|||
suiteID := id
|
||||
|
||||
It(fmt.Sprintf("using %s", name), func() {
|
||||
tlsServerConf.CipherSuites = []uint16{suiteID}
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.CipherSuites = []uint16{suiteID}
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
}
|
||||
})
|
||||
|
||||
Context("Certifiate validation", func() {
|
||||
Context("Certificate validation", func() {
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
|
@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() {
|
|||
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}})
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
runServer()
|
||||
})
|
||||
|
||||
It("accepts the certificate", func() {
|
||||
runServer(getTLSConfig())
|
||||
_, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
|
@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("works with a long certificate chain", func() {
|
||||
runServer(getTLSConfigWithLongCertChain())
|
||||
_, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if the server name doesn't match", func() {
|
||||
runServer(getTLSConfig())
|
||||
_, err := quic.DialAddr(
|
||||
fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
|
@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
|
||||
It("fails the handshake if the client fails to provide the requested client cert", func() {
|
||||
tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
runServer(tlsConf)
|
||||
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
|
@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
|
||||
It("uses the ServerName in the tls.Config", func() {
|
||||
runServer(getTLSConfig())
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ServerName = "localhost"
|
||||
_, err := quic.DialAddr(
|
||||
|
@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
Context("ALPN", func() {
|
||||
It("negotiates an application protocol", func() {
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
|
||||
It("errors if application protocol negotiation fails", func() {
|
||||
server := runServer()
|
||||
runServer(getTLSConfig())
|
||||
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.NextProtos = []string{"foobar"}
|
||||
|
@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
|
||||
Expect(err.Error()).To(ContainSubstring("no application protocol"))
|
||||
Expect(server.Close()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -85,10 +85,9 @@ var (
|
|||
logBuf *syncedBuffer
|
||||
enableQlog bool
|
||||
|
||||
caPrivateKey *rsa.PrivateKey
|
||||
ca *x509.Certificate
|
||||
leafPrivateKey *rsa.PrivateKey
|
||||
leafCert *x509.Certificate
|
||||
tlsConfig *tls.Config
|
||||
tlsConfigLongChain *tls.Config
|
||||
tlsClientConfig *tls.Config
|
||||
)
|
||||
|
||||
// read the logfile command line flag
|
||||
|
@ -97,16 +96,37 @@ func init() {
|
|||
flag.StringVar(&logFileName, "logfile", "", "log file")
|
||||
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
|
||||
|
||||
if err := generateCA(); err != nil {
|
||||
ca, caPrivateKey, err := generateCA()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := generateCertChain(); err != nil {
|
||||
leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{tls.Certificate{
|
||||
Certificate: [][]byte{leafCert.Raw},
|
||||
PrivateKey: leafPrivateKey,
|
||||
}},
|
||||
NextProtos: []string{alpn},
|
||||
}
|
||||
tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tlsConfigLongChain = tlsConfLongChain
|
||||
|
||||
root := x509.NewCertPool()
|
||||
root.AddCert(ca)
|
||||
tlsClientConfig = &tls.Config{
|
||||
RootCAs: root,
|
||||
NextProtos: []string{alpn},
|
||||
}
|
||||
}
|
||||
|
||||
func generateCA() error {
|
||||
caCert := &x509.Certificate{
|
||||
func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
certTempl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2019),
|
||||
Subject: pkix.Name{},
|
||||
NotBefore: time.Now(),
|
||||
|
@ -116,21 +136,23 @@ func generateCA() error {
|
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
var err error
|
||||
caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey)
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
ca, err = x509.ParseCertificate(caBytes)
|
||||
return err
|
||||
ca, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return ca, caPrivateKey, nil
|
||||
}
|
||||
|
||||
func generateCertChain() error {
|
||||
cert := &x509.Certificate{
|
||||
func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
certTempl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
DNSNames: []string{"localhost"},
|
||||
NotBefore: time.Now(),
|
||||
|
@ -138,36 +160,86 @@ func generateCertChain() error {
|
|||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
var err error
|
||||
leafPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &leafPrivateKey.PublicKey, caPrivateKey)
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
leafCert, err = x509.ParseCertificate(certBytes)
|
||||
return err
|
||||
cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return cert, privKey, nil
|
||||
}
|
||||
|
||||
func getTLSConfig() *tls.Config {
|
||||
// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
|
||||
// The Root CA used is the same as for the config returned from getTLSConfig().
|
||||
func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) {
|
||||
const chainLen = 7
|
||||
certTempl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2019),
|
||||
Subject: pkix.Name{},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
lastCA := ca
|
||||
lastCAPrivKey := caPrivateKey
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certs := make([]*x509.Certificate, chainLen)
|
||||
for i := 0; i < chainLen; i++ {
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ca, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certs[i] = ca
|
||||
lastCA = ca
|
||||
lastCAPrivKey = privKey
|
||||
}
|
||||
leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawCerts := make([][]byte, chainLen+1)
|
||||
for i, cert := range certs {
|
||||
rawCerts[chainLen-i] = cert.Raw
|
||||
}
|
||||
rawCerts[0] = leafCert.Raw
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tls.Certificate{
|
||||
Certificate: [][]byte{leafCert.Raw},
|
||||
Certificate: rawCerts,
|
||||
PrivateKey: leafPrivateKey,
|
||||
}},
|
||||
NextProtos: []string{alpn},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getTLSConfig() *tls.Config {
|
||||
return tlsConfig.Clone()
|
||||
}
|
||||
|
||||
func getTLSConfigWithLongCertChain() *tls.Config {
|
||||
return tlsConfigLongChain.Clone()
|
||||
}
|
||||
|
||||
func getTLSClientConfig() *tls.Config {
|
||||
root := x509.NewCertPool()
|
||||
root.AddCert(ca)
|
||||
return &tls.Config{
|
||||
RootCAs: root,
|
||||
NextProtos: []string{alpn},
|
||||
}
|
||||
return tlsClientConfig.Clone()
|
||||
}
|
||||
|
||||
func getQuicConfigForClient(conf *quic.Config) *quic.Config {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue