use a synchronous API for the crypto setup (#3939)

This commit is contained in:
Marten Seemann 2023-07-21 10:00:42 -07:00 committed by GitHub
parent 2c0e7e02b0
commit 469a6153b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 696 additions and 1032 deletions

View file

@ -13,70 +13,12 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
chunkChan chan<- chunk
encLevel protocol.EncryptionLevel
}
func (s *stream) Write(b []byte) (int, error) {
data := append([]byte{}, b...)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 10)
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
return chunkChan, initialStream, handshakeStream
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
const alpn = "fuzz"
func main() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
client := handshake.NewCryptoSetupClient(
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
newRunner(clientHandshakeCompleted),
&tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: "localhost",
@ -91,17 +33,11 @@ func main() {
protocol.Version1,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
config := testdata.GetTLSConfig()
config.NextProtos = []string{alpn}
serverHandshakeCompleted := make(chan struct{})
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
server := handshake.NewCryptoSetupServer(
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
newRunner(serverHandshakeCompleted),
config,
false,
utils.NewRTTStats(),
@ -118,29 +54,55 @@ func main() {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
<-serverHandshakeCompleted
<-clientHandshakeCompleted
close(done)
}()
var clientHandshakeComplete, serverHandshakeComplete bool
var messages [][]byte
messageLoop:
for {
select {
case c := <-cChunkChan:
messages = append(messages, c.data)
if err := server.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
clientLoop:
for {
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
break clientLoop
case handshake.EventWriteInitialData:
messages = append(messages, ev.Data)
if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
log.Fatal(err)
}
case handshake.EventWriteHandshakeData:
messages = append(messages, ev.Data)
if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
log.Fatal(err)
}
case handshake.EventHandshakeComplete:
clientHandshakeComplete = true
}
case c := <-sChunkChan:
messages = append(messages, c.data)
if err := client.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
}
serverLoop:
for {
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
break serverLoop
case handshake.EventWriteInitialData:
messages = append(messages, ev.Data)
if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
log.Fatal(err)
}
case handshake.EventWriteHandshakeData:
messages = append(messages, ev.Data)
if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
log.Fatal(err)
}
case handshake.EventHandshakeComplete:
serverHandshakeComplete = true
}
case <-done:
break messageLoop
}
if serverHandshakeComplete && clientHandshakeComplete {
break
}
}

View file

@ -126,57 +126,6 @@ func getClientAuth(rand uint8) tls.ClientAuthType {
}
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
chunkChan chan<- chunk
encLevel protocol.EncryptionLevel
}
func (s *stream) Write(b []byte) (int, error) {
data := append([]byte{}, b...)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 10)
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
return chunkChan, initialStream, handshakeStream
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
const (
alpn = "fuzzing"
alpnWrong = "wrong"
@ -193,28 +142,6 @@ func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
}
}
func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel {
//nolint:exhaustive
switch encLevel {
case protocol.EncryptionInitial:
return protocol.EncryptionInitial
case protocol.EncryptionHandshake:
// Handshake opener not available. We can't possibly read a Handshake handshake message.
if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil {
return protocol.EncryptionInitial
}
return protocol.EncryptionHandshake
case protocol.Encryption1RTT:
// 1-RTT opener not available. We can't possibly read a post-handshake message.
if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil {
return maxEncLevel(cs, protocol.EncryptionHandshake)
}
return protocol.Encryption1RTT
default:
panic("unexpected encryption level")
}
}
func getTransportParameters(seed uint8) *wire.TransportParameters {
const maxVarInt = math.MaxUint64 / 4
r := mrand.New(mrand.NewSource(int64(seed)))
@ -357,16 +284,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
messageToReplace := messageConfig % 32
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
if len(data) == 0 {
return -1
}
client := handshake.NewCryptoSetupClient(
protocol.ConnectionID{},
clientTP,
newRunner(clientHandshakeCompleted),
clientConf,
enable0RTTClient,
utils.NewRTTStats(),
@ -374,16 +298,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
serverHandshakeCompleted := make(chan struct{})
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
server := handshake.NewCryptoSetupServer(
protocol.ConnectionID{},
serverTP,
newRunner(serverHandshakeCompleted),
serverConf,
enable0RTTServer,
utils.NewRTTStats(),
@ -391,57 +312,69 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
if len(data) == 0 {
return -1
}
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
if err := server.StartHandshake(); err != nil {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
<-serverHandshakeCompleted
<-clientHandshakeCompleted
close(done)
}()
messageLoop:
var clientHandshakeComplete, serverHandshakeComplete bool
for {
select {
case c := <-cChunkChan:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
clientLoop:
for {
var processedEvent bool
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !clientHandshakeComplete { // handshake stuck
return 1
}
break clientLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
msg := ev.Data
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
}
if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
clientHandshakeComplete = true
}
if err := server.HandleMessage(b, encLevel); err != nil {
break messageLoop
processedEvent = true
}
serverLoop:
for {
var processedEvent bool
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !serverHandshakeComplete { // handshake stuck
return 1
}
break serverLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
msg := ev.Data
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
}
if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
serverHandshakeComplete = true
}
case c := <-sChunkChan:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
}
if err := client.HandleMessage(b, encLevel); err != nil {
break messageLoop
}
case <-done: // test done
break messageLoop
processedEvent = true
}
if serverHandshakeComplete && clientHandshakeComplete {
break
}
}
<-done
_ = client.ConnectionState()
_ = server.ConnectionState()