mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use a callback to pass handshake errors to the session
This commit is contained in:
parent
ed69ae2ce0
commit
743868159f
6 changed files with 66 additions and 57 deletions
|
@ -76,6 +76,7 @@ type cryptoSetup struct {
|
|||
handleParamsCallback func([]byte)
|
||||
|
||||
dropKeyCallback func(protocol.EncryptionLevel)
|
||||
closeCallback func(error)
|
||||
|
||||
alertChan chan uint8
|
||||
// HandleData() sends errors on the messageErrChan
|
||||
|
@ -132,6 +133,7 @@ func NewCryptoSetupClient(
|
|||
tp *TransportParameters,
|
||||
handleParams func([]byte),
|
||||
dropKeys func(protocol.EncryptionLevel),
|
||||
close func(error),
|
||||
tlsConf *tls.Config,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
|
@ -143,6 +145,7 @@ func NewCryptoSetupClient(
|
|||
tp,
|
||||
handleParams,
|
||||
dropKeys,
|
||||
close,
|
||||
tlsConf,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
|
@ -164,6 +167,7 @@ func NewCryptoSetupServer(
|
|||
tp *TransportParameters,
|
||||
handleParams func([]byte),
|
||||
dropKeys func(protocol.EncryptionLevel),
|
||||
close func(error),
|
||||
tlsConf *tls.Config,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
|
@ -175,6 +179,7 @@ func NewCryptoSetupServer(
|
|||
tp,
|
||||
handleParams,
|
||||
dropKeys,
|
||||
close,
|
||||
tlsConf,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
|
@ -194,6 +199,7 @@ func newCryptoSetup(
|
|||
tp *TransportParameters,
|
||||
handleParams func([]byte),
|
||||
dropKeys func(protocol.EncryptionLevel),
|
||||
close func(error),
|
||||
tlsConf *tls.Config,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
|
@ -213,6 +219,7 @@ func newCryptoSetup(
|
|||
writeEncLevel: protocol.EncryptionInitial,
|
||||
handleParamsCallback: handleParams,
|
||||
dropKeyCallback: dropKeys,
|
||||
closeCallback: close,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
|
@ -259,7 +266,7 @@ func (h *cryptoSetup) Received1RTTAck() {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) RunHandshake() error {
|
||||
func (h *cryptoSetup) RunHandshake() {
|
||||
// Handle errors that might occur when HandleData() is called.
|
||||
handshakeComplete := make(chan struct{})
|
||||
handshakeErrChan := make(chan error, 1)
|
||||
|
@ -273,23 +280,20 @@ func (h *cryptoSetup) RunHandshake() error {
|
|||
}()
|
||||
|
||||
select {
|
||||
case <-h.closeChan:
|
||||
close(h.messageChan)
|
||||
// wait until the Handshake() go routine has returned
|
||||
return errors.New("Handshake aborted")
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
return nil
|
||||
case <-h.closeChan:
|
||||
// wait until the Handshake() go routine has returned
|
||||
close(h.messageChan)
|
||||
case alert := <-h.alertChan:
|
||||
err := <-handshakeErrChan
|
||||
return qerr.CryptoError(alert, err.Error())
|
||||
handshakeErr := <-handshakeErrChan
|
||||
h.closeCallback(qerr.CryptoError(alert, handshakeErr.Error()))
|
||||
case err := <-h.messageErrChan:
|
||||
// If the handshake errored because of an error that occurred during HandleData(),
|
||||
// that error message will be more useful than the error message generated by Handshake().
|
||||
// Close the message chan that qtls is receiving messages from.
|
||||
// This will make qtls.Handshake() return.
|
||||
// Thereby the go routine running qtls.Handshake() will return.
|
||||
close(h.messageChan)
|
||||
return err
|
||||
h.closeCallback(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(err error) { Fail("error callback called") },
|
||||
tlsConf,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -108,6 +109,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("returns Handshake() when an error occurs", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -118,6 +120,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { sErrChan <- e },
|
||||
testdata.GetTLSConfig(),
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -126,8 +129,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(MatchError("CRYPTO_ERROR: local error: tls: unexpected message"))
|
||||
server.RunHandshake()
|
||||
Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR: local error: tls: unexpected message")))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -143,6 +146,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("returns Handshake() when a message is received at the wrong encryption level", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -153,6 +157,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { sErrChan <- e },
|
||||
testdata.GetTLSConfig(),
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -161,7 +166,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
server.RunHandshake()
|
||||
var err error
|
||||
Expect(sErrChan).To(Receive(&err))
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
||||
qerr := err.(*qerr.QuicError)
|
||||
Expect(qerr.IsCryptoError()).To(BeTrue())
|
||||
|
@ -176,6 +183,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -186,6 +194,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { sErrChan <- e },
|
||||
testdata.GetTLSConfig(),
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -194,7 +203,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
server.RunHandshake()
|
||||
var err error
|
||||
Expect(sErrChan).To(Receive(&err))
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
||||
qerr := err.(*qerr.QuicError)
|
||||
Expect(qerr.IsCryptoError()).To(BeTrue())
|
||||
|
@ -209,6 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("returns Handshake() when it is closed", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -219,6 +231,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { sErrChan <- e },
|
||||
testdata.GetTLSConfig(),
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -227,8 +240,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(MatchError("Handshake aborted"))
|
||||
server.RunHandshake()
|
||||
Consistently(sErrChan).ShouldNot(Receive())
|
||||
close(done)
|
||||
}()
|
||||
Expect(server.Close()).To(Succeed())
|
||||
|
@ -255,13 +268,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
}
|
||||
|
||||
handshake := func(
|
||||
client CryptoSetup,
|
||||
cChunkChan <-chan chunk,
|
||||
server CryptoSetup,
|
||||
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
||||
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
|
||||
server CryptoSetup, sChunkChan <-chan chunk) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
|
@ -276,20 +285,19 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
}()
|
||||
|
||||
serverErrChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serverErrChan <- server.RunHandshake()
|
||||
server.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
clientErr := client.RunHandshake()
|
||||
var serverErr error
|
||||
Eventually(serverErrChan).Should(Receive(&serverErr))
|
||||
return clientErr, serverErr
|
||||
client.RunHandshake()
|
||||
Eventually(done).Should(BeClosed())
|
||||
}
|
||||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cErrChan := make(chan error, 1)
|
||||
client, _, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
|
@ -299,12 +307,14 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { cErrChan <- e },
|
||||
clientConf,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sErrChan := make(chan error, 1)
|
||||
var token [16]byte
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -315,12 +325,23 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{StatelessResetToken: &token},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(e error) { sErrChan <- e },
|
||||
serverConf,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return handshake(client, cChunkChan, server, sChunkChan)
|
||||
handshake(client, cChunkChan, server, sChunkChan)
|
||||
var cErr, sErr error
|
||||
select {
|
||||
case sErr = <-sErrChan:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case cErr = <-cErrChan:
|
||||
default:
|
||||
}
|
||||
return cErr, sErr
|
||||
}
|
||||
|
||||
It("handshakes", func() {
|
||||
|
@ -358,6 +379,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
&TransportParameters{},
|
||||
func([]byte) {},
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(error) {},
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
@ -396,6 +418,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cTransportParameters,
|
||||
func(p []byte) { sTransportParametersRcvd = p },
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(error) { Fail("error callback called") },
|
||||
clientConf,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
@ -416,6 +439,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sTransportParameters,
|
||||
func(p []byte) { cTransportParametersRcvd = p },
|
||||
func(protocol.EncryptionLevel) {},
|
||||
func(error) { Fail("error callback called") },
|
||||
testdata.GetTLSConfig(),
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
@ -424,9 +448,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
handshake(client, cChunkChan, server, sChunkChan)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
|
|
@ -30,7 +30,7 @@ type tlsExtensionHandler interface {
|
|||
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
RunHandshake() error
|
||||
RunHandshake()
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID) error
|
||||
|
||||
|
|
|
@ -150,11 +150,9 @@ func (mr *MockCryptoSetupMockRecorder) Received1RTTAck() *gomock.Call {
|
|||
}
|
||||
|
||||
// RunHandshake mocks base method
|
||||
func (m *MockCryptoSetup) RunHandshake() error {
|
||||
func (m *MockCryptoSetup) RunHandshake() {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RunHandshake")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "RunHandshake")
|
||||
}
|
||||
|
||||
// RunHandshake indicates an expected call of RunHandshake
|
||||
|
|
10
session.go
10
session.go
|
@ -47,7 +47,7 @@ type streamManager interface {
|
|||
}
|
||||
|
||||
type cryptoStreamHandler interface {
|
||||
RunHandshake() error
|
||||
RunHandshake()
|
||||
ChangeConnectionID(protocol.ConnectionID) error
|
||||
Received1RTTAck()
|
||||
io.Closer
|
||||
|
@ -200,6 +200,7 @@ var newSession = func(
|
|||
params,
|
||||
s.processTransportParameters,
|
||||
s.dropEncryptionLevel,
|
||||
s.closeLocal,
|
||||
tlsConf,
|
||||
logger,
|
||||
)
|
||||
|
@ -269,6 +270,7 @@ var newClientSession = func(
|
|||
params,
|
||||
s.processTransportParameters,
|
||||
s.dropEncryptionLevel,
|
||||
s.closeLocal,
|
||||
tlsConf,
|
||||
logger,
|
||||
)
|
||||
|
@ -338,10 +340,8 @@ func (s *session) run() error {
|
|||
defer s.ctxCancel()
|
||||
|
||||
go func() {
|
||||
if err := s.cryptoStreamHandler.RunHandshake(); err != nil {
|
||||
s.closeLocal(err)
|
||||
return
|
||||
}
|
||||
s.cryptoStreamHandler.RunHandshake()
|
||||
// If an error occurred during the handshake, the crypto setup will already have called the close callback.
|
||||
close(s.handshakeCompleteChan)
|
||||
}()
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
|
|
|
@ -1121,21 +1121,6 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
})
|
||||
|
||||
It("closes when RunHandshake() errors", func() {
|
||||
testErr := errors.New("crypto setup error")
|
||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
|
||||
sessionRunner.EXPECT().Retire(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Return(testErr)
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
}()
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("calls the onHandshakeComplete callback when the handshake completes", func() {
|
||||
packer.EXPECT().PackPacket().AnyTimes()
|
||||
go func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue