mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
fix error handling in the TLS crypto setup
There are two ways that an error can occur during the handshake: 1. as a return value from qtls.Handshake() 2. when new data is passed to the crypto setup via HandleData() We need to make sure that the RunHandshake() as well as HandleData() both return if an error occurs at any step during the handshake.
This commit is contained in:
parent
82508f1562
commit
2dbc29a5bd
6 changed files with 375 additions and 242 deletions
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleData([]byte, protocol.EncryptionLevel) error
|
||||
HandleData([]byte, protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
|
@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
|
|||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
if err := m.cryptoHandler.HandleData(data, encLevel); err != nil {
|
||||
return err
|
||||
}
|
||||
m.cryptoHandler.HandleData(data, encLevel)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
|
||||
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns the error if handling crypto data fails", func() {
|
||||
testErr := errors.New("test error")
|
||||
f := &wire.CryptoFrame{Data: []byte("foobar")}
|
||||
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr)
|
||||
err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -56,8 +56,19 @@ type cryptoSetupTLS struct {
|
|||
writeEncLevel protocol.EncryptionLevel
|
||||
|
||||
handleParamsCallback func(*TransportParameters)
|
||||
|
||||
// There are two ways that an error can occur during the handshake:
|
||||
// 1. as a return value from qtls.Handshake()
|
||||
// 2. when new data is passed to the crypto setup via HandleData()
|
||||
// handshakeErrChan is closed when qtls.Handshake() errors
|
||||
handshakeErrChan chan struct{}
|
||||
// HandleData() sends errors on the messageErrChan
|
||||
messageErrChan chan error
|
||||
// handshakeEvent signals a change of encryption level to the session
|
||||
handshakeEvent chan<- struct{}
|
||||
// handshakeComplete is closed when the handshake completes
|
||||
handshakeComplete chan<- struct{}
|
||||
// transport parameters are sent on the receivedTransportParams, as soon as they are received
|
||||
receivedTransportParams <-chan TransportParameters
|
||||
|
||||
clientHelloWritten bool
|
||||
|
@ -190,6 +201,8 @@ func newCryptoSetupTLS(
|
|||
handshakeComplete: handshakeComplete,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeErrChan: make(chan struct{}),
|
||||
messageErrChan: make(chan error, 1),
|
||||
clientHelloWrittenChan: make(chan struct{}),
|
||||
messageChan: make(chan []byte, 100),
|
||||
receivedReadKey: make(chan struct{}),
|
||||
|
@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error {
|
|||
case protocol.PerspectiveServer:
|
||||
conn = qtls.Server(nil, h.tlsConf)
|
||||
}
|
||||
// Handle errors that might occur when HandleData() is called.
|
||||
handshakeErrChan := make(chan error, 1)
|
||||
handshakeComplete := make(chan struct{})
|
||||
go func() {
|
||||
if err := conn.Handshake(); err != nil {
|
||||
close(h.receivedReadKey)
|
||||
close(h.receivedWriteKey)
|
||||
return err
|
||||
handshakeErrChan <- err
|
||||
return
|
||||
}
|
||||
close(handshakeComplete)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
close(h.handshakeComplete)
|
||||
return nil
|
||||
case err := <-handshakeErrChan:
|
||||
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
|
||||
close(h.handshakeErrChan)
|
||||
return err
|
||||
case err := <-h.messageErrChan:
|
||||
// If the handshake errored because of an error that occurred during HandleData(),
|
||||
// that error message will be more useful than the error message generated by Handshake().
|
||||
// Close the message chan that qtls is receiving messages from.
|
||||
// This will make qtls.Handshake() return.
|
||||
// Thereby the go routine running qtls.Handshake() will return.
|
||||
close(h.messageChan)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
|
||||
var buf *bytes.Buffer
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
|
@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
|
|||
case protocol.EncryptionHandshake:
|
||||
buf = &h.handshakeReadBuf
|
||||
default:
|
||||
return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
|
||||
h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
|
||||
return
|
||||
}
|
||||
buf.Write(data)
|
||||
for buf.Len() >= 4 {
|
||||
|
@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev
|
|||
// read the TLS message length
|
||||
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
if buf.Len() < 4+length { // message not yet complete
|
||||
return nil
|
||||
return
|
||||
}
|
||||
msg := make([]byte, length+4)
|
||||
buf.Read(msg)
|
||||
if err := h.handleMessage(msg, encLevel); err != nil {
|
||||
return err
|
||||
h.messageErrChan <- err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessage handles a TLS handshake message.
|
||||
|
@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption
|
|||
h.messageChan <- data
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
return h.handleMessageForClient(msgType)
|
||||
h.handleMessageForClient(msgType)
|
||||
case protocol.PerspectiveServer:
|
||||
return h.handleMessageForServer(msgType)
|
||||
h.handleMessageForServer(msgType)
|
||||
default:
|
||||
panic("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
|
@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
|
|||
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||
}
|
||||
if encLevel != expected {
|
||||
return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error {
|
||||
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
|
||||
switch msgType {
|
||||
case typeClientHello:
|
||||
params := <-h.receivedTransportParams
|
||||
select {
|
||||
case params := <-h.receivedTransportParams:
|
||||
h.handleParamsCallback(¶ms)
|
||||
<-h.receivedWriteKey // get the handshake write key
|
||||
<-h.receivedWriteKey // get the 1-RTT write key
|
||||
<-h.receivedReadKey // get the handshake read key
|
||||
h.handshakeEvent <- struct{}{}
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// get the 1-RTT write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// get the handshake read key
|
||||
// TODO: check that the initial stream doesn't have any more data
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
case typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
case typeFinished:
|
||||
<-h.receivedReadKey // get the 1-RTT read key
|
||||
h.handshakeEvent <- struct{}{}
|
||||
// get the 1-RTT read key
|
||||
// TODO: check that the handshake stream doesn't have any more data
|
||||
default:
|
||||
// TODO: think about what to do with unknown message types
|
||||
return fmt.Errorf("Received unknown handshake message: %d", msgType)
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
default:
|
||||
panic("unexpected handshake message")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error {
|
||||
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
|
||||
switch msgType {
|
||||
case typeServerHello:
|
||||
<-h.receivedReadKey // get the handshake read key
|
||||
// get the handshake read key
|
||||
// TODO: check that the initial stream doesn't have any more data
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
case typeEncryptedExtensions:
|
||||
params := <-h.receivedTransportParams
|
||||
select {
|
||||
case params := <-h.receivedTransportParams:
|
||||
h.handleParamsCallback(¶ms)
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
case typeFinished:
|
||||
<-h.receivedWriteKey // get the handshake write key
|
||||
// get the handshake write key
|
||||
// TODO: check that the initial stream doesn't have any more data
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// While the order of these two is not defined by the TLS spec,
|
||||
// we have to do it on the same order as our TLS library does it.
|
||||
<-h.receivedWriteKey // get the handshake write key
|
||||
<-h.receivedReadKey // get the 1-RTT read key
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// get the 1-RTT read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return
|
||||
}
|
||||
// TODO: check that the handshake stream doesn't have any more data
|
||||
h.handshakeEvent <- struct{}{}
|
||||
default:
|
||||
// TODO: think about what to do with unknown extensions
|
||||
return fmt.Errorf("Received unknown handshake message: %d", msgType)
|
||||
panic("unexpected handshake message: ")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
|
||||
// TODO: add some error handling here (when the session is closed)
|
||||
return <-h.messageChan, nil
|
||||
msg, ok := <-h.messageChan
|
||||
if !ok {
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
|
|
|
@ -48,6 +48,77 @@ func (s *stream) Write(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
var _ = Describe("Crypto Setup TLS", func() {
|
||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 100)
|
||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
||||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
It("returns Handshake() when an error occurs", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("received unexpected handshake message"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleData(fakeCH, protocol.EncryptionInitial)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("doing the handshake", func() {
|
||||
generateCert := func() tls.Certificate {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -67,13 +138,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
}
|
||||
|
||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 100)
|
||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
||||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
handshake := func(
|
||||
client CryptoSetupTLS,
|
||||
cChunkChan <-chan chunk,
|
||||
|
@ -86,11 +150,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
err := server.HandleData(c.data, c.encLevel)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server.HandleData(c.data, c.encLevel)
|
||||
case c := <-sChunkChan:
|
||||
err := client.HandleData(c.data, c.encLevel)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client.HandleData(c.data, c.encLevel)
|
||||
case <-done: // handshake complete
|
||||
}
|
||||
}
|
||||
|
@ -202,7 +264,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(len(ch.data) - 4).To(Equal(length))
|
||||
|
||||
// make the go routine return
|
||||
client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -263,3 +325,4 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -44,7 +44,7 @@ type CryptoSetup interface {
|
|||
type CryptoSetupTLS interface {
|
||||
baseCryptoSetup
|
||||
|
||||
HandleData([]byte, protocol.EncryptionLevel) error
|
||||
HandleData([]byte, protocol.EncryptionLevel)
|
||||
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
|
|
|
@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
|
|||
}
|
||||
|
||||
// HandleData mocks base method
|
||||
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error {
|
||||
ret := m.ctrl.Call(m, "HandleData", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) {
|
||||
m.ctrl.Call(m, "HandleData", arg0, arg1)
|
||||
}
|
||||
|
||||
// HandleData indicates an expected call of HandleData
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue