From a95b7c286859bbe5ee4ec6a5aae4928e3715397d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 1 Feb 2019 17:22:37 +0900 Subject: [PATCH] refactor how transport parameters are passed from the extension handler --- internal/handshake/crypto_setup.go | 51 +++++++++---------- internal/handshake/interface.go | 1 + .../handshake/tls_extension_handler_client.go | 16 +++--- .../tls_extension_handler_client_test.go | 13 ++--- .../handshake/tls_extension_handler_server.go | 15 +++--- .../tls_extension_handler_server_test.go | 13 ++--- 6 files changed, 51 insertions(+), 58 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 7d0c2a10..d45ff920 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -60,6 +60,8 @@ type cryptoSetup struct { readEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel + extHandler tlsExtensionHandler + handleParamsCallback func(*TransportParameters) // There are two ways that an error can occur during the handshake: @@ -71,8 +73,6 @@ type cryptoSetup struct { messageErrChan chan error // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns handshakeDone chan struct{} - // transport parameters are sent on the receivedTransportParams, as soon as they are received - receivedTransportParams <-chan TransportParameters // is closed when Close() is called closeChan chan struct{} @@ -117,7 +117,7 @@ func NewCryptoSetupClient( logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { - extHandler, receivedTransportParams := newExtensionHandlerClient( + extHandler := newExtensionHandlerClient( params, origConnID, initialVersion, @@ -130,7 +130,6 @@ func NewCryptoSetupClient( handshakeStream, connID, extHandler, - receivedTransportParams, handleParams, tlsConf, logger, @@ -156,7 +155,7 @@ func NewCryptoSetupServer( logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, error) { - extHandler, receivedTransportParams := newExtensionHandlerServer( + extHandler := newExtensionHandlerServer( params, supportedVersions, currentVersion, @@ -167,7 +166,6 @@ func NewCryptoSetupServer( handshakeStream, connID, extHandler, - receivedTransportParams, handleParams, tlsConf, logger, @@ -185,7 +183,6 @@ func newCryptoSetup( handshakeStream io.Writer, connID protocol.ConnectionID, extHandler tlsExtensionHandler, - transportParamChan <-chan TransportParameters, handleParams func(*TransportParameters), tlsConf *tls.Config, logger utils.Logger, @@ -196,24 +193,24 @@ func newCryptoSetup( return nil, nil, err } cs := &cryptoSetup{ - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - handleParamsCallback: handleParams, - receivedTransportParams: transportParamChan, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - handshakeErrChan: make(chan struct{}), - messageErrChan: make(chan error, 1), - clientHelloWrittenChan: make(chan struct{}), - messageChan: make(chan []byte, 100), - receivedReadKey: make(chan struct{}), - receivedWriteKey: make(chan struct{}), - closeChan: make(chan struct{}), + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + handleParamsCallback: handleParams, + extHandler: extHandler, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + handshakeErrChan: make(chan struct{}), + messageErrChan: make(chan error, 1), + clientHelloWrittenChan: make(chan struct{}), + messageChan: make(chan []byte, 100), + receivedReadKey: make(chan struct{}), + receivedWriteKey: make(chan struct{}), + closeChan: make(chan struct{}), } qtlsConf := tlsConfigToQtlsConfig(tlsConf) qtlsConf.AlternativeRecordLayer = cs @@ -312,7 +309,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { switch msgType { case typeClientHello: select { - case params := <-h.receivedTransportParams: + case params := <-h.extHandler.TransportParameters(): h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: return false @@ -370,7 +367,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { return true case typeEncryptedExtensions: select { - case params := <-h.receivedTransportParams: + case params := <-h.extHandler.TransportParameters(): h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: return false diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 38d8e4a6..ecec76e4 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -25,6 +25,7 @@ type Sealer interface { type tlsExtensionHandler interface { GetExtensions(msgType uint8) []qtls.Extension ReceivedExtensions(msgType uint8, exts []qtls.Extension) error + TransportParameters() <-chan TransportParameters } // CryptoSetup handles the handshake and protecting / unprotecting packets diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 5e4e1c4d..ce1f914d 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -12,7 +12,7 @@ import ( type extensionHandlerClient struct { ourParams *TransportParameters - paramsChan chan<- TransportParameters + paramsChan chan TransportParameters origConnID protocol.ConnectionID initialVersion protocol.VersionNumber @@ -32,20 +32,16 @@ func newExtensionHandlerClient( supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, logger utils.Logger, -) (tlsExtensionHandler, <-chan TransportParameters) { - // The client reads the transport parameters from the Encrypted Extensions message. - // The paramsChan is used in the session's run loop's select statement. - // We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately. - paramsChan := make(chan TransportParameters) +) tlsExtensionHandler { return &extensionHandlerClient{ ourParams: params, - paramsChan: paramsChan, + paramsChan: make(chan TransportParameters), origConnID: origConnID, initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, logger: logger, - }, paramsChan + } } func (h *extensionHandlerClient) GetExtensions(msgType uint8) []qtls.Extension { @@ -111,3 +107,7 @@ func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.E h.paramsChan <- params return nil } + +func (h *extensionHandlerClient) TransportParameters() <-chan TransportParameters { + return h.paramsChan +} diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index 66fc150e..63b4dfa5 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -13,15 +13,12 @@ import ( ) var _ = Describe("TLS Extension Handler, for the client", func() { - var ( - handler *extensionHandlerClient - paramsChan <-chan TransportParameters - ) + var handler *extensionHandlerClient version := protocol.VersionNumber(0x42) BeforeEach(func() { var h tlsExtensionHandler - h, paramsChan = newExtensionHandlerClient( + h = newExtensionHandlerClient( &TransportParameters{}, nil, version, @@ -83,7 +80,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { }() var params TransportParameters Consistently(done).ShouldNot(BeClosed()) - Expect(paramsChan).To(Receive(¶ms)) + Expect(handler.TransportParameters()).To(Receive(¶ms)) Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) Eventually(done).Should(BeClosed()) }) @@ -133,7 +130,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - Eventually(paramsChan).Should(Receive()) + Eventually(handler.TransportParameters()).Should(Receive()) close(done) }() @@ -206,7 +203,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - Eventually(paramsChan).Should(Receive()) + Eventually(handler.TransportParameters()).Should(Receive()) close(done) }() diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index 732f8dc0..919475fb 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -11,7 +11,7 @@ import ( type extensionHandlerServer struct { ourParams *TransportParameters - paramsChan chan<- TransportParameters + paramsChan chan TransportParameters version protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -27,17 +27,14 @@ func newExtensionHandlerServer( supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, logger utils.Logger, -) (tlsExtensionHandler, <-chan TransportParameters) { - // Processing the ClientHello is performed statelessly (and from a single go-routine). - // Therefore, we have to use a buffered chan to pass the transport parameters to that go routine. - paramsChan := make(chan TransportParameters) +) tlsExtensionHandler { return &extensionHandlerServer{ ourParams: params, - paramsChan: paramsChan, + paramsChan: make(chan TransportParameters), supportedVersions: supportedVersions, version: version, logger: logger, - }, paramsChan + } } func (h *extensionHandlerServer) GetExtensions(msgType uint8) []qtls.Extension { @@ -84,3 +81,7 @@ func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.E h.paramsChan <- chtp.Parameters return nil } + +func (h *extensionHandlerServer) TransportParameters() <-chan TransportParameters { + return h.paramsChan +} diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index e64c40c2..b08d13c2 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -12,14 +12,11 @@ import ( ) var _ = Describe("TLS Extension Handler, for the server", func() { - var ( - handler *extensionHandlerServer - paramsChan <-chan TransportParameters - ) + var handler *extensionHandlerServer BeforeEach(func() { var h tlsExtensionHandler - h, paramsChan = newExtensionHandlerServer( + h = newExtensionHandlerServer( &TransportParameters{}, nil, protocol.VersionWhatever, @@ -79,7 +76,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }() var params TransportParameters Consistently(done).ShouldNot(BeClosed()) - Expect(paramsChan).To(Receive(¶ms)) + Expect(handler.TransportParameters()).To(Receive(¶ms)) Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) Eventually(done).Should(BeClosed()) }) @@ -103,7 +100,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - <-paramsChan + <-handler.TransportParameters() close(done) }() handler.version = 42 @@ -123,7 +120,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - <-paramsChan + <-handler.TransportParameters() close(done) }() handler.version = 42