diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index ddc5f633..e6c5e55d 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -72,11 +72,9 @@ type cryptoSetup struct { messageChan chan []byte - paramsChan <-chan []byte - handleParamsCallback func([]byte) + paramsChan <-chan []byte - dropKeyCallback func(protocol.EncryptionLevel) - closeCallback func(error) + runner handshakeRunner alertChan chan uint8 // HandleData() sends errors on the messageErrChan @@ -131,9 +129,7 @@ func NewCryptoSetupClient( connID protocol.ConnectionID, remoteAddr net.Addr, tp *TransportParameters, - handleParams func([]byte), - dropKeys func(protocol.EncryptionLevel), - close func(error), + runner handshakeRunner, tlsConf *tls.Config, logger utils.Logger, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { @@ -143,9 +139,7 @@ func NewCryptoSetupClient( oneRTTStream, connID, tp, - handleParams, - dropKeys, - close, + runner, tlsConf, logger, protocol.PerspectiveClient, @@ -165,9 +159,7 @@ func NewCryptoSetupServer( connID protocol.ConnectionID, remoteAddr net.Addr, tp *TransportParameters, - handleParams func([]byte), - dropKeys func(protocol.EncryptionLevel), - close func(error), + runner handshakeRunner, tlsConf *tls.Config, logger utils.Logger, ) (CryptoSetup, error) { @@ -177,9 +169,7 @@ func NewCryptoSetupServer( oneRTTStream, connID, tp, - handleParams, - dropKeys, - close, + runner, tlsConf, logger, protocol.PerspectiveServer, @@ -197,9 +187,7 @@ func newCryptoSetup( oneRTTStream io.Writer, connID protocol.ConnectionID, tp *TransportParameters, - handleParams func([]byte), - dropKeys func(protocol.EncryptionLevel), - close func(error), + runner handshakeRunner, tlsConf *tls.Config, logger utils.Logger, perspective protocol.Perspective, @@ -217,9 +205,7 @@ func newCryptoSetup( oneRTTStream: oneRTTStream, readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, - handleParamsCallback: handleParams, - dropKeyCallback: dropKeys, - closeCallback: close, + runner: runner, paramsChan: extHandler.TransportParameters(), logger: logger, perspective: perspective, @@ -254,7 +240,7 @@ func (h *cryptoSetup) Received1RTTAck() { if h.initialOpener != nil { h.initialOpener = nil h.initialSealer = nil - h.dropKeyCallback(protocol.EncryptionInitial) + h.runner.DropKeys(protocol.EncryptionInitial) h.logger.Debugf("Dropping Initial keys.") } // drop handshake keys @@ -262,7 +248,7 @@ func (h *cryptoSetup) Received1RTTAck() { h.handshakeOpener = nil h.handshakeSealer = nil h.logger.Debugf("Dropping Handshake keys.") - h.dropKeyCallback(protocol.EncryptionHandshake) + h.runner.DropKeys(protocol.EncryptionHandshake) } } @@ -286,14 +272,14 @@ func (h *cryptoSetup) RunHandshake() { close(h.messageChan) case alert := <-h.alertChan: handshakeErr := <-handshakeErrChan - h.closeCallback(qerr.CryptoError(alert, handshakeErr.Error())) + h.runner.OnError(qerr.CryptoError(alert, handshakeErr.Error())) case err := <-h.messageErrChan: // If the handshake errored because of an error that occurred during HandleData(), // that error message will be more useful than the error message generated by Handshake(). // Close the message chan that qtls is receiving messages from. // This will make qtls.Handshake() return. close(h.messageChan) - h.closeCallback(err) + h.runner.OnError(err) } } @@ -358,7 +344,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { h.logger.Debugf("Sending HelloRetryRequest") return false case data := <-h.paramsChan: - h.handleParamsCallback(data) + h.runner.OnReceivedParams(data) case <-h.handshakeDone: return false } @@ -423,7 +409,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { case typeEncryptedExtensions: select { case data := <-h.paramsChan: - h.handleParamsCallback(data) + h.runner.OnReceivedParams(data) case <-h.handshakeDone: return false } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 7282f7d2..8c99469a 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -12,6 +12,7 @@ import ( "math/big" "time" + gomock "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -87,9 +88,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(err error) { Fail("error callback called") }, + NewMockHandshakeRunner(mockCtrl), tlsConf, utils.DefaultLogger.WithPrefix("server"), ) @@ -110,6 +109,8 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when an error occurs", func() { sErrChan := make(chan error, 1) + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupServer( sInitialStream, @@ -118,9 +119,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { sErrChan <- e }, + runner, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -148,6 +147,8 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when a message is received at the wrong encryption level", func() { sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, @@ -155,9 +156,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { sErrChan <- e }, + runner, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -185,6 +184,8 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when handling a message fails", func() { sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() + runner := NewMockHandshakeRunner(mockCtrl) + runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, @@ -192,9 +193,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { sErrChan <- e }, + runner, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -220,7 +219,6 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("returns Handshake() when it is closed", func() { - sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupServer( sInitialStream, @@ -229,9 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { sErrChan <- e }, + NewMockHandshakeRunner(mockCtrl), testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -241,7 +237,6 @@ var _ = Describe("Crypto Setup TLS", func() { go func() { defer GinkgoRecover() server.RunHandshake() - Consistently(sErrChan).ShouldNot(Receive()) close(done) }() Expect(server.Close()).To(Succeed()) @@ -298,6 +293,9 @@ var _ = Describe("Crypto Setup TLS", func() { handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cErrChan := make(chan error, 1) + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, @@ -305,9 +303,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { cErrChan <- e }, + cRunner, clientConf, utils.DefaultLogger.WithPrefix("client"), ) @@ -315,6 +311,9 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() sErrChan := make(chan error, 1) + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) var token [16]byte server, err := NewCryptoSetupServer( sInitialStream, @@ -323,9 +322,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{StatelessResetToken: &token}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(e error) { sErrChan <- e }, + sRunner, serverConf, utils.DefaultLogger.WithPrefix("server"), ) @@ -369,6 +366,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("signals when it has written the ClientHello", func() { + runner := NewMockHandshakeRunner(mockCtrl) cChunkChan, cInitialStream, cHandshakeStream := initStreams() client, chChan, err := NewCryptoSetupClient( cInitialStream, @@ -377,9 +375,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, &TransportParameters{}, - func([]byte) {}, - func(protocol.EncryptionLevel) {}, - func(error) {}, + runner, &tls.Config{InsecureSkipVerify: true}, utils.DefaultLogger.WithPrefix("client"), ) @@ -401,6 +397,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(len(ch.data) - 4).To(Equal(length)) // make the go routine return + runner.EXPECT().OnError(gomock.Any()) client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) }) @@ -409,6 +406,8 @@ var _ = Describe("Crypto Setup TLS", func() { var cTransportParametersRcvd, sTransportParametersRcvd []byte cChunkChan, cInitialStream, cHandshakeStream := initStreams() cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b }) client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, @@ -416,9 +415,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, cTransportParameters, - func(p []byte) { sTransportParametersRcvd = p }, - func(protocol.EncryptionLevel) {}, - func(error) { Fail("error callback called") }, + cRunner, clientConf, utils.DefaultLogger.WithPrefix("client"), ) @@ -426,6 +423,8 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() var token [16]byte + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b }) sTransportParameters := &TransportParameters{ IdleTimeout: 0x1337 * time.Second, StatelessResetToken: &token, @@ -437,9 +436,7 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, nil, sTransportParameters, - func(p []byte) { cTransportParametersRcvd = p }, - func(protocol.EncryptionLevel) {}, - func(error) { Fail("error callback called") }, + sRunner, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 948a34a8..f34421ab 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -28,6 +28,12 @@ type tlsExtensionHandler interface { TransportParameters() <-chan []byte } +type handshakeRunner interface { + OnReceivedParams([]byte) + OnError(error) + DropKeys(protocol.EncryptionLevel) +} + // CryptoSetup handles the handshake and protecting / unprotecting packets type CryptoSetup interface { RunHandshake() diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go new file mode 100644 index 00000000..37fb494e --- /dev/null +++ b/internal/handshake/mock_handshake_runner_test.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: HandshakeRunner) + +// Package handshake is a generated GoMock package. +package handshake + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockHandshakeRunner is a mock of HandshakeRunner interface +type MockHandshakeRunner struct { + ctrl *gomock.Controller + recorder *MockHandshakeRunnerMockRecorder +} + +// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner +type MockHandshakeRunnerMockRecorder struct { + mock *MockHandshakeRunner +} + +// NewMockHandshakeRunner creates a new mock instance +func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { + mock := &MockHandshakeRunner{ctrl: ctrl} + mock.recorder = &MockHandshakeRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { + return m.recorder +} + +// DropKeys mocks base method +func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropKeys", arg0) +} + +// DropKeys indicates an expected call of DropKeys +func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) +} + +// OnError mocks base method +func (m *MockHandshakeRunner) OnError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnError", arg0) +} + +// OnError indicates an expected call of OnError +func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) +} + +// OnReceivedParams mocks base method +func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnReceivedParams", arg0) +} + +// OnReceivedParams indicates an expected call of OnReceivedParams +func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) +} diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go new file mode 100644 index 00000000..1c3b24b0 --- /dev/null +++ b/internal/handshake/mockgen.go @@ -0,0 +1,3 @@ +package handshake + +//go:generate sh -c "../mockgen_internal.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner" diff --git a/session.go b/session.go index cb37beef..987a0423 100644 --- a/session.go +++ b/session.go @@ -71,6 +71,16 @@ func (p *receivedPacket) Clone() *receivedPacket { } } +type handshakeRunner struct { + onReceivedParams func([]byte) + onError func(error) + dropKeys func(protocol.EncryptionLevel) +} + +func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) } +func (r *handshakeRunner) OnError(e error) { r.onError(e) } +func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } + type closeError struct { err error remote bool @@ -198,9 +208,11 @@ var newSession = func( clientDestConnID, conn.RemoteAddr(), params, - s.processTransportParameters, - s.dropEncryptionLevel, - s.closeLocal, + &handshakeRunner{ + onReceivedParams: s.processTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + }, tlsConf, logger, ) @@ -268,9 +280,11 @@ var newClientSession = func( s.destConnID, conn.RemoteAddr(), params, - s.processTransportParameters, - s.dropEncryptionLevel, - s.closeLocal, + &handshakeRunner{ + onReceivedParams: s.processTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + }, tlsConf, logger, )