mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
fix error handling in the TLS crypto setup
There are two ways that an error can occur during the handshake: 1. as a return value from qtls.Handshake() 2. when new data is passed to the crypto setup via HandleData() We need to make sure that the RunHandshake() as well as HandleData() both return if an error occurs at any step during the handshake.
This commit is contained in:
parent
82508f1562
commit
2dbc29a5bd
6 changed files with 375 additions and 242 deletions
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type cryptoDataHandler interface {
|
type cryptoDataHandler interface {
|
||||||
HandleData([]byte, protocol.EncryptionLevel) error
|
HandleData([]byte, protocol.EncryptionLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
type cryptoStreamManager struct {
|
type cryptoStreamManager struct {
|
||||||
|
@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := m.cryptoHandler.HandleData(data, encLevel); err != nil {
|
m.cryptoHandler.HandleData(data, encLevel)
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() {
|
||||||
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
|
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
|
||||||
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
|
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns the error if handling crypto data fails", func() {
|
|
||||||
testErr := errors.New("test error")
|
|
||||||
f := &wire.CryptoFrame{Data: []byte("foobar")}
|
|
||||||
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr)
|
|
||||||
err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -55,9 +55,20 @@ type cryptoSetupTLS struct {
|
||||||
readEncLevel protocol.EncryptionLevel
|
readEncLevel protocol.EncryptionLevel
|
||||||
writeEncLevel protocol.EncryptionLevel
|
writeEncLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
handleParamsCallback func(*TransportParameters)
|
handleParamsCallback func(*TransportParameters)
|
||||||
handshakeEvent chan<- struct{}
|
|
||||||
handshakeComplete chan<- struct{}
|
// There are two ways that an error can occur during the handshake:
|
||||||
|
// 1. as a return value from qtls.Handshake()
|
||||||
|
// 2. when new data is passed to the crypto setup via HandleData()
|
||||||
|
// handshakeErrChan is closed when qtls.Handshake() errors
|
||||||
|
handshakeErrChan chan struct{}
|
||||||
|
// HandleData() sends errors on the messageErrChan
|
||||||
|
messageErrChan chan error
|
||||||
|
// handshakeEvent signals a change of encryption level to the session
|
||||||
|
handshakeEvent chan<- struct{}
|
||||||
|
// handshakeComplete is closed when the handshake completes
|
||||||
|
handshakeComplete chan<- struct{}
|
||||||
|
// transport parameters are sent on the receivedTransportParams, as soon as they are received
|
||||||
receivedTransportParams <-chan TransportParameters
|
receivedTransportParams <-chan TransportParameters
|
||||||
|
|
||||||
clientHelloWritten bool
|
clientHelloWritten bool
|
||||||
|
@ -190,6 +201,8 @@ func newCryptoSetupTLS(
|
||||||
handshakeComplete: handshakeComplete,
|
handshakeComplete: handshakeComplete,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
|
handshakeErrChan: make(chan struct{}),
|
||||||
|
messageErrChan: make(chan error, 1),
|
||||||
clientHelloWrittenChan: make(chan struct{}),
|
clientHelloWrittenChan: make(chan struct{}),
|
||||||
messageChan: make(chan []byte, 100),
|
messageChan: make(chan []byte, 100),
|
||||||
receivedReadKey: make(chan struct{}),
|
receivedReadKey: make(chan struct{}),
|
||||||
|
@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error {
|
||||||
case protocol.PerspectiveServer:
|
case protocol.PerspectiveServer:
|
||||||
conn = qtls.Server(nil, h.tlsConf)
|
conn = qtls.Server(nil, h.tlsConf)
|
||||||
}
|
}
|
||||||
if err := conn.Handshake(); err != nil {
|
// Handle errors that might occur when HandleData() is called.
|
||||||
close(h.receivedReadKey)
|
handshakeErrChan := make(chan error, 1)
|
||||||
close(h.receivedWriteKey)
|
handshakeComplete := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
if err := conn.Handshake(); err != nil {
|
||||||
|
handshakeErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(handshakeComplete)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-handshakeComplete: // return when the handshake is done
|
||||||
|
close(h.handshakeComplete)
|
||||||
|
return nil
|
||||||
|
case err := <-handshakeErrChan:
|
||||||
|
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
|
||||||
|
close(h.handshakeErrChan)
|
||||||
|
return err
|
||||||
|
case err := <-h.messageErrChan:
|
||||||
|
// If the handshake errored because of an error that occurred during HandleData(),
|
||||||
|
// that error message will be more useful than the error message generated by Handshake().
|
||||||
|
// Close the message chan that qtls is receiving messages from.
|
||||||
|
// This will make qtls.Handshake() return.
|
||||||
|
// Thereby the go routine running qtls.Handshake() will return.
|
||||||
|
close(h.messageChan)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
close(h.handshakeComplete)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error {
|
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
|
@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
|
||||||
case protocol.EncryptionHandshake:
|
case protocol.EncryptionHandshake:
|
||||||
buf = &h.handshakeReadBuf
|
buf = &h.handshakeReadBuf
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
|
h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
buf.Write(data)
|
buf.Write(data)
|
||||||
for buf.Len() >= 4 {
|
for buf.Len() >= 4 {
|
||||||
|
@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
|
||||||
// read the TLS message length
|
// read the TLS message length
|
||||||
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||||
if buf.Len() < 4+length { // message not yet complete
|
if buf.Len() < 4+length { // message not yet complete
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
msg := make([]byte, length+4)
|
msg := make([]byte, length+4)
|
||||||
buf.Read(msg)
|
buf.Read(msg)
|
||||||
if err := h.handleMessage(msg, encLevel); err != nil {
|
if err := h.handleMessage(msg, encLevel); err != nil {
|
||||||
return err
|
h.messageErrChan <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleMessage handles a TLS handshake message.
|
// handleMessage handles a TLS handshake message.
|
||||||
|
@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption
|
||||||
h.messageChan <- data
|
h.messageChan <- data
|
||||||
switch h.perspective {
|
switch h.perspective {
|
||||||
case protocol.PerspectiveClient:
|
case protocol.PerspectiveClient:
|
||||||
return h.handleMessageForClient(msgType)
|
h.handleMessageForClient(msgType)
|
||||||
case protocol.PerspectiveServer:
|
case protocol.PerspectiveServer:
|
||||||
return h.handleMessageForServer(msgType)
|
h.handleMessageForServer(msgType)
|
||||||
default:
|
default:
|
||||||
panic("")
|
panic("")
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||||
|
@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
|
||||||
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||||
}
|
}
|
||||||
if encLevel != expected {
|
if encLevel != expected {
|
||||||
return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel)
|
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error {
|
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case typeClientHello:
|
case typeClientHello:
|
||||||
params := <-h.receivedTransportParams
|
select {
|
||||||
h.handleParamsCallback(¶ms)
|
case params := <-h.receivedTransportParams:
|
||||||
<-h.receivedWriteKey // get the handshake write key
|
h.handleParamsCallback(¶ms)
|
||||||
<-h.receivedWriteKey // get the 1-RTT write key
|
case <-h.handshakeErrChan:
|
||||||
<-h.receivedReadKey // get the handshake read key
|
return
|
||||||
h.handshakeEvent <- struct{}{}
|
}
|
||||||
|
// get the handshake write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// get the 1-RTT write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// get the handshake read key
|
||||||
// TODO: check that the initial stream doesn't have any more data
|
// TODO: check that the initial stream doesn't have any more data
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handshakeEvent <- struct{}{}
|
||||||
case typeCertificate, typeCertificateVerify:
|
case typeCertificate, typeCertificateVerify:
|
||||||
// nothing to do
|
// nothing to do
|
||||||
case typeFinished:
|
case typeFinished:
|
||||||
<-h.receivedReadKey // get the 1-RTT read key
|
// get the 1-RTT read key
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
// TODO: check that the handshake stream doesn't have any more data
|
// TODO: check that the handshake stream doesn't have any more data
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handshakeEvent <- struct{}{}
|
||||||
default:
|
default:
|
||||||
// TODO: think about what to do with unknown message types
|
panic("unexpected handshake message")
|
||||||
return fmt.Errorf("Received unknown handshake message: %d", msgType)
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error {
|
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case typeServerHello:
|
case typeServerHello:
|
||||||
<-h.receivedReadKey // get the handshake read key
|
// get the handshake read key
|
||||||
|
// TODO: check that the initial stream doesn't have any more data
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
h.handshakeEvent <- struct{}{}
|
h.handshakeEvent <- struct{}{}
|
||||||
case typeEncryptedExtensions:
|
case typeEncryptedExtensions:
|
||||||
params := <-h.receivedTransportParams
|
select {
|
||||||
h.handleParamsCallback(¶ms)
|
case params := <-h.receivedTransportParams:
|
||||||
|
h.handleParamsCallback(¶ms)
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
||||||
// nothing to do
|
// nothing to do
|
||||||
case typeFinished:
|
case typeFinished:
|
||||||
<-h.receivedWriteKey // get the handshake write key
|
// get the handshake write key
|
||||||
// TODO: check that the initial stream doesn't have any more data
|
// TODO: check that the initial stream doesn't have any more data
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
// While the order of these two is not defined by the TLS spec,
|
// While the order of these two is not defined by the TLS spec,
|
||||||
// we have to do it on the same order as our TLS library does it.
|
// we have to do it on the same order as our TLS library does it.
|
||||||
<-h.receivedWriteKey // get the handshake write key
|
// get the handshake write key
|
||||||
<-h.receivedReadKey // get the 1-RTT read key
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// get the 1-RTT read key
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
// TODO: check that the handshake stream doesn't have any more data
|
// TODO: check that the handshake stream doesn't have any more data
|
||||||
h.handshakeEvent <- struct{}{}
|
h.handshakeEvent <- struct{}{}
|
||||||
default:
|
default:
|
||||||
// TODO: think about what to do with unknown extensions
|
panic("unexpected handshake message: ")
|
||||||
return fmt.Errorf("Received unknown handshake message: %d", msgType)
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadHandshakeMessage is called by TLS.
|
// ReadHandshakeMessage is called by TLS.
|
||||||
// It blocks until a new handshake message is available.
|
// It blocks until a new handshake message is available.
|
||||||
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
|
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
|
||||||
// TODO: add some error handling here (when the session is closed)
|
// TODO: add some error handling here (when the session is closed)
|
||||||
return <-h.messageChan, nil
|
msg, ok := <-h.messageChan
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("error while handling the handshake message")
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||||
|
|
|
@ -48,25 +48,6 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = Describe("Crypto Setup TLS", func() {
|
var _ = Describe("Crypto Setup TLS", func() {
|
||||||
generateCert := func() tls.Certificate {
|
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
tmpl := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
Subject: pkix.Name{},
|
|
||||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return tls.Certificate{
|
|
||||||
PrivateKey: priv,
|
|
||||||
Certificate: [][]byte{certDER},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||||
chunkChan := make(chan chunk, 100)
|
chunkChan := make(chan chunk, 100)
|
||||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||||
|
@ -74,172 +55,16 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
return chunkChan, initialStream, handshakeStream
|
return chunkChan, initialStream, handshakeStream
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake := func(
|
It("returns Handshake() when an error occurs", func() {
|
||||||
client CryptoSetupTLS,
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
cChunkChan <-chan chunk,
|
|
||||||
server CryptoSetupTLS,
|
|
||||||
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case c := <-cChunkChan:
|
|
||||||
err := server.HandleData(c.data, c.encLevel)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
case c := <-sChunkChan:
|
|
||||||
err := client.HandleData(c.data, c.encLevel)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
case <-done: // handshake complete
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverErrChan := make(chan error)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
serverErrChan <- server.RunHandshake()
|
|
||||||
}()
|
|
||||||
|
|
||||||
clientErr := client.RunHandshake()
|
|
||||||
var serverErr error
|
|
||||||
Eventually(serverErrChan).Should(Receive(&serverErr))
|
|
||||||
return clientErr, serverErr
|
|
||||||
}
|
|
||||||
|
|
||||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
||||||
client, _, err := NewCryptoSetupTLSClient(
|
|
||||||
cInitialStream,
|
|
||||||
cHandshakeStream,
|
|
||||||
protocol.ConnectionID{},
|
|
||||||
&TransportParameters{},
|
|
||||||
func(p *TransportParameters) {},
|
|
||||||
make(chan struct{}, 100),
|
|
||||||
make(chan struct{}),
|
|
||||||
clientConf,
|
|
||||||
protocol.VersionTLS,
|
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
||||||
server, err := NewCryptoSetupTLSServer(
|
server, err := NewCryptoSetupTLSServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
|
|
||||||
func(p *TransportParameters) {},
|
|
||||||
make(chan struct{}, 100),
|
|
||||||
make(chan struct{}),
|
|
||||||
serverConf,
|
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
return handshake(client, cChunkChan, server, sChunkChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
It("handshakes", func() {
|
|
||||||
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
|
|
||||||
serverConf := testdata.GetTLSConfig()
|
|
||||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
|
||||||
Expect(clientErr).ToNot(HaveOccurred())
|
|
||||||
Expect(serverErr).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handshakes with client auth", func() {
|
|
||||||
clientConf := &tls.Config{
|
|
||||||
ServerName: "quic.clemente.io",
|
|
||||||
Certificates: []tls.Certificate{generateCert()},
|
|
||||||
}
|
|
||||||
serverConf := testdata.GetTLSConfig()
|
|
||||||
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
|
||||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
|
||||||
Expect(clientErr).ToNot(HaveOccurred())
|
|
||||||
Expect(serverErr).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("signals when it has written the ClientHello", func() {
|
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
||||||
client, chChan, err := NewCryptoSetupTLSClient(
|
|
||||||
cInitialStream,
|
|
||||||
cHandshakeStream,
|
|
||||||
protocol.ConnectionID{},
|
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func(p *TransportParameters) {},
|
func(p *TransportParameters) {},
|
||||||
make(chan struct{}, 100),
|
make(chan struct{}, 100),
|
||||||
make(chan struct{}),
|
make(chan struct{}),
|
||||||
&tls.Config{InsecureSkipVerify: true},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
client.RunHandshake()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
var ch chunk
|
|
||||||
Eventually(cChunkChan).Should(Receive(&ch))
|
|
||||||
Eventually(chChan).Should(BeClosed())
|
|
||||||
// make sure the whole ClientHello was written
|
|
||||||
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
|
||||||
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
|
||||||
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
|
||||||
Expect(len(ch.data) - 4).To(Equal(length))
|
|
||||||
|
|
||||||
// make the go routine return
|
|
||||||
client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("receives transport parameters", func() {
|
|
||||||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
||||||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
|
||||||
client, _, err := NewCryptoSetupTLSClient(
|
|
||||||
cInitialStream,
|
|
||||||
cHandshakeStream,
|
|
||||||
protocol.ConnectionID{},
|
|
||||||
cTransportParameters,
|
|
||||||
func(p *TransportParameters) { sTransportParametersRcvd = p },
|
|
||||||
make(chan struct{}, 100),
|
|
||||||
make(chan struct{}),
|
|
||||||
&tls.Config{ServerName: "quic.clemente.io"},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
||||||
sTransportParameters := &TransportParameters{
|
|
||||||
IdleTimeout: 0x1337 * time.Second,
|
|
||||||
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
|
||||||
}
|
|
||||||
server, err := NewCryptoSetupTLSServer(
|
|
||||||
sInitialStream,
|
|
||||||
sHandshakeStream,
|
|
||||||
protocol.ConnectionID{},
|
|
||||||
sTransportParameters,
|
|
||||||
func(p *TransportParameters) { cTransportParametersRcvd = p },
|
|
||||||
make(chan struct{}, 100),
|
|
||||||
make(chan struct{}),
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
|
@ -251,15 +76,253 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
|
err := server.RunHandshake()
|
||||||
Expect(clientErr).ToNot(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(serverErr).ToNot(HaveOccurred())
|
Expect(err.Error()).To(ContainSubstring("received unexpected handshake message"))
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||||
|
server.HandleData(fakeCH, protocol.EncryptionInitial)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
Expect(cTransportParametersRcvd).ToNot(BeNil())
|
})
|
||||||
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
|
|
||||||
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
It("returns Handshake() when handling a message fails", func() {
|
||||||
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
|
server, err := NewCryptoSetupTLSServer(
|
||||||
|
sInitialStream,
|
||||||
|
sHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
&TransportParameters{},
|
||||||
|
func(p *TransportParameters) {},
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
testdata.GetTLSConfig(),
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
|
protocol.PerspectiveServer,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
err := server.RunHandshake()
|
||||||
|
Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||||
|
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("doing the handshake", func() {
|
||||||
|
generateCert := func() tls.Certificate {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{},
|
||||||
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return tls.Certificate{
|
||||||
|
PrivateKey: priv,
|
||||||
|
Certificate: [][]byte{certDER},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake := func(
|
||||||
|
client CryptoSetupTLS,
|
||||||
|
cChunkChan <-chan chunk,
|
||||||
|
server CryptoSetupTLS,
|
||||||
|
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c := <-cChunkChan:
|
||||||
|
server.HandleData(c.data, c.encLevel)
|
||||||
|
case c := <-sChunkChan:
|
||||||
|
client.HandleData(c.data, c.encLevel)
|
||||||
|
case <-done: // handshake complete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverErrChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
serverErrChan <- server.RunHandshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientErr := client.RunHandshake()
|
||||||
|
var serverErr error
|
||||||
|
Eventually(serverErrChan).Should(Receive(&serverErr))
|
||||||
|
return clientErr, serverErr
|
||||||
|
}
|
||||||
|
|
||||||
|
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||||
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
|
client, _, err := NewCryptoSetupTLSClient(
|
||||||
|
cInitialStream,
|
||||||
|
cHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
&TransportParameters{},
|
||||||
|
func(p *TransportParameters) {},
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
clientConf,
|
||||||
|
protocol.VersionTLS,
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
|
protocol.PerspectiveClient,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
|
server, err := NewCryptoSetupTLSServer(
|
||||||
|
sInitialStream,
|
||||||
|
sHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
|
||||||
|
func(p *TransportParameters) {},
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
serverConf,
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
|
protocol.PerspectiveServer,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
return handshake(client, cChunkChan, server, sChunkChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
It("handshakes", func() {
|
||||||
|
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
|
||||||
|
serverConf := testdata.GetTLSConfig()
|
||||||
|
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||||
|
Expect(clientErr).ToNot(HaveOccurred())
|
||||||
|
Expect(serverErr).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handshakes with client auth", func() {
|
||||||
|
clientConf := &tls.Config{
|
||||||
|
ServerName: "quic.clemente.io",
|
||||||
|
Certificates: []tls.Certificate{generateCert()},
|
||||||
|
}
|
||||||
|
serverConf := testdata.GetTLSConfig()
|
||||||
|
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
||||||
|
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||||
|
Expect(clientErr).ToNot(HaveOccurred())
|
||||||
|
Expect(serverErr).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("signals when it has written the ClientHello", func() {
|
||||||
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
|
client, chChan, err := NewCryptoSetupTLSClient(
|
||||||
|
cInitialStream,
|
||||||
|
cHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
&TransportParameters{},
|
||||||
|
func(p *TransportParameters) {},
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
&tls.Config{InsecureSkipVerify: true},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
|
protocol.PerspectiveClient,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
client.RunHandshake()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
var ch chunk
|
||||||
|
Eventually(cChunkChan).Should(Receive(&ch))
|
||||||
|
Eventually(chChan).Should(BeClosed())
|
||||||
|
// make sure the whole ClientHello was written
|
||||||
|
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
||||||
|
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
||||||
|
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
||||||
|
Expect(len(ch.data) - 4).To(Equal(length))
|
||||||
|
|
||||||
|
// make the go routine return
|
||||||
|
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("receives transport parameters", func() {
|
||||||
|
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||||
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
|
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||||
|
client, _, err := NewCryptoSetupTLSClient(
|
||||||
|
cInitialStream,
|
||||||
|
cHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
cTransportParameters,
|
||||||
|
func(p *TransportParameters) { sTransportParametersRcvd = p },
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
&tls.Config{ServerName: "quic.clemente.io"},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
|
protocol.PerspectiveClient,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
|
sTransportParameters := &TransportParameters{
|
||||||
|
IdleTimeout: 0x1337 * time.Second,
|
||||||
|
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
||||||
|
}
|
||||||
|
server, err := NewCryptoSetupTLSServer(
|
||||||
|
sInitialStream,
|
||||||
|
sHandshakeStream,
|
||||||
|
protocol.ConnectionID{},
|
||||||
|
sTransportParameters,
|
||||||
|
func(p *TransportParameters) { cTransportParametersRcvd = p },
|
||||||
|
make(chan struct{}, 100),
|
||||||
|
make(chan struct{}),
|
||||||
|
testdata.GetTLSConfig(),
|
||||||
|
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
protocol.VersionTLS,
|
||||||
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
|
protocol.PerspectiveServer,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
|
||||||
|
Expect(clientErr).ToNot(HaveOccurred())
|
||||||
|
Expect(serverErr).ToNot(HaveOccurred())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
Expect(cTransportParametersRcvd).ToNot(BeNil())
|
||||||
|
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
|
||||||
|
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
||||||
|
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -44,7 +44,7 @@ type CryptoSetup interface {
|
||||||
type CryptoSetupTLS interface {
|
type CryptoSetupTLS interface {
|
||||||
baseCryptoSetup
|
baseCryptoSetup
|
||||||
|
|
||||||
HandleData([]byte, protocol.EncryptionLevel) error
|
HandleData([]byte, protocol.EncryptionLevel)
|
||||||
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
|
|
|
@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleData mocks base method
|
// HandleData mocks base method
|
||||||
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error {
|
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) {
|
||||||
ret := m.ctrl.Call(m, "HandleData", arg0, arg1)
|
m.ctrl.Call(m, "HandleData", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleData indicates an expected call of HandleData
|
// HandleData indicates an expected call of HandleData
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue