fix error handling in the TLS crypto setup

There are two ways that an error can occur during the handshake:
1. as a return value from qtls.Handshake()
2. when new data is passed to the crypto setup via HandleData()
We need to make sure that the RunHandshake() as well as HandleData()
both return if an error occurs at any step during the handshake.
This commit is contained in:
Marten Seemann 2018-10-18 22:55:02 +01:00
parent 82508f1562
commit 2dbc29a5bd
6 changed files with 375 additions and 242 deletions

View file

@ -8,7 +8,7 @@ import (
)
type cryptoDataHandler interface {
HandleData([]byte, protocol.EncryptionLevel) error
HandleData([]byte, protocol.EncryptionLevel)
}
type cryptoStreamManager struct {
@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
if data == nil {
return nil
}
if err := m.cryptoHandler.HandleData(data, encLevel); err != nil {
return err
}
m.cryptoHandler.HandleData(data, encLevel)
}
}

View file

@ -1,8 +1,6 @@
package quic
import (
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() {
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
})
It("returns the error if handling crypto data fails", func() {
testErr := errors.New("test error")
f := &wire.CryptoFrame{Data: []byte("foobar")}
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr)
err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)
Expect(err).To(MatchError(testErr))
})
})

View file

@ -56,8 +56,19 @@ type cryptoSetupTLS struct {
writeEncLevel protocol.EncryptionLevel
handleParamsCallback func(*TransportParameters)
// There are two ways that an error can occur during the handshake:
// 1. as a return value from qtls.Handshake()
// 2. when new data is passed to the crypto setup via HandleData()
// handshakeErrChan is closed when qtls.Handshake() errors
handshakeErrChan chan struct{}
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeEvent signals a change of encryption level to the session
handshakeEvent chan<- struct{}
// handshakeComplete is closed when the handshake completes
handshakeComplete chan<- struct{}
// transport parameters are sent on the receivedTransportParams, as soon as they are received
receivedTransportParams <-chan TransportParameters
clientHelloWritten bool
@ -190,6 +201,8 @@ func newCryptoSetupTLS(
handshakeComplete: handshakeComplete,
logger: logger,
perspective: perspective,
handshakeErrChan: make(chan struct{}),
messageErrChan: make(chan error, 1),
clientHelloWrittenChan: make(chan struct{}),
messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}),
@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error {
case protocol.PerspectiveServer:
conn = qtls.Server(nil, h.tlsConf)
}
// Handle errors that might occur when HandleData() is called.
handshakeErrChan := make(chan error, 1)
handshakeComplete := make(chan struct{})
go func() {
if err := conn.Handshake(); err != nil {
close(h.receivedReadKey)
close(h.receivedWriteKey)
return err
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-handshakeComplete: // return when the handshake is done
close(h.handshakeComplete)
return nil
case err := <-handshakeErrChan:
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
close(h.handshakeErrChan)
return err
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
// Thereby the go routine running qtls.Handshake() will return.
close(h.messageChan)
return err
}
}
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error {
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
var buf *bytes.Buffer
switch encLevel {
case protocol.EncryptionInitial:
@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
case protocol.EncryptionHandshake:
buf = &h.handshakeReadBuf
default:
return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
return
}
buf.Write(data)
for buf.Len() >= 4 {
@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
// read the TLS message length
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if buf.Len() < 4+length { // message not yet complete
return nil
return
}
msg := make([]byte, length+4)
buf.Read(msg)
if err := h.handleMessage(msg, encLevel); err != nil {
return err
h.messageErrChan <- err
}
}
return nil
}
// handleMessage handles a TLS handshake message.
@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption
h.messageChan <- data
switch h.perspective {
case protocol.PerspectiveClient:
return h.handleMessageForClient(msgType)
h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
return h.handleMessageForServer(msgType)
h.handleMessageForServer(msgType)
default:
panic("")
}
return nil
}
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel)
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
}
return nil
}
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error {
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
switch msgType {
case typeClientHello:
params := <-h.receivedTransportParams
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
<-h.receivedWriteKey // get the handshake write key
<-h.receivedWriteKey // get the 1-RTT write key
<-h.receivedReadKey // get the handshake read key
h.handshakeEvent <- struct{}{}
case <-h.handshakeErrChan:
return
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{}
case typeCertificate, typeCertificateVerify:
// nothing to do
case typeFinished:
<-h.receivedReadKey // get the 1-RTT read key
h.handshakeEvent <- struct{}{}
// get the 1-RTT read key
// TODO: check that the handshake stream doesn't have any more data
default:
// TODO: think about what to do with unknown message types
return fmt.Errorf("Received unknown handshake message: %d", msgType)
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{}
default:
panic("unexpected handshake message")
}
return nil
}
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error {
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
switch msgType {
case typeServerHello:
<-h.receivedReadKey // get the handshake read key
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
h.handshakeEvent <- struct{}{}
case typeEncryptedExtensions:
params := <-h.receivedTransportParams
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return
}
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
case typeFinished:
<-h.receivedWriteKey // get the handshake write key
// get the handshake write key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it.
<-h.receivedWriteKey // get the handshake write key
<-h.receivedReadKey // get the 1-RTT read key
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
}
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
}
// TODO: check that the handshake stream doesn't have any more data
h.handshakeEvent <- struct{}{}
default:
// TODO: think about what to do with unknown extensions
return fmt.Errorf("Received unknown handshake message: %d", msgType)
panic("unexpected handshake message: ")
}
return nil
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
// TODO: add some error handling here (when the session is closed)
return <-h.messageChan, nil
msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
}
return msg, nil
}
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {

View file

@ -48,6 +48,77 @@ func (s *stream) Write(b []byte) (int, error) {
}
var _ = Describe("Crypto Setup TLS", func() {
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
return chunkChan, initialStream, handshakeStream
}
It("returns Handshake() when an error occurs", func() {
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("received unexpected handshake message"))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when handling a message fails", func() {
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
utils.DefaultLogger.WithPrefix("server"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Eventually(done).Should(BeClosed())
})
Context("doing the handshake", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
@ -67,13 +138,6 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
return chunkChan, initialStream, handshakeStream
}
handshake := func(
client CryptoSetupTLS,
cChunkChan <-chan chunk,
@ -86,11 +150,9 @@ var _ = Describe("Crypto Setup TLS", func() {
for {
select {
case c := <-cChunkChan:
err := server.HandleData(c.data, c.encLevel)
Expect(err).ToNot(HaveOccurred())
server.HandleData(c.data, c.encLevel)
case c := <-sChunkChan:
err := client.HandleData(c.data, c.encLevel)
Expect(err).ToNot(HaveOccurred())
client.HandleData(c.data, c.encLevel)
case <-done: // handshake complete
}
}
@ -202,7 +264,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial)
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
@ -263,3 +325,4 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
})
})
})

View file

@ -44,7 +44,7 @@ type CryptoSetup interface {
type CryptoSetupTLS interface {
baseCryptoSetup
HandleData([]byte, protocol.EncryptionLevel) error
HandleData([]byte, protocol.EncryptionLevel)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View file

@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
}
// HandleData mocks base method
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error {
ret := m.ctrl.Call(m, "HandleData", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) {
m.ctrl.Call(m, "HandleData", arg0, arg1)
}
// HandleData indicates an expected call of HandleData