use more tls.Config options in the handshake fuzzer

This commit is contained in:
Marten Seemann 2020-09-06 13:00:57 +07:00
parent adadc06181
commit 382c923a67

View file

@ -5,8 +5,13 @@ import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"math"
mrand "math/rand"
"time"
"github.com/lucas-clemente/quic-go/fuzzing/internal/helper"
"github.com/lucas-clemente/quic-go/internal/handshake"
@ -15,8 +20,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
var cert *tls.Certificate
var certPool *x509.CertPool
var cert, clientCert *tls.Certificate
var certPool, clientCertPool *x509.CertPool
var sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
func init() {
@ -28,6 +33,15 @@ func init() {
if err != nil {
log.Fatal(err)
}
privClient, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
log.Fatal(err)
}
clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
if err != nil {
log.Fatal(err)
}
}
type messageType uint8
@ -67,6 +81,49 @@ func (m messageType) String() string {
}
}
func appendSuites(suites []uint16, rand uint8) []uint16 {
const (
s1 = tls.TLS_AES_128_GCM_SHA256
s2 = tls.TLS_AES_256_GCM_SHA384
s3 = tls.TLS_CHACHA20_POLY1305_SHA256
)
switch rand % 4 {
default:
return suites
case 1:
return append(suites, s1)
case 2:
return append(suites, s2)
case 3:
return append(suites, s3)
}
}
// consumes 2 bits
func getSuites(rand uint8) []uint16 {
suites := make([]uint16, 0, 3)
for i := 1; i <= 3; i++ {
suites = appendSuites(suites, rand>>i%4)
}
return suites
}
// consumes 3 bits
func getClientAuth(rand uint8) tls.ClientAuthType {
switch rand {
default:
return tls.NoClientCert
case 0:
return tls.RequestClientCert
case 1:
return tls.RequireAnyClientCert
case 2:
return tls.VerifyClientCertIfGiven
case 3:
return tls.RequireAndVerifyClientCert
}
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
@ -125,6 +182,7 @@ func (r *runner) OnError(err error) {
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
const alpn = "fuzzing"
const alpnWrong = "wrong"
func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
switch n % 3 {
@ -158,8 +216,20 @@ func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) pr
}
}
func getTransportParameters(seed uint8) *wire.TransportParameters {
const maxVarInt = math.MaxUint64 / 4
r := mrand.New(mrand.NewSource(int64(seed)))
return &wire.TransportParameters{
InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)),
}
}
// PrefixLen is the number of bytes used for configuration
const PrefixLen = 4
const PrefixLen = 12
const confLen = 5
// Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
//go:generate go run ./cmd/corpus.go
@ -167,18 +237,26 @@ func Fuzz(data []byte) int {
if len(data) < PrefixLen {
return -1
}
runConfig1 := data[0]
messageConfig1 := data[1]
runConfig2 := data[2]
messageConfig2 := data[3]
data = data[PrefixLen:]
dataLen := len(data)
var runConfig1, runConfig2 [confLen]byte
copy(runConfig1[:], data)
data = data[confLen:]
messageConfig1 := data[0]
data = data[1:]
copy(runConfig2[:], data)
data = data[confLen:]
messageConfig2 := data[0]
data = data[1:]
if dataLen != len(data)+PrefixLen {
panic("incorrect configuration")
}
clientConf := &tls.Config{
ServerName: "localhost",
NextProtos: []string{alpn},
RootCAs: certPool,
}
useSessionTicketCache := helper.NthBit(runConfig1, 2)
useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
if useSessionTicketCache {
clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
}
@ -189,12 +267,89 @@ func Fuzz(data []byte) int {
return runHandshake(runConfig2, messageConfig2, clientConf, data)
}
func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config, data []byte) int {
enable0RTTClient := helper.NthBit(runConfig, 0)
enable0RTTServer := helper.NthBit(runConfig, 1)
sendPostHandshakeMessageToClient := helper.NthBit(runConfig, 3)
sendPostHandshakeMessageToServer := helper.NthBit(runConfig, 4)
sendSessionTicket := helper.NthBit(runConfig, 5)
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
serverConf := &tls.Config{
Certificates: []tls.Certificate{*cert},
NextProtos: []string{alpn},
SessionTicketKey: sessionTicketKey,
}
enable0RTTClient := helper.NthBit(runConfig[0], 0)
enable0RTTServer := helper.NthBit(runConfig[0], 1)
sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
sendSessionTicket := helper.NthBit(runConfig[0], 5)
clientConf.CipherSuites = getSuites(runConfig[0] >> 6)
serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
serverConf.CipherSuites = getSuites(runConfig[1] >> 6)
serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
clientConf.PreferServerCipherSuites = helper.NthBit(runConfig[1], 4)
if helper.NthBit(runConfig[2], 0) {
clientConf.RootCAs = x509.NewCertPool()
}
if helper.NthBit(runConfig[2], 1) {
serverConf.ClientCAs = clientCertPool
} else {
serverConf.ClientCAs = x509.NewCertPool()
}
if helper.NthBit(runConfig[2], 2) {
serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
if helper.NthBit(runConfig[2], 3) {
return nil, errors.New("getting client config failed")
}
if helper.NthBit(runConfig[2], 4) {
return nil, nil
}
return serverConf, nil
}
}
if helper.NthBit(runConfig[2], 5) {
serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if helper.NthBit(runConfig[2], 6) {
return nil, errors.New("getting certificate failed")
}
if helper.NthBit(runConfig[2], 7) {
return nil, nil
}
return clientCert, nil // this certificate will be invalid
}
}
if helper.NthBit(runConfig[3], 0) {
serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if helper.NthBit(runConfig[3], 1) {
return errors.New("certificate verification failed")
}
return nil
}
}
if helper.NthBit(runConfig[3], 2) {
clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if helper.NthBit(runConfig[3], 3) {
return errors.New("certificate verification failed")
}
return nil
}
}
if helper.NthBit(runConfig[3], 4) {
serverConf.NextProtos = []string{alpnWrong}
}
if helper.NthBit(runConfig[3], 5) {
serverConf.NextProtos = []string{alpnWrong, alpn}
}
if helper.NthBit(runConfig[3], 6) {
serverConf.KeyLogWriter = ioutil.Discard
}
if helper.NthBit(runConfig[3], 7) {
clientConf.KeyLogWriter = ioutil.Discard
}
clientTP := getTransportParameters(runConfig[4] & 0x3)
if helper.NthBit(runConfig[4], 3) {
clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
}
serverTP := getTransportParameters(runConfig[4] & 0b00011000)
if helper.NthBit(runConfig[4], 3) {
serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
}
messageToReplace := messageConfig % 32
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
@ -208,7 +363,7 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
clientTP,
runner,
clientConf,
enable0RTTClient,
@ -224,13 +379,9 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
serverTP,
runner,
&tls.Config{
Certificates: []tls.Certificate{*cert},
NextProtos: []string{alpn},
SessionTicketKey: sessionTicketKey,
},
serverConf,
enable0RTTServer,
utils.NewRTTStats(),
nil,
@ -267,7 +418,7 @@ messageLoop:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Println("replacing message to the server", messageType(b[0]).String())
fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
}
@ -276,7 +427,7 @@ messageLoop:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Println("replacing message to the client", messageType(b[0]).String())
fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
}
@ -295,7 +446,26 @@ messageLoop:
if runner.errored {
return 0
}
if sendSessionTicket {
sealer, err := client.Get1RTTSealer()
if err != nil {
panic("expected to get a 1-RTT sealer")
}
opener, err := server.Get1RTTOpener()
if err != nil {
panic("expected to get a 1-RTT opener")
}
const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
if err != nil {
panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
}
if string(decrypted) != msg {
panic("wrong message")
}
if sendSessionTicket && !serverConf.SessionTicketsDisabled {
ticket, err := server.GetSessionTicket()
if err != nil {
panic(err)