mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-02 03:37:37 +03:00
423 lines
12 KiB
Go
423 lines
12 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math"
|
|
mrand "math/rand"
|
|
"net"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
"github.com/refraction-networking/uquic/fuzzing/internal/helper"
|
|
"github.com/refraction-networking/uquic/internal/handshake"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/qtls"
|
|
"github.com/refraction-networking/uquic/internal/utils"
|
|
"github.com/refraction-networking/uquic/internal/wire"
|
|
)
|
|
|
|
var (
|
|
cert, clientCert *tls.Certificate
|
|
certPool, clientCertPool *x509.CertPool
|
|
sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
|
)
|
|
|
|
func init() {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
cert, certPool, err = helper.GenerateCertificate(priv)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
privClient, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type messageType uint8
|
|
|
|
// TLS handshake message types.
|
|
const (
|
|
typeClientHello messageType = 1
|
|
typeServerHello messageType = 2
|
|
typeNewSessionTicket messageType = 4
|
|
typeEncryptedExtensions messageType = 8
|
|
typeCertificate messageType = 11
|
|
typeCertificateRequest messageType = 13
|
|
typeCertificateVerify messageType = 15
|
|
typeFinished messageType = 20
|
|
)
|
|
|
|
func (m messageType) String() string {
|
|
switch m {
|
|
case typeClientHello:
|
|
return "ClientHello"
|
|
case typeServerHello:
|
|
return "ServerHello"
|
|
case typeNewSessionTicket:
|
|
return "NewSessionTicket"
|
|
case typeEncryptedExtensions:
|
|
return "EncryptedExtensions"
|
|
case typeCertificate:
|
|
return "Certificate"
|
|
case typeCertificateRequest:
|
|
return "CertificateRequest"
|
|
case typeCertificateVerify:
|
|
return "CertificateVerify"
|
|
case typeFinished:
|
|
return "Finished"
|
|
default:
|
|
return fmt.Sprintf("unknown message type: %d", m)
|
|
}
|
|
}
|
|
|
|
// consumes 3 bits
|
|
func getClientAuth(rand uint8) tls.ClientAuthType {
|
|
switch rand {
|
|
default:
|
|
return tls.NoClientCert
|
|
case 0:
|
|
return tls.RequestClientCert
|
|
case 1:
|
|
return tls.RequireAnyClientCert
|
|
case 2:
|
|
return tls.VerifyClientCertIfGiven
|
|
case 3:
|
|
return tls.RequireAndVerifyClientCert
|
|
}
|
|
}
|
|
|
|
const (
|
|
alpn = "fuzzing"
|
|
alpnWrong = "wrong"
|
|
)
|
|
|
|
func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
|
|
switch n % 3 {
|
|
default:
|
|
return protocol.EncryptionInitial
|
|
case 1:
|
|
return protocol.EncryptionHandshake
|
|
case 2:
|
|
return protocol.Encryption1RTT
|
|
}
|
|
}
|
|
|
|
func getTransportParameters(seed uint8) *wire.TransportParameters {
|
|
const maxVarInt = math.MaxUint64 / 4
|
|
r := mrand.New(mrand.NewSource(int64(seed)))
|
|
return &wire.TransportParameters{
|
|
ActiveConnectionIDLimit: 2,
|
|
InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)),
|
|
InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)),
|
|
InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)),
|
|
InitialMaxStreamDataUni: protocol.ByteCount(r.Int63n(maxVarInt)),
|
|
}
|
|
}
|
|
|
|
// PrefixLen is the number of bytes used for configuration
|
|
const (
|
|
PrefixLen = 12
|
|
confLen = 5
|
|
)
|
|
|
|
// Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
|
|
//
|
|
//go:generate go run ./cmd/corpus.go
|
|
func Fuzz(data []byte) int {
|
|
if len(data) < PrefixLen {
|
|
return -1
|
|
}
|
|
dataLen := len(data)
|
|
var runConfig1, runConfig2 [confLen]byte
|
|
copy(runConfig1[:], data)
|
|
data = data[confLen:]
|
|
messageConfig1 := data[0]
|
|
data = data[1:]
|
|
copy(runConfig2[:], data)
|
|
data = data[confLen:]
|
|
messageConfig2 := data[0]
|
|
data = data[1:]
|
|
if dataLen != len(data)+PrefixLen {
|
|
panic("incorrect configuration")
|
|
}
|
|
|
|
clientConf := &tls.Config{
|
|
MinVersion: tls.VersionTLS13,
|
|
ServerName: "localhost",
|
|
NextProtos: []string{alpn},
|
|
RootCAs: certPool,
|
|
}
|
|
useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
|
|
if useSessionTicketCache {
|
|
clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
|
|
}
|
|
|
|
if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 {
|
|
return val
|
|
}
|
|
return runHandshake(runConfig2, messageConfig2, clientConf, data)
|
|
}
|
|
|
|
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
|
|
serverConf := &tls.Config{
|
|
MinVersion: tls.VersionTLS13,
|
|
Certificates: []tls.Certificate{*cert},
|
|
NextProtos: []string{alpn},
|
|
SessionTicketKey: sessionTicketKey,
|
|
}
|
|
|
|
// This sets the cipher suite for both client and server.
|
|
// The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server.
|
|
resetCipherSuite := func() {}
|
|
switch (runConfig[0] >> 6) % 4 {
|
|
case 0:
|
|
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
|
case 1:
|
|
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384)
|
|
case 3:
|
|
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
|
|
default:
|
|
}
|
|
defer resetCipherSuite()
|
|
|
|
enable0RTTClient := helper.NthBit(runConfig[0], 0)
|
|
enable0RTTServer := helper.NthBit(runConfig[0], 1)
|
|
sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
|
|
sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
|
|
sendSessionTicket := helper.NthBit(runConfig[0], 5)
|
|
serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
|
|
serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
|
|
if helper.NthBit(runConfig[2], 0) {
|
|
clientConf.RootCAs = x509.NewCertPool()
|
|
}
|
|
if helper.NthBit(runConfig[2], 1) {
|
|
serverConf.ClientCAs = clientCertPool
|
|
} else {
|
|
serverConf.ClientCAs = x509.NewCertPool()
|
|
}
|
|
if helper.NthBit(runConfig[2], 2) {
|
|
serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
if helper.NthBit(runConfig[2], 3) {
|
|
return nil, errors.New("getting client config failed")
|
|
}
|
|
if helper.NthBit(runConfig[2], 4) {
|
|
return nil, nil
|
|
}
|
|
return serverConf, nil
|
|
}
|
|
}
|
|
if helper.NthBit(runConfig[2], 5) {
|
|
serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if helper.NthBit(runConfig[2], 6) {
|
|
return nil, errors.New("getting certificate failed")
|
|
}
|
|
if helper.NthBit(runConfig[2], 7) {
|
|
return nil, nil
|
|
}
|
|
return clientCert, nil // this certificate will be invalid
|
|
}
|
|
}
|
|
if helper.NthBit(runConfig[3], 0) {
|
|
serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
if helper.NthBit(runConfig[3], 1) {
|
|
return errors.New("certificate verification failed")
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
if helper.NthBit(runConfig[3], 2) {
|
|
clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
if helper.NthBit(runConfig[3], 3) {
|
|
return errors.New("certificate verification failed")
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
if helper.NthBit(runConfig[3], 4) {
|
|
serverConf.NextProtos = []string{alpnWrong}
|
|
}
|
|
if helper.NthBit(runConfig[3], 5) {
|
|
serverConf.NextProtos = []string{alpnWrong, alpn}
|
|
}
|
|
if helper.NthBit(runConfig[3], 6) {
|
|
serverConf.KeyLogWriter = io.Discard
|
|
}
|
|
if helper.NthBit(runConfig[3], 7) {
|
|
clientConf.KeyLogWriter = io.Discard
|
|
}
|
|
clientTP := getTransportParameters(runConfig[4] & 0x3)
|
|
if helper.NthBit(runConfig[4], 3) {
|
|
clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
|
|
}
|
|
serverTP := getTransportParameters(runConfig[4] & 0b00011000)
|
|
if helper.NthBit(runConfig[4], 3) {
|
|
serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
|
|
}
|
|
|
|
messageToReplace := messageConfig % 32
|
|
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
|
|
|
|
if len(data) == 0 {
|
|
return -1
|
|
}
|
|
|
|
client := handshake.NewCryptoSetupClient(
|
|
protocol.ConnectionID{},
|
|
clientTP,
|
|
clientConf,
|
|
enable0RTTClient,
|
|
utils.NewRTTStats(),
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
protocol.Version1,
|
|
)
|
|
if err := client.StartHandshake(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer client.Close()
|
|
|
|
server := handshake.NewCryptoSetupServer(
|
|
protocol.ConnectionID{},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
|
|
serverTP,
|
|
serverConf,
|
|
enable0RTTServer,
|
|
utils.NewRTTStats(),
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
protocol.Version1,
|
|
)
|
|
if err := server.StartHandshake(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer server.Close()
|
|
|
|
var clientHandshakeComplete, serverHandshakeComplete bool
|
|
for {
|
|
var processedEvent bool
|
|
clientLoop:
|
|
for {
|
|
ev := client.NextEvent()
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case handshake.EventNoEvent:
|
|
if !processedEvent && !clientHandshakeComplete { // handshake stuck
|
|
return 1
|
|
}
|
|
break clientLoop
|
|
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
|
|
msg := ev.Data
|
|
encLevel := protocol.EncryptionInitial
|
|
if ev.Kind == handshake.EventWriteHandshakeData {
|
|
encLevel = protocol.EncryptionHandshake
|
|
}
|
|
if msg[0] == messageToReplace {
|
|
fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
|
|
msg = data
|
|
encLevel = messageToReplaceEncLevel
|
|
}
|
|
if err := server.HandleMessage(msg, encLevel); err != nil {
|
|
return 1
|
|
}
|
|
case handshake.EventHandshakeComplete:
|
|
clientHandshakeComplete = true
|
|
}
|
|
processedEvent = true
|
|
}
|
|
|
|
processedEvent = false
|
|
serverLoop:
|
|
for {
|
|
ev := server.NextEvent()
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case handshake.EventNoEvent:
|
|
if !processedEvent && !serverHandshakeComplete { // handshake stuck
|
|
return 1
|
|
}
|
|
break serverLoop
|
|
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
|
|
encLevel := protocol.EncryptionInitial
|
|
if ev.Kind == handshake.EventWriteHandshakeData {
|
|
encLevel = protocol.EncryptionHandshake
|
|
}
|
|
msg := ev.Data
|
|
if msg[0] == messageToReplace {
|
|
fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
|
|
msg = data
|
|
encLevel = messageToReplaceEncLevel
|
|
}
|
|
if err := client.HandleMessage(msg, encLevel); err != nil {
|
|
return 1
|
|
}
|
|
case handshake.EventHandshakeComplete:
|
|
serverHandshakeComplete = true
|
|
}
|
|
processedEvent = true
|
|
}
|
|
|
|
if serverHandshakeComplete && clientHandshakeComplete {
|
|
break
|
|
}
|
|
}
|
|
|
|
_ = client.ConnectionState()
|
|
_ = server.ConnectionState()
|
|
|
|
sealer, err := client.Get1RTTSealer()
|
|
if err != nil {
|
|
panic("expected to get a 1-RTT sealer")
|
|
}
|
|
opener, err := server.Get1RTTOpener()
|
|
if err != nil {
|
|
panic("expected to get a 1-RTT opener")
|
|
}
|
|
const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
|
encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
|
|
decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
|
|
}
|
|
if string(decrypted) != msg {
|
|
panic("wrong message")
|
|
}
|
|
|
|
if sendSessionTicket && !serverConf.SessionTicketsDisabled {
|
|
ticket, err := server.GetSessionTicket()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if ticket == nil {
|
|
panic("empty ticket")
|
|
}
|
|
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
|
}
|
|
|
|
if sendPostHandshakeMessageToClient {
|
|
fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel)
|
|
client.HandleMessage(data, messageToReplaceEncLevel)
|
|
}
|
|
if sendPostHandshakeMessageToServer {
|
|
fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel)
|
|
server.HandleMessage(data, messageToReplaceEncLevel)
|
|
}
|
|
|
|
return 1
|
|
}
|