use a struct to pass callbacks from the session to the crypto setup

This commit is contained in:
Marten Seemann 2019-05-31 15:45:21 +08:00
parent 743868159f
commit 4fd6a7cc99
6 changed files with 143 additions and 66 deletions

View file

@ -72,11 +72,9 @@ type cryptoSetup struct {
messageChan chan []byte
paramsChan <-chan []byte
handleParamsCallback func([]byte)
paramsChan <-chan []byte
dropKeyCallback func(protocol.EncryptionLevel)
closeCallback func(error)
runner handshakeRunner
alertChan chan uint8
// HandleData() sends errors on the messageErrChan
@ -131,9 +129,7 @@ func NewCryptoSetupClient(
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
runner handshakeRunner,
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
@ -143,9 +139,7 @@ func NewCryptoSetupClient(
oneRTTStream,
connID,
tp,
handleParams,
dropKeys,
close,
runner,
tlsConf,
logger,
protocol.PerspectiveClient,
@ -165,9 +159,7 @@ func NewCryptoSetupServer(
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
runner handshakeRunner,
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, error) {
@ -177,9 +169,7 @@ func NewCryptoSetupServer(
oneRTTStream,
connID,
tp,
handleParams,
dropKeys,
close,
runner,
tlsConf,
logger,
protocol.PerspectiveServer,
@ -197,9 +187,7 @@ func newCryptoSetup(
oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
close func(error),
runner handshakeRunner,
tlsConf *tls.Config,
logger utils.Logger,
perspective protocol.Perspective,
@ -217,9 +205,7 @@ func newCryptoSetup(
oneRTTStream: oneRTTStream,
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
dropKeyCallback: dropKeys,
closeCallback: close,
runner: runner,
paramsChan: extHandler.TransportParameters(),
logger: logger,
perspective: perspective,
@ -254,7 +240,7 @@ func (h *cryptoSetup) Received1RTTAck() {
if h.initialOpener != nil {
h.initialOpener = nil
h.initialSealer = nil
h.dropKeyCallback(protocol.EncryptionInitial)
h.runner.DropKeys(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
}
// drop handshake keys
@ -262,7 +248,7 @@ func (h *cryptoSetup) Received1RTTAck() {
h.handshakeOpener = nil
h.handshakeSealer = nil
h.logger.Debugf("Dropping Handshake keys.")
h.dropKeyCallback(protocol.EncryptionHandshake)
h.runner.DropKeys(protocol.EncryptionHandshake)
}
}
@ -286,14 +272,14 @@ func (h *cryptoSetup) RunHandshake() {
close(h.messageChan)
case alert := <-h.alertChan:
handshakeErr := <-handshakeErrChan
h.closeCallback(qerr.CryptoError(alert, handshakeErr.Error()))
h.runner.OnError(qerr.CryptoError(alert, handshakeErr.Error()))
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
close(h.messageChan)
h.closeCallback(err)
h.runner.OnError(err)
}
}
@ -358,7 +344,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.paramsChan:
h.handleParamsCallback(data)
h.runner.OnReceivedParams(data)
case <-h.handshakeDone:
return false
}
@ -423,7 +409,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
case typeEncryptedExtensions:
select {
case data := <-h.paramsChan:
h.handleParamsCallback(data)
h.runner.OnReceivedParams(data)
case <-h.handshakeDone:
return false
}

View file

@ -12,6 +12,7 @@ import (
"math/big"
"time"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
@ -87,9 +88,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(err error) { Fail("error callback called") },
NewMockHandshakeRunner(mockCtrl),
tlsConf,
utils.DefaultLogger.WithPrefix("server"),
)
@ -110,6 +109,8 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when an error occurs", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -118,9 +119,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
runner,
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -148,6 +147,8 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
@ -155,9 +156,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
runner,
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -185,6 +184,8 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
@ -192,9 +193,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
runner,
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -220,7 +219,6 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when it is closed", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
@ -229,9 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
NewMockHandshakeRunner(mockCtrl),
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@ -241,7 +237,6 @@ var _ = Describe("Crypto Setup TLS", func() {
go func() {
defer GinkgoRecover()
server.RunHandshake()
Consistently(sErrChan).ShouldNot(Receive())
close(done)
}()
Expect(server.Close()).To(Succeed())
@ -298,6 +293,9 @@ var _ = Describe("Crypto Setup TLS", func() {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
@ -305,9 +303,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { cErrChan <- e },
cRunner,
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
@ -315,6 +311,9 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
var token [16]byte
server, err := NewCryptoSetupServer(
sInitialStream,
@ -323,9 +322,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{StatelessResetToken: &token},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(e error) { sErrChan <- e },
sRunner,
serverConf,
utils.DefaultLogger.WithPrefix("server"),
)
@ -369,6 +366,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan, err := NewCryptoSetupClient(
cInitialStream,
@ -377,9 +375,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
func(error) {},
runner,
&tls.Config{InsecureSkipVerify: true},
utils.DefaultLogger.WithPrefix("client"),
)
@ -401,6 +397,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
runner.EXPECT().OnError(gomock.Any())
client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
@ -409,6 +406,8 @@ var _ = Describe("Crypto Setup TLS", func() {
var cTransportParametersRcvd, sTransportParametersRcvd []byte
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
@ -416,9 +415,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
cTransportParameters,
func(p []byte) { sTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
func(error) { Fail("error callback called") },
cRunner,
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
@ -426,6 +423,8 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token [16]byte
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
sTransportParameters := &TransportParameters{
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: &token,
@ -437,9 +436,7 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
nil,
sTransportParameters,
func(p []byte) { cTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
func(error) { Fail("error callback called") },
sRunner,
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)

View file

@ -28,6 +28,12 @@ type tlsExtensionHandler interface {
TransportParameters() <-chan []byte
}
type handshakeRunner interface {
OnReceivedParams([]byte)
OnError(error)
DropKeys(protocol.EncryptionLevel)
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake()

View file

@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: HandshakeRunner)
// Package handshake is a generated GoMock package.
package handshake
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockHandshakeRunner is a mock of HandshakeRunner interface
type MockHandshakeRunner struct {
ctrl *gomock.Controller
recorder *MockHandshakeRunnerMockRecorder
}
// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner
type MockHandshakeRunnerMockRecorder struct {
mock *MockHandshakeRunner
}
// NewMockHandshakeRunner creates a new mock instance
func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner {
mock := &MockHandshakeRunner{ctrl: ctrl}
mock.recorder = &MockHandshakeRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder {
return m.recorder
}
// DropKeys mocks base method
func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropKeys", arg0)
}
// DropKeys indicates an expected call of DropKeys
func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0)
}
// OnError mocks base method
func (m *MockHandshakeRunner) OnError(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnError", arg0)
}
// OnError indicates an expected call of OnError
func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0)
}
// OnReceivedParams mocks base method
func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnReceivedParams", arg0)
}
// OnReceivedParams indicates an expected call of OnReceivedParams
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0)
}

View file

@ -0,0 +1,3 @@
package handshake
//go:generate sh -c "../mockgen_internal.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"

View file

@ -71,6 +71,16 @@ func (p *receivedPacket) Clone() *receivedPacket {
}
}
type handshakeRunner struct {
onReceivedParams func([]byte)
onError func(error)
dropKeys func(protocol.EncryptionLevel)
}
func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) }
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
type closeError struct {
err error
remote bool
@ -198,9 +208,11 @@ var newSession = func(
clientDestConnID,
conn.RemoteAddr(),
params,
s.processTransportParameters,
s.dropEncryptionLevel,
s.closeLocal,
&handshakeRunner{
onReceivedParams: s.processTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
},
tlsConf,
logger,
)
@ -268,9 +280,11 @@ var newClientSession = func(
s.destConnID,
conn.RemoteAddr(),
params,
s.processTransportParameters,
s.dropEncryptionLevel,
s.closeLocal,
&handshakeRunner{
onReceivedParams: s.processTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
},
tlsConf,
logger,
)