mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use a synchronous API for the crypto setup (#3939)
This commit is contained in:
parent
2c0e7e02b0
commit
469a6153b6
18 changed files with 696 additions and 1032 deletions
|
@ -13,70 +13,12 @@ import (
|
|||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
encLevel protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
chunkChan chan<- chunk
|
||||
encLevel protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
data := append([]byte{}, b...)
|
||||
select {
|
||||
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
|
||||
default:
|
||||
panic("chunkChan too small")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 10)
|
||||
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
|
||||
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
|
||||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnReceivedReadKeys()
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
handshakeComplete chan<- struct{}
|
||||
}
|
||||
|
||||
var _ handshakeRunner = &runner{}
|
||||
|
||||
func newRunner(handshakeComplete chan<- struct{}) *runner {
|
||||
return &runner{handshakeComplete: handshakeComplete}
|
||||
}
|
||||
|
||||
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
|
||||
func (r *runner) OnReceivedReadKeys() {}
|
||||
func (r *runner) OnHandshakeComplete() {
|
||||
close(r.handshakeComplete)
|
||||
}
|
||||
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
|
||||
|
||||
const alpn = "fuzz"
|
||||
|
||||
func main() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
var client, server handshake.CryptoSetup
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
client, _ = handshake.NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
client := handshake.NewCryptoSetupClient(
|
||||
protocol.ConnectionID{},
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
newRunner(clientHandshakeCompleted),
|
||||
&tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
ServerName: "localhost",
|
||||
|
@ -91,17 +33,11 @@ func main() {
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
config := testdata.GetTLSConfig()
|
||||
config.NextProtos = []string{alpn}
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
server = handshake.NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
server := handshake.NewCryptoSetupServer(
|
||||
protocol.ConnectionID{},
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
newRunner(serverHandshakeCompleted),
|
||||
config,
|
||||
false,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -118,29 +54,55 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-serverHandshakeCompleted
|
||||
<-clientHandshakeCompleted
|
||||
close(done)
|
||||
}()
|
||||
|
||||
var clientHandshakeComplete, serverHandshakeComplete bool
|
||||
var messages [][]byte
|
||||
messageLoop:
|
||||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
messages = append(messages, c.data)
|
||||
if err := server.HandleMessage(c.data, c.encLevel); err != nil {
|
||||
log.Fatal(err)
|
||||
clientLoop:
|
||||
for {
|
||||
ev := client.NextEvent()
|
||||
//nolint:exhaustive // only need to process a few events
|
||||
switch ev.Kind {
|
||||
case handshake.EventNoEvent:
|
||||
break clientLoop
|
||||
case handshake.EventWriteInitialData:
|
||||
messages = append(messages, ev.Data)
|
||||
if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case handshake.EventWriteHandshakeData:
|
||||
messages = append(messages, ev.Data)
|
||||
if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case handshake.EventHandshakeComplete:
|
||||
clientHandshakeComplete = true
|
||||
}
|
||||
case c := <-sChunkChan:
|
||||
messages = append(messages, c.data)
|
||||
if err := client.HandleMessage(c.data, c.encLevel); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
serverLoop:
|
||||
for {
|
||||
ev := server.NextEvent()
|
||||
//nolint:exhaustive // only need to process a few events
|
||||
switch ev.Kind {
|
||||
case handshake.EventNoEvent:
|
||||
break serverLoop
|
||||
case handshake.EventWriteInitialData:
|
||||
messages = append(messages, ev.Data)
|
||||
if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case handshake.EventWriteHandshakeData:
|
||||
messages = append(messages, ev.Data)
|
||||
if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case handshake.EventHandshakeComplete:
|
||||
serverHandshakeComplete = true
|
||||
}
|
||||
case <-done:
|
||||
break messageLoop
|
||||
}
|
||||
|
||||
if serverHandshakeComplete && clientHandshakeComplete {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -126,57 +126,6 @@ func getClientAuth(rand uint8) tls.ClientAuthType {
|
|||
}
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
encLevel protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
chunkChan chan<- chunk
|
||||
encLevel protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
data := append([]byte{}, b...)
|
||||
select {
|
||||
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
|
||||
default:
|
||||
panic("chunkChan too small")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 10)
|
||||
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
|
||||
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
|
||||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnReceivedReadKeys()
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
handshakeComplete chan<- struct{}
|
||||
}
|
||||
|
||||
var _ handshakeRunner = &runner{}
|
||||
|
||||
func newRunner(handshakeComplete chan<- struct{}) *runner {
|
||||
return &runner{handshakeComplete: handshakeComplete}
|
||||
}
|
||||
|
||||
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
|
||||
func (r *runner) OnReceivedReadKeys() {}
|
||||
func (r *runner) OnHandshakeComplete() {
|
||||
close(r.handshakeComplete)
|
||||
}
|
||||
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
|
||||
|
||||
const (
|
||||
alpn = "fuzzing"
|
||||
alpnWrong = "wrong"
|
||||
|
@ -193,28 +142,6 @@ func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
|
|||
}
|
||||
}
|
||||
|
||||
func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel {
|
||||
//nolint:exhaustive
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
// Handshake opener not available. We can't possibly read a Handshake handshake message.
|
||||
if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil {
|
||||
return protocol.EncryptionInitial
|
||||
}
|
||||
return protocol.EncryptionHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
// 1-RTT opener not available. We can't possibly read a post-handshake message.
|
||||
if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil {
|
||||
return maxEncLevel(cs, protocol.EncryptionHandshake)
|
||||
}
|
||||
return protocol.Encryption1RTT
|
||||
default:
|
||||
panic("unexpected encryption level")
|
||||
}
|
||||
}
|
||||
|
||||
func getTransportParameters(seed uint8) *wire.TransportParameters {
|
||||
const maxVarInt = math.MaxUint64 / 4
|
||||
r := mrand.New(mrand.NewSource(int64(seed)))
|
||||
|
@ -357,16 +284,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
messageToReplace := messageConfig % 32
|
||||
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
|
||||
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
var client, server handshake.CryptoSetup
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
client, _ = handshake.NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
if len(data) == 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
client := handshake.NewCryptoSetupClient(
|
||||
protocol.ConnectionID{},
|
||||
clientTP,
|
||||
newRunner(clientHandshakeCompleted),
|
||||
clientConf,
|
||||
enable0RTTClient,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -374,16 +298,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.Version1,
|
||||
)
|
||||
if err := client.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
server = handshake.NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
server := handshake.NewCryptoSetupServer(
|
||||
protocol.ConnectionID{},
|
||||
serverTP,
|
||||
newRunner(serverHandshakeCompleted),
|
||||
serverConf,
|
||||
enable0RTTServer,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -391,57 +312,69 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
if len(data) == 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
if err := client.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := server.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-serverHandshakeCompleted
|
||||
<-clientHandshakeCompleted
|
||||
close(done)
|
||||
}()
|
||||
|
||||
messageLoop:
|
||||
var clientHandshakeComplete, serverHandshakeComplete bool
|
||||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
b := c.data
|
||||
encLevel := c.encLevel
|
||||
if len(b) > 0 && b[0] == messageToReplace {
|
||||
fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
|
||||
b = data
|
||||
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
|
||||
clientLoop:
|
||||
for {
|
||||
var processedEvent bool
|
||||
ev := client.NextEvent()
|
||||
//nolint:exhaustive // only need to process a few events
|
||||
switch ev.Kind {
|
||||
case handshake.EventNoEvent:
|
||||
if !processedEvent && !clientHandshakeComplete { // handshake stuck
|
||||
return 1
|
||||
}
|
||||
break clientLoop
|
||||
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
|
||||
msg := ev.Data
|
||||
if msg[0] == messageToReplace {
|
||||
fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
|
||||
msg = data
|
||||
}
|
||||
if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
|
||||
return 1
|
||||
}
|
||||
case handshake.EventHandshakeComplete:
|
||||
clientHandshakeComplete = true
|
||||
}
|
||||
if err := server.HandleMessage(b, encLevel); err != nil {
|
||||
break messageLoop
|
||||
processedEvent = true
|
||||
}
|
||||
|
||||
serverLoop:
|
||||
for {
|
||||
var processedEvent bool
|
||||
ev := server.NextEvent()
|
||||
//nolint:exhaustive // only need to process a few events
|
||||
switch ev.Kind {
|
||||
case handshake.EventNoEvent:
|
||||
if !processedEvent && !serverHandshakeComplete { // handshake stuck
|
||||
return 1
|
||||
}
|
||||
break serverLoop
|
||||
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
|
||||
msg := ev.Data
|
||||
if msg[0] == messageToReplace {
|
||||
fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
|
||||
msg = data
|
||||
}
|
||||
if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
|
||||
return 1
|
||||
}
|
||||
case handshake.EventHandshakeComplete:
|
||||
serverHandshakeComplete = true
|
||||
}
|
||||
case c := <-sChunkChan:
|
||||
b := c.data
|
||||
encLevel := c.encLevel
|
||||
if len(b) > 0 && b[0] == messageToReplace {
|
||||
fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
|
||||
b = data
|
||||
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
|
||||
}
|
||||
if err := client.HandleMessage(b, encLevel); err != nil {
|
||||
break messageLoop
|
||||
}
|
||||
case <-done: // test done
|
||||
break messageLoop
|
||||
processedEvent = true
|
||||
}
|
||||
|
||||
if serverHandshakeComplete && clientHandshakeComplete {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
<-done
|
||||
_ = client.ConnectionState()
|
||||
_ = server.ConnectionState()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue