mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use more tls.Config options in the handshake fuzzer
This commit is contained in:
parent
adadc06181
commit
382c923a67
1 changed files with 195 additions and 25 deletions
|
@ -5,8 +5,13 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
mrand "math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/fuzzing/internal/helper"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
|
@ -15,8 +20,8 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var cert *tls.Certificate
|
||||
var certPool *x509.CertPool
|
||||
var cert, clientCert *tls.Certificate
|
||||
var certPool, clientCertPool *x509.CertPool
|
||||
var sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
||||
|
||||
func init() {
|
||||
|
@ -28,6 +33,15 @@ func init() {
|
|||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
privClient, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type messageType uint8
|
||||
|
@ -67,6 +81,49 @@ func (m messageType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func appendSuites(suites []uint16, rand uint8) []uint16 {
|
||||
const (
|
||||
s1 = tls.TLS_AES_128_GCM_SHA256
|
||||
s2 = tls.TLS_AES_256_GCM_SHA384
|
||||
s3 = tls.TLS_CHACHA20_POLY1305_SHA256
|
||||
)
|
||||
switch rand % 4 {
|
||||
default:
|
||||
return suites
|
||||
case 1:
|
||||
return append(suites, s1)
|
||||
case 2:
|
||||
return append(suites, s2)
|
||||
case 3:
|
||||
return append(suites, s3)
|
||||
}
|
||||
}
|
||||
|
||||
// consumes 2 bits
|
||||
func getSuites(rand uint8) []uint16 {
|
||||
suites := make([]uint16, 0, 3)
|
||||
for i := 1; i <= 3; i++ {
|
||||
suites = appendSuites(suites, rand>>i%4)
|
||||
}
|
||||
return suites
|
||||
}
|
||||
|
||||
// consumes 3 bits
|
||||
func getClientAuth(rand uint8) tls.ClientAuthType {
|
||||
switch rand {
|
||||
default:
|
||||
return tls.NoClientCert
|
||||
case 0:
|
||||
return tls.RequestClientCert
|
||||
case 1:
|
||||
return tls.RequireAnyClientCert
|
||||
case 2:
|
||||
return tls.VerifyClientCertIfGiven
|
||||
case 3:
|
||||
return tls.RequireAndVerifyClientCert
|
||||
}
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
encLevel protocol.EncryptionLevel
|
||||
|
@ -125,6 +182,7 @@ func (r *runner) OnError(err error) {
|
|||
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
|
||||
|
||||
const alpn = "fuzzing"
|
||||
const alpnWrong = "wrong"
|
||||
|
||||
func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
|
||||
switch n % 3 {
|
||||
|
@ -158,8 +216,20 @@ func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) pr
|
|||
}
|
||||
}
|
||||
|
||||
func getTransportParameters(seed uint8) *wire.TransportParameters {
|
||||
const maxVarInt = math.MaxUint64 / 4
|
||||
r := mrand.New(mrand.NewSource(int64(seed)))
|
||||
return &wire.TransportParameters{
|
||||
InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)),
|
||||
InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)),
|
||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
|
||||
InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)),
|
||||
}
|
||||
}
|
||||
|
||||
// PrefixLen is the number of bytes used for configuration
|
||||
const PrefixLen = 4
|
||||
const PrefixLen = 12
|
||||
const confLen = 5
|
||||
|
||||
// Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
|
||||
//go:generate go run ./cmd/corpus.go
|
||||
|
@ -167,18 +237,26 @@ func Fuzz(data []byte) int {
|
|||
if len(data) < PrefixLen {
|
||||
return -1
|
||||
}
|
||||
runConfig1 := data[0]
|
||||
messageConfig1 := data[1]
|
||||
runConfig2 := data[2]
|
||||
messageConfig2 := data[3]
|
||||
data = data[PrefixLen:]
|
||||
dataLen := len(data)
|
||||
var runConfig1, runConfig2 [confLen]byte
|
||||
copy(runConfig1[:], data)
|
||||
data = data[confLen:]
|
||||
messageConfig1 := data[0]
|
||||
data = data[1:]
|
||||
copy(runConfig2[:], data)
|
||||
data = data[confLen:]
|
||||
messageConfig2 := data[0]
|
||||
data = data[1:]
|
||||
if dataLen != len(data)+PrefixLen {
|
||||
panic("incorrect configuration")
|
||||
}
|
||||
|
||||
clientConf := &tls.Config{
|
||||
ServerName: "localhost",
|
||||
NextProtos: []string{alpn},
|
||||
RootCAs: certPool,
|
||||
}
|
||||
useSessionTicketCache := helper.NthBit(runConfig1, 2)
|
||||
useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
|
||||
if useSessionTicketCache {
|
||||
clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
|
||||
}
|
||||
|
@ -189,12 +267,89 @@ func Fuzz(data []byte) int {
|
|||
return runHandshake(runConfig2, messageConfig2, clientConf, data)
|
||||
}
|
||||
|
||||
func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config, data []byte) int {
|
||||
enable0RTTClient := helper.NthBit(runConfig, 0)
|
||||
enable0RTTServer := helper.NthBit(runConfig, 1)
|
||||
sendPostHandshakeMessageToClient := helper.NthBit(runConfig, 3)
|
||||
sendPostHandshakeMessageToServer := helper.NthBit(runConfig, 4)
|
||||
sendSessionTicket := helper.NthBit(runConfig, 5)
|
||||
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
|
||||
serverConf := &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
NextProtos: []string{alpn},
|
||||
SessionTicketKey: sessionTicketKey,
|
||||
}
|
||||
|
||||
enable0RTTClient := helper.NthBit(runConfig[0], 0)
|
||||
enable0RTTServer := helper.NthBit(runConfig[0], 1)
|
||||
sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
|
||||
sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
|
||||
sendSessionTicket := helper.NthBit(runConfig[0], 5)
|
||||
clientConf.CipherSuites = getSuites(runConfig[0] >> 6)
|
||||
serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
|
||||
serverConf.CipherSuites = getSuites(runConfig[1] >> 6)
|
||||
serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
|
||||
clientConf.PreferServerCipherSuites = helper.NthBit(runConfig[1], 4)
|
||||
if helper.NthBit(runConfig[2], 0) {
|
||||
clientConf.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if helper.NthBit(runConfig[2], 1) {
|
||||
serverConf.ClientCAs = clientCertPool
|
||||
} else {
|
||||
serverConf.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if helper.NthBit(runConfig[2], 2) {
|
||||
serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
if helper.NthBit(runConfig[2], 3) {
|
||||
return nil, errors.New("getting client config failed")
|
||||
}
|
||||
if helper.NthBit(runConfig[2], 4) {
|
||||
return nil, nil
|
||||
}
|
||||
return serverConf, nil
|
||||
}
|
||||
}
|
||||
if helper.NthBit(runConfig[2], 5) {
|
||||
serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if helper.NthBit(runConfig[2], 6) {
|
||||
return nil, errors.New("getting certificate failed")
|
||||
}
|
||||
if helper.NthBit(runConfig[2], 7) {
|
||||
return nil, nil
|
||||
}
|
||||
return clientCert, nil // this certificate will be invalid
|
||||
}
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 0) {
|
||||
serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
if helper.NthBit(runConfig[3], 1) {
|
||||
return errors.New("certificate verification failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 2) {
|
||||
clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
if helper.NthBit(runConfig[3], 3) {
|
||||
return errors.New("certificate verification failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 4) {
|
||||
serverConf.NextProtos = []string{alpnWrong}
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 5) {
|
||||
serverConf.NextProtos = []string{alpnWrong, alpn}
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 6) {
|
||||
serverConf.KeyLogWriter = ioutil.Discard
|
||||
}
|
||||
if helper.NthBit(runConfig[3], 7) {
|
||||
clientConf.KeyLogWriter = ioutil.Discard
|
||||
}
|
||||
clientTP := getTransportParameters(runConfig[4] & 0x3)
|
||||
if helper.NthBit(runConfig[4], 3) {
|
||||
clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
|
||||
}
|
||||
serverTP := getTransportParameters(runConfig[4] & 0b00011000)
|
||||
if helper.NthBit(runConfig[4], 3) {
|
||||
serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
|
||||
}
|
||||
|
||||
messageToReplace := messageConfig % 32
|
||||
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
|
||||
|
@ -208,7 +363,7 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config,
|
|||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{},
|
||||
clientTP,
|
||||
runner,
|
||||
clientConf,
|
||||
enable0RTTClient,
|
||||
|
@ -224,13 +379,9 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config,
|
|||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{},
|
||||
serverTP,
|
||||
runner,
|
||||
&tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
NextProtos: []string{alpn},
|
||||
SessionTicketKey: sessionTicketKey,
|
||||
},
|
||||
serverConf,
|
||||
enable0RTTServer,
|
||||
utils.NewRTTStats(),
|
||||
nil,
|
||||
|
@ -267,7 +418,7 @@ messageLoop:
|
|||
b := c.data
|
||||
encLevel := c.encLevel
|
||||
if len(b) > 0 && b[0] == messageToReplace {
|
||||
fmt.Println("replacing message to the server", messageType(b[0]).String())
|
||||
fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
|
||||
b = data
|
||||
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
|
||||
}
|
||||
|
@ -276,7 +427,7 @@ messageLoop:
|
|||
b := c.data
|
||||
encLevel := c.encLevel
|
||||
if len(b) > 0 && b[0] == messageToReplace {
|
||||
fmt.Println("replacing message to the client", messageType(b[0]).String())
|
||||
fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
|
||||
b = data
|
||||
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
|
||||
}
|
||||
|
@ -295,7 +446,26 @@ messageLoop:
|
|||
if runner.errored {
|
||||
return 0
|
||||
}
|
||||
if sendSessionTicket {
|
||||
|
||||
sealer, err := client.Get1RTTSealer()
|
||||
if err != nil {
|
||||
panic("expected to get a 1-RTT sealer")
|
||||
}
|
||||
opener, err := server.Get1RTTOpener()
|
||||
if err != nil {
|
||||
panic("expected to get a 1-RTT opener")
|
||||
}
|
||||
const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
|
||||
decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
|
||||
}
|
||||
if string(decrypted) != msg {
|
||||
panic("wrong message")
|
||||
}
|
||||
|
||||
if sendSessionTicket && !serverConf.SessionTicketsDisabled {
|
||||
ticket, err := server.GetSessionTicket()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue