use a callback to pass handshake errors to the session

This commit is contained in:
Marten Seemann 2019-05-31 14:13:34 +08:00
parent ed69ae2ce0
commit 743868159f
6 changed files with 66 additions and 57 deletions

View file

@ -76,6 +76,7 @@ type cryptoSetup struct {
handleParamsCallback func([]byte)
dropKeyCallback func(protocol.EncryptionLevel)
closeCallback func(error)
alertChan chan uint8
// HandleData() sends errors on the messageErrChan
@ -132,6 +133,7 @@ func NewCryptoSetupClient(
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
@ -143,6 +145,7 @@ func NewCryptoSetupClient(
tp,
handleParams,
dropKeys,
close,
tlsConf,
logger,
protocol.PerspectiveClient,
@ -164,6 +167,7 @@ func NewCryptoSetupServer(
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, error) {
@ -175,6 +179,7 @@ func NewCryptoSetupServer(
tp,
handleParams,
dropKeys,
close,
tlsConf,
logger,
protocol.PerspectiveServer,
@ -194,6 +199,7 @@ func newCryptoSetup(
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
tlsConf *tls.Config,
logger utils.Logger,
perspective protocol.Perspective,
@ -213,6 +219,7 @@ func newCryptoSetup(
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
dropKeyCallback: dropKeys,
closeCallback: close,
paramsChan: extHandler.TransportParameters(),
logger: logger,
perspective: perspective,
@ -259,7 +266,7 @@ func (h *cryptoSetup) Received1RTTAck() {
}
}
func (h *cryptoSetup) RunHandshake() error {
func (h *cryptoSetup) RunHandshake() {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
@ -273,23 +280,20 @@ func (h *cryptoSetup) RunHandshake() error {
}()
select {
case <-h.closeChan:
close(h.messageChan)
// wait until the Handshake() go routine has returned
return errors.New("Handshake aborted")
case <-handshakeComplete: // return when the handshake is done
return nil
case <-h.closeChan:
// wait until the Handshake() go routine has returned
close(h.messageChan)
case alert := <-h.alertChan:
err := <-handshakeErrChan
return qerr.CryptoError(alert, err.Error())
handshakeErr := <-handshakeErrChan
h.closeCallback(qerr.CryptoError(alert, handshakeErr.Error()))
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
// Thereby the go routine running qtls.Handshake() will return.
close(h.messageChan)
return err
h.closeCallback(err)
}
}

View file

@ -89,6 +89,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(err error) { Fail("error callback called") },
tlsConf,
utils.DefaultLogger.WithPrefix("server"),
)
@ -108,6 +109,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when an error occurs", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -118,6 +120,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -126,8 +129,8 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
Expect(err).To(MatchError("CRYPTO_ERROR: local error: tls: unexpected message"))
server.RunHandshake()
Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR: local error: tls: unexpected message")))
close(done)
}()
@ -143,6 +146,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -153,6 +157,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -161,7 +166,9 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
server.RunHandshake()
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
@ -176,6 +183,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -186,6 +194,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -194,7 +203,9 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
server.RunHandshake()
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
@ -209,6 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when it is closed", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -219,6 +231,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -227,8 +240,8 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := server.RunHandshake()
Expect(err).To(MatchError("Handshake aborted"))
server.RunHandshake()
Consistently(sErrChan).ShouldNot(Receive())
close(done)
}()
Expect(server.Close()).To(Succeed())
@ -255,13 +268,9 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}
handshake := func(
client CryptoSetup,
cChunkChan <-chan chunk,
server CryptoSetup,
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk) {
done := make(chan struct{})
defer close(done)
go func() {
defer GinkgoRecover()
for {
@ -276,20 +285,19 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}()
serverErrChan := make(chan error)
go func() {
defer GinkgoRecover()
serverErrChan <- server.RunHandshake()
server.RunHandshake()
close(done)
}()
clientErr := client.RunHandshake()
var serverErr error
Eventually(serverErrChan).Should(Receive(&serverErr))
return clientErr, serverErr
client.RunHandshake()
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
@ -299,12 +307,14 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { cErrChan <- e },
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
var token [16]byte
server, err := NewCryptoSetupServer(
sInitialStream,
@ -315,12 +325,23 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{StatelessResetToken: &token},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
serverConf,
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
return handshake(client, cChunkChan, server, sChunkChan)
handshake(client, cChunkChan, server, sChunkChan)
var cErr, sErr error
select {
case sErr = <-sErrChan:
default:
}
select {
case cErr = <-cErrChan:
default:
}
return cErr, sErr
}
It("handshakes", func() {
@ -358,6 +379,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(error) {},
&tls.Config{InsecureSkipVerify: true},
utils.DefaultLogger.WithPrefix("client"),
)
@ -396,6 +418,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cTransportParameters,
func(p []byte) { sTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
func(error) { Fail("error callback called") },
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
@ -416,6 +439,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
func(p []byte) { cTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
func(error) { Fail("error callback called") },
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -424,9 +448,7 @@ var _ = Describe("Crypto Setup TLS", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())

View file

@ -30,7 +30,7 @@ type tlsExtensionHandler interface {
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake() error
RunHandshake()
io.Closer
ChangeConnectionID(protocol.ConnectionID) error

View file

@ -150,11 +150,9 @@ func (mr *MockCryptoSetupMockRecorder) Received1RTTAck() *gomock.Call {
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
func (m *MockCryptoSetup) RunHandshake() {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RunHandshake")
ret0, _ := ret[0].(error)
return ret0
m.ctrl.Call(m, "RunHandshake")
}
// RunHandshake indicates an expected call of RunHandshake

View file

@ -47,7 +47,7 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
RunHandshake() error
RunHandshake()
ChangeConnectionID(protocol.ConnectionID) error
Received1RTTAck()
io.Closer
@ -200,6 +200,7 @@ var newSession = func(
params,
s.processTransportParameters,
s.dropEncryptionLevel,
s.closeLocal,
tlsConf,
logger,
)
@ -269,6 +270,7 @@ var newClientSession = func(
params,
s.processTransportParameters,
s.dropEncryptionLevel,
s.closeLocal,
tlsConf,
logger,
)
@ -338,10 +340,8 @@ func (s *session) run() error {
defer s.ctxCancel()
go func() {
if err := s.cryptoStreamHandler.RunHandshake(); err != nil {
s.closeLocal(err)
return
}
s.cryptoStreamHandler.RunHandshake()
// If an error occurred during the handshake, the crypto setup will already have called the close callback.
close(s.handshakeCompleteChan)
}()
if s.perspective == protocol.PerspectiveClient {

View file

@ -1121,21 +1121,6 @@ var _ = Describe("Session", func() {
})
})
It("closes when RunHandshake() errors", func() {
testErr := errors.New("crypto setup error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
sessionRunner.EXPECT().Retire(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().Return(testErr)
err := sess.run()
Expect(err).To(MatchError(testErr))
}()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("calls the onHandshakeComplete callback when the handshake completes", func() {
packer.EXPECT().PackPacket().AnyTimes()
go func() {