fix error handling in the TLS crypto setup

There are two ways that an error can occur during the handshake:
1. as a return value from qtls.Handshake()
2. when new data is passed to the crypto setup via HandleData()
We need to make sure that the RunHandshake() as well as HandleData()
both return if an error occurs at any step during the handshake.
This commit is contained in:
Marten Seemann 2018-10-18 22:55:02 +01:00
parent 82508f1562
commit 2dbc29a5bd
6 changed files with 375 additions and 242 deletions

View file

@ -8,7 +8,7 @@ import (
) )
type cryptoDataHandler interface { type cryptoDataHandler interface {
HandleData([]byte, protocol.EncryptionLevel) error HandleData([]byte, protocol.EncryptionLevel)
} }
type cryptoStreamManager struct { type cryptoStreamManager struct {
@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
if data == nil { if data == nil {
return nil return nil
} }
if err := m.cryptoHandler.HandleData(data, encLevel); err != nil { m.cryptoHandler.HandleData(data, encLevel)
return err
}
} }
} }

View file

@ -1,8 +1,6 @@
package quic package quic
import ( import (
"errors"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() {
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
}) })
It("returns the error if handling crypto data fails", func() {
testErr := errors.New("test error")
f := &wire.CryptoFrame{Data: []byte("foobar")}
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr)
err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)
Expect(err).To(MatchError(testErr))
})
}) })

View file

@ -55,9 +55,20 @@ type cryptoSetupTLS struct {
readEncLevel protocol.EncryptionLevel readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel
handleParamsCallback func(*TransportParameters) handleParamsCallback func(*TransportParameters)
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{} // There are two ways that an error can occur during the handshake:
// 1. as a return value from qtls.Handshake()
// 2. when new data is passed to the crypto setup via HandleData()
// handshakeErrChan is closed when qtls.Handshake() errors
handshakeErrChan chan struct{}
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeEvent signals a change of encryption level to the session
handshakeEvent chan<- struct{}
// handshakeComplete is closed when the handshake completes
handshakeComplete chan<- struct{}
// transport parameters are sent on the receivedTransportParams, as soon as they are received
receivedTransportParams <-chan TransportParameters receivedTransportParams <-chan TransportParameters
clientHelloWritten bool clientHelloWritten bool
@ -190,6 +201,8 @@ func newCryptoSetupTLS(
handshakeComplete: handshakeComplete, handshakeComplete: handshakeComplete,
logger: logger, logger: logger,
perspective: perspective, perspective: perspective,
handshakeErrChan: make(chan struct{}),
messageErrChan: make(chan error, 1),
clientHelloWrittenChan: make(chan struct{}), clientHelloWrittenChan: make(chan struct{}),
messageChan: make(chan []byte, 100), messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}), receivedReadKey: make(chan struct{}),
@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error {
case protocol.PerspectiveServer: case protocol.PerspectiveServer:
conn = qtls.Server(nil, h.tlsConf) conn = qtls.Server(nil, h.tlsConf)
} }
if err := conn.Handshake(); err != nil { // Handle errors that might occur when HandleData() is called.
close(h.receivedReadKey) handshakeErrChan := make(chan error, 1)
close(h.receivedWriteKey) handshakeComplete := make(chan struct{})
go func() {
if err := conn.Handshake(); err != nil {
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-handshakeComplete: // return when the handshake is done
close(h.handshakeComplete)
return nil
case err := <-handshakeErrChan:
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
close(h.handshakeErrChan)
return err
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
// Thereby the go routine running qtls.Handshake() will return.
close(h.messageChan)
return err return err
} }
close(h.handshakeComplete)
return nil
} }
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error { func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
var buf *bytes.Buffer var buf *bytes.Buffer
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
buf = &h.handshakeReadBuf buf = &h.handshakeReadBuf
default: default:
return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
return
} }
buf.Write(data) buf.Write(data)
for buf.Len() >= 4 { for buf.Len() >= 4 {
@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
// read the TLS message length // read the TLS message length
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if buf.Len() < 4+length { // message not yet complete if buf.Len() < 4+length { // message not yet complete
return nil return
} }
msg := make([]byte, length+4) msg := make([]byte, length+4)
buf.Read(msg) buf.Read(msg)
if err := h.handleMessage(msg, encLevel); err != nil { if err := h.handleMessage(msg, encLevel); err != nil {
return err h.messageErrChan <- err
} }
} }
return nil
} }
// handleMessage handles a TLS handshake message. // handleMessage handles a TLS handshake message.
@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption
h.messageChan <- data h.messageChan <- data
switch h.perspective { switch h.perspective {
case protocol.PerspectiveClient: case protocol.PerspectiveClient:
return h.handleMessageForClient(msgType) h.handleMessageForClient(msgType)
case protocol.PerspectiveServer: case protocol.PerspectiveServer:
return h.handleMessageForServer(msgType) h.handleMessageForServer(msgType)
default: default:
panic("") panic("")
} }
return nil
} }
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
return fmt.Errorf("unexpected handshake message: %d", msgType) return fmt.Errorf("unexpected handshake message: %d", msgType)
} }
if encLevel != expected { if encLevel != expected {
return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel) return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
} }
return nil return nil
} }
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error { func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
switch msgType { switch msgType {
case typeClientHello: case typeClientHello:
params := <-h.receivedTransportParams select {
h.handleParamsCallback(&params) case params := <-h.receivedTransportParams:
<-h.receivedWriteKey // get the handshake write key h.handleParamsCallback(&params)
<-h.receivedWriteKey // get the 1-RTT write key case <-h.handshakeErrChan:
<-h.receivedReadKey // get the handshake read key return
h.handshakeEvent <- struct{}{} }
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data // TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{}
case typeCertificate, typeCertificateVerify: case typeCertificate, typeCertificateVerify:
// nothing to do // nothing to do
case typeFinished: case typeFinished:
<-h.receivedReadKey // get the 1-RTT read key // get the 1-RTT read key
h.handshakeEvent <- struct{}{}
// TODO: check that the handshake stream doesn't have any more data // TODO: check that the handshake stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{}
default: default:
// TODO: think about what to do with unknown message types panic("unexpected handshake message")
return fmt.Errorf("Received unknown handshake message: %d", msgType)
} }
return nil
} }
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error { func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
switch msgType { switch msgType {
case typeServerHello: case typeServerHello:
<-h.receivedReadKey // get the handshake read key // get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
case typeEncryptedExtensions: case typeEncryptedExtensions:
params := <-h.receivedTransportParams select {
h.handleParamsCallback(&params) case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return
}
case typeCertificateRequest, typeCertificate, typeCertificateVerify: case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do // nothing to do
case typeFinished: case typeFinished:
<-h.receivedWriteKey // get the handshake write key // get the handshake write key
// TODO: check that the initial stream doesn't have any more data // TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// While the order of these two is not defined by the TLS spec, // While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it. // we have to do it on the same order as our TLS library does it.
<-h.receivedWriteKey // get the handshake write key // get the handshake write key
<-h.receivedReadKey // get the 1-RTT read key select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
// TODO: check that the handshake stream doesn't have any more data // TODO: check that the handshake stream doesn't have any more data
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
default: default:
// TODO: think about what to do with unknown extensions panic("unexpected handshake message: ")
return fmt.Errorf("Received unknown handshake message: %d", msgType)
} }
return nil
} }
// ReadHandshakeMessage is called by TLS. // ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available. // It blocks until a new handshake message is available.
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) { func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
// TODO: add some error handling here (when the session is closed) // TODO: add some error handling here (when the session is closed)
return <-h.messageChan, nil msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
}
return msg, nil
} }
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {

View file

@ -48,25 +48,6 @@ func (s *stream) Write(b []byte) (int, error) {
} }
var _ = Describe("Crypto Setup TLS", func() { var _ = Describe("Crypto Setup TLS", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}
}
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100) chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial) initialStream := newStream(chunkChan, protocol.EncryptionInitial)
@ -74,172 +55,16 @@ var _ = Describe("Crypto Setup TLS", func() {
return chunkChan, initialStream, handshakeStream return chunkChan, initialStream, handshakeStream
} }
handshake := func( It("returns Handshake() when an error occurs", func() {
client CryptoSetupTLS, _, sInitialStream, sHandshakeStream := initStreams()
cChunkChan <-chan chunk,
server CryptoSetupTLS,
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
done := make(chan struct{})
defer close(done)
go func() {
defer GinkgoRecover()
for {
select {
case c := <-cChunkChan:
err := server.HandleData(c.data, c.encLevel)
Expect(err).ToNot(HaveOccurred())
case c := <-sChunkChan:
err := client.HandleData(c.data, c.encLevel)
Expect(err).ToNot(HaveOccurred())
case <-done: // handshake complete
}
}
}()
serverErrChan := make(chan error)
go func() {
defer GinkgoRecover()
serverErrChan <- server.RunHandshake()
}()
clientErr := client.RunHandshake()
var serverErr error
Eventually(serverErrChan).Should(Receive(&serverErr))
return clientErr, serverErr
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, _, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
clientConf,
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer( server, err := NewCryptoSetupTLSServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
protocol.ConnectionID{}, protocol.ConnectionID{},
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
serverConf,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
return handshake(client, cChunkChan, server, sChunkChan)
}
It("handshakes", func() {
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
serverConf := testdata.GetTLSConfig()
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("handshakes with client auth", func() {
clientConf := &tls.Config{
ServerName: "quic.clemente.io",
Certificates: []tls.Certificate{generateCert()},
}
serverConf := testdata.GetTLSConfig()
serverConf.ClientAuth = qtls.RequireAnyClientCert
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("signals when it has written the ClientHello", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{}, &TransportParameters{},
func(p *TransportParameters) {}, func(p *TransportParameters) {},
make(chan struct{}, 100), make(chan struct{}, 100),
make(chan struct{}), make(chan struct{}),
&tls.Config{InsecureSkipVerify: true},
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RunHandshake()
close(done)
}()
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(BeClosed())
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
client, _, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
cTransportParameters,
func(p *TransportParameters) { sTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{ServerName: "quic.clemente.io"},
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sTransportParameters := &TransportParameters{
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
}
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
sTransportParameters,
func(p *TransportParameters) { cTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS}, []protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS, protocol.VersionTLS,
@ -251,15 +76,253 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) err := server.RunHandshake()
Expect(clientErr).ToNot(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("received unexpected handshake message"))
close(done) close(done)
}() }()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
Expect(cTransportParametersRcvd).ToNot(BeNil()) })
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
Expect(sTransportParametersRcvd).ToNot(BeNil()) It("returns Handshake() when handling a message fails", func() {
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) _, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Eventually(done).Should(BeClosed())
})
Context("doing the handshake", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}
}
handshake := func(
client CryptoSetupTLS,
cChunkChan <-chan chunk,
server CryptoSetupTLS,
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
done := make(chan struct{})
defer close(done)
go func() {
defer GinkgoRecover()
for {
select {
case c := <-cChunkChan:
server.HandleData(c.data, c.encLevel)
case c := <-sChunkChan:
client.HandleData(c.data, c.encLevel)
case <-done: // handshake complete
}
}
}()
serverErrChan := make(chan error)
go func() {
defer GinkgoRecover()
serverErrChan <- server.RunHandshake()
}()
clientErr := client.RunHandshake()
var serverErr error
Eventually(serverErrChan).Should(Receive(&serverErr))
return clientErr, serverErr
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, _, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
clientConf,
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
serverConf,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
return handshake(client, cChunkChan, server, sChunkChan)
}
It("handshakes", func() {
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
serverConf := testdata.GetTLSConfig()
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("handshakes with client auth", func() {
clientConf := &tls.Config{
ServerName: "quic.clemente.io",
Certificates: []tls.Certificate{generateCert()},
}
serverConf := testdata.GetTLSConfig()
serverConf.ClientAuth = qtls.RequireAnyClientCert
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("signals when it has written the ClientHello", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{InsecureSkipVerify: true},
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RunHandshake()
close(done)
}()
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(BeClosed())
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
client, _, err := NewCryptoSetupTLSClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
cTransportParameters,
func(p *TransportParameters) { sTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{ServerName: "quic.clemente.io"},
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("client"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sTransportParameters := &TransportParameters{
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
}
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
sTransportParameters,
func(p *TransportParameters) { cTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(cTransportParametersRcvd).ToNot(BeNil())
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
Expect(sTransportParametersRcvd).ToNot(BeNil())
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
})
}) })
}) })

View file

@ -44,7 +44,7 @@ type CryptoSetup interface {
type CryptoSetupTLS interface { type CryptoSetupTLS interface {
baseCryptoSetup baseCryptoSetup
HandleData([]byte, protocol.EncryptionLevel) error HandleData([]byte, protocol.EncryptionLevel)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View file

@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
} }
// HandleData mocks base method // HandleData mocks base method
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error { func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) {
ret := m.ctrl.Call(m, "HandleData", arg0, arg1) m.ctrl.Call(m, "HandleData", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
} }
// HandleData indicates an expected call of HandleData // HandleData indicates an expected call of HandleData