diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 1083f606..0cd15cc7 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -5,8 +5,13 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "errors" "fmt" + "io/ioutil" "log" + "math" + mrand "math/rand" + "time" "github.com/lucas-clemente/quic-go/fuzzing/internal/helper" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -15,8 +20,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -var cert *tls.Certificate -var certPool *x509.CertPool +var cert, clientCert *tls.Certificate +var certPool, clientCertPool *x509.CertPool var sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} func init() { @@ -28,6 +33,15 @@ func init() { if err != nil { log.Fatal(err) } + + privClient, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + log.Fatal(err) + } + clientCert, clientCertPool, err = helper.GenerateCertificate(privClient) + if err != nil { + log.Fatal(err) + } } type messageType uint8 @@ -67,6 +81,49 @@ func (m messageType) String() string { } } +func appendSuites(suites []uint16, rand uint8) []uint16 { + const ( + s1 = tls.TLS_AES_128_GCM_SHA256 + s2 = tls.TLS_AES_256_GCM_SHA384 + s3 = tls.TLS_CHACHA20_POLY1305_SHA256 + ) + switch rand % 4 { + default: + return suites + case 1: + return append(suites, s1) + case 2: + return append(suites, s2) + case 3: + return append(suites, s3) + } +} + +// consumes 2 bits +func getSuites(rand uint8) []uint16 { + suites := make([]uint16, 0, 3) + for i := 1; i <= 3; i++ { + suites = appendSuites(suites, rand>>i%4) + } + return suites +} + +// consumes 3 bits +func getClientAuth(rand uint8) tls.ClientAuthType { + switch rand { + default: + return tls.NoClientCert + case 0: + return tls.RequestClientCert + case 1: + return tls.RequireAnyClientCert + case 2: + return tls.VerifyClientCertIfGiven + case 3: + return tls.RequireAndVerifyClientCert + } +} + type chunk struct { data []byte encLevel protocol.EncryptionLevel @@ -125,6 +182,7 @@ func (r *runner) OnError(err error) { func (r *runner) DropKeys(protocol.EncryptionLevel) {} const alpn = "fuzzing" +const alpnWrong = "wrong" func toEncryptionLevel(n uint8) protocol.EncryptionLevel { switch n % 3 { @@ -158,8 +216,20 @@ func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) pr } } +func getTransportParameters(seed uint8) *wire.TransportParameters { + const maxVarInt = math.MaxUint64 / 4 + r := mrand.New(mrand.NewSource(int64(seed))) + return &wire.TransportParameters{ + InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), + InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)), + } +} + // PrefixLen is the number of bytes used for configuration -const PrefixLen = 4 +const PrefixLen = 12 +const confLen = 5 // Fuzz fuzzes the TLS 1.3 handshake used by QUIC. //go:generate go run ./cmd/corpus.go @@ -167,18 +237,26 @@ func Fuzz(data []byte) int { if len(data) < PrefixLen { return -1 } - runConfig1 := data[0] - messageConfig1 := data[1] - runConfig2 := data[2] - messageConfig2 := data[3] - data = data[PrefixLen:] + dataLen := len(data) + var runConfig1, runConfig2 [confLen]byte + copy(runConfig1[:], data) + data = data[confLen:] + messageConfig1 := data[0] + data = data[1:] + copy(runConfig2[:], data) + data = data[confLen:] + messageConfig2 := data[0] + data = data[1:] + if dataLen != len(data)+PrefixLen { + panic("incorrect configuration") + } clientConf := &tls.Config{ ServerName: "localhost", NextProtos: []string{alpn}, RootCAs: certPool, } - useSessionTicketCache := helper.NthBit(runConfig1, 2) + useSessionTicketCache := helper.NthBit(runConfig1[0], 2) if useSessionTicketCache { clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5) } @@ -189,12 +267,89 @@ func Fuzz(data []byte) int { return runHandshake(runConfig2, messageConfig2, clientConf, data) } -func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config, data []byte) int { - enable0RTTClient := helper.NthBit(runConfig, 0) - enable0RTTServer := helper.NthBit(runConfig, 1) - sendPostHandshakeMessageToClient := helper.NthBit(runConfig, 3) - sendPostHandshakeMessageToServer := helper.NthBit(runConfig, 4) - sendSessionTicket := helper.NthBit(runConfig, 5) +func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { + serverConf := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + NextProtos: []string{alpn}, + SessionTicketKey: sessionTicketKey, + } + + enable0RTTClient := helper.NthBit(runConfig[0], 0) + enable0RTTServer := helper.NthBit(runConfig[0], 1) + sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) + sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) + sendSessionTicket := helper.NthBit(runConfig[0], 5) + clientConf.CipherSuites = getSuites(runConfig[0] >> 6) + serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) + serverConf.CipherSuites = getSuites(runConfig[1] >> 6) + serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) + clientConf.PreferServerCipherSuites = helper.NthBit(runConfig[1], 4) + if helper.NthBit(runConfig[2], 0) { + clientConf.RootCAs = x509.NewCertPool() + } + if helper.NthBit(runConfig[2], 1) { + serverConf.ClientCAs = clientCertPool + } else { + serverConf.ClientCAs = x509.NewCertPool() + } + if helper.NthBit(runConfig[2], 2) { + serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + if helper.NthBit(runConfig[2], 3) { + return nil, errors.New("getting client config failed") + } + if helper.NthBit(runConfig[2], 4) { + return nil, nil + } + return serverConf, nil + } + } + if helper.NthBit(runConfig[2], 5) { + serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if helper.NthBit(runConfig[2], 6) { + return nil, errors.New("getting certificate failed") + } + if helper.NthBit(runConfig[2], 7) { + return nil, nil + } + return clientCert, nil // this certificate will be invalid + } + } + if helper.NthBit(runConfig[3], 0) { + serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if helper.NthBit(runConfig[3], 1) { + return errors.New("certificate verification failed") + } + return nil + } + } + if helper.NthBit(runConfig[3], 2) { + clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if helper.NthBit(runConfig[3], 3) { + return errors.New("certificate verification failed") + } + return nil + } + } + if helper.NthBit(runConfig[3], 4) { + serverConf.NextProtos = []string{alpnWrong} + } + if helper.NthBit(runConfig[3], 5) { + serverConf.NextProtos = []string{alpnWrong, alpn} + } + if helper.NthBit(runConfig[3], 6) { + serverConf.KeyLogWriter = ioutil.Discard + } + if helper.NthBit(runConfig[3], 7) { + clientConf.KeyLogWriter = ioutil.Discard + } + clientTP := getTransportParameters(runConfig[4] & 0x3) + if helper.NthBit(runConfig[4], 3) { + clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 + } + serverTP := getTransportParameters(runConfig[4] & 0b00011000) + if helper.NthBit(runConfig[4], 3) { + serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 + } messageToReplace := messageConfig % 32 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) @@ -208,7 +363,7 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + clientTP, runner, clientConf, enable0RTTClient, @@ -224,13 +379,9 @@ func runHandshake(runConfig uint8, messageConfig uint8, clientConf *tls.Config, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + serverTP, runner, - &tls.Config{ - Certificates: []tls.Certificate{*cert}, - NextProtos: []string{alpn}, - SessionTicketKey: sessionTicketKey, - }, + serverConf, enable0RTTServer, utils.NewRTTStats(), nil, @@ -267,7 +418,7 @@ messageLoop: b := c.data encLevel := c.encLevel if len(b) > 0 && b[0] == messageToReplace { - fmt.Println("replacing message to the server", messageType(b[0]).String()) + fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0])) b = data encLevel = maxEncLevel(server, messageToReplaceEncLevel) } @@ -276,7 +427,7 @@ messageLoop: b := c.data encLevel := c.encLevel if len(b) > 0 && b[0] == messageToReplace { - fmt.Println("replacing message to the client", messageType(b[0]).String()) + fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0])) b = data encLevel = maxEncLevel(client, messageToReplaceEncLevel) } @@ -295,7 +446,26 @@ messageLoop: if runner.errored { return 0 } - if sendSessionTicket { + + sealer, err := client.Get1RTTSealer() + if err != nil { + panic("expected to get a 1-RTT sealer") + } + opener, err := server.Get1RTTOpener() + if err != nil { + panic("expected to get a 1-RTT opener") + } + const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." + encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar")) + decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar")) + if err != nil { + panic(fmt.Sprintf("Decrypting message failed: %s", err.Error())) + } + if string(decrypted) != msg { + panic("wrong message") + } + + if sendSessionTicket && !serverConf.SessionTicketsDisabled { ticket, err := server.GetSessionTicket() if err != nil { panic(err)