use the new crypto/tls QUIC Transport (#3860)

This commit is contained in:
Marten Seemann 2023-07-01 11:15:00 -07:00 committed by GitHub
parent 4998733ae1
commit 3d89e545d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 2197 additions and 1509 deletions

View file

@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -23,12 +22,10 @@ import (
. "github.com/onsi/gomega"
)
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
const (
typeClientHello = 1
typeNewSessionTicket = 4
)
type chunk struct {
data []byte
@ -80,54 +77,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}
})
It("returns Handshake() when an error occurs in qtls", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "local error: tls: unexpected message",
})))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
handledMessage := make(chan struct{})
go func() {
defer GinkgoRecover()
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
close(handledMessage)
}()
Eventually(handledMessage).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("handles qtls errors occurring before during ClientHello generation", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
tlsConf := testdata.GetTLSConfig()
tlsConf.InsecureSkipVerify = true
@ -135,11 +85,10 @@ var _ = Describe("Crypto Setup TLS", func() {
cl, _ := NewCryptoSetupClient(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
NewMockHandshakeRunner(mockCtrl),
tlsConf,
false,
&utils.RTTStats{},
@ -148,32 +97,21 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cl.RunHandshake()
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: "tls: invalid NextProtos value",
})))
}))
})
It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
@ -184,90 +122,13 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
Expect(server.StartHandshake()).To(Succeed())
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake",
})))
// make the go routine return
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
close(done)
}()
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
// wrong encryption level
err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})
Context("doing the handshake", func() {
@ -297,55 +158,32 @@ var _ = Describe("Crypto Setup TLS", func() {
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk,
) {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
for {
select {
case c := <-cChunkChan:
msgType := messageType(c.data[0])
finished := server.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeClientHello {
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
_, err := server.GetHandshakeOpener()
Expect(finished).To(Equal(err == nil))
} else {
Expect(finished).To(BeFalse())
}
case c := <-sChunkChan:
msgType := messageType(c.data[0])
finished := client.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeServerHello {
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
} else {
Expect(finished).To(BeFalse())
}
case <-done: // handshake complete
return
}
}
}()
handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) {
Expect(client.StartHandshake()).To(Succeed())
Expect(server.StartHandshake()).To(Succeed())
go func() {
defer GinkgoRecover()
defer close(done)
server.RunHandshake()
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
client.HandleMessage(ticket, protocol.Encryption1RTT)
for {
select {
case c := <-cChunkChan:
Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
}()
select {
case c := <-sChunkChan:
Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
// no more messages to send from client and server. Handshake complete?
break
}
client.RunHandshake()
Eventually(done).Should(BeClosed())
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
}
}
handshakeWithTLSConf := func(
@ -359,15 +197,14 @@ var _ = Describe("Crypto Setup TLS", func() {
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
clientTransportParameters,
cRunner,
clientConf,
@ -383,7 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
if serverTransportParameters.StatelessResetToken == nil {
var token protocol.StatelessResetToken
@ -392,9 +229,8 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
serverTransportParameters,
sRunner,
serverConf,
@ -462,9 +298,8 @@ var _ = Describe("Crypto Setup TLS", func() {
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
&tls.Config{InsecureSkipVerify: true},
@ -475,24 +310,15 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RunHandshake()
close(done)
}()
Expect(client.StartHandshake()).To(Succeed())
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(Receive(BeNil()))
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
Expect(client.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("receives transport parameters", func() {
@ -500,14 +326,14 @@ var _ = Describe("Crypto Setup TLS", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
cTransportParameters,
cRunner,
clientConf,
@ -521,6 +347,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
sRunner.EXPECT().OnHandshakeComplete()
sTransportParameters := &wire.TransportParameters{
@ -531,9 +358,8 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
sTransportParameters,
sRunner,
serverConf,
@ -561,13 +387,13 @@ var _ = Describe("Crypto Setup TLS", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
@ -581,14 +407,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
@ -608,25 +434,23 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake",
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.EncryptionHandshake)
err := client.HandleMessage(b, protocol.EncryptionHandshake)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
@ -640,14 +464,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
@ -667,12 +491,10 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.Encryption1RTT)
err := client.HandleMessage(b, protocol.Encryption1RTT)
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
})
It("uses session resumption", func() {
@ -785,7 +607,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
@ -840,7 +661,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}